diff --git a/synapse/__init__.py b/synapse/__init__.py
index a14d578e36..e62901b761 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.33.2"
+__version__ = "0.33.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5bbbe8e2e7..6502a6be7b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.errors import AuthError, Codes
+from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -211,9 +211,9 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
- )[0]
+ )[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr:
- self.store.insert_client_ip(
+ yield self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
@@ -682,7 +682,7 @@ class Auth(object):
Returns:
bool: False if no access_token was given, True otherwise.
"""
- query_params = request.args.get("access_token")
+ query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning
403.
Returns:
- str: The access_token
+ unicode: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
@@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
- parts = auth_headers[0].split(" ")
- if parts[0] == "Bearer" and len(parts) == 2:
- return parts[1]
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ return parts[1].decode('ascii')
else:
raise AuthError(
token_not_found_http_status,
@@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN
)
- return query_params[0]
+ return query_params[0].decode('ascii')
@defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id):
@@ -773,3 +773,36 @@ class Auth(object):
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
+
+ @defer.inlineCallbacks
+ def check_auth_blocking(self, user_id=None):
+ """Checks if the user should be rejected for some external reason,
+ such as monthly active user limiting or global disable flag
+
+ Args:
+ user_id(str|None): If present, checks for presence against existing
+ MAU cohort
+ """
+ if self.hs.config.hs_disabled:
+ raise ResourceLimitError(
+ 403, self.hs.config.hs_disabled_message,
+ errcode=Codes.RESOURCE_LIMIT_EXCEED,
+ admin_uri=self.hs.config.admin_uri,
+ limit_type=self.hs.config.hs_disabled_limit_type
+ )
+ if self.hs.config.limit_usage_by_mau is True:
+ # If the user is already part of the MAU cohort
+ if user_id:
+ timestamp = yield self.store.user_last_seen_monthly_active(user_id)
+ if timestamp:
+ return
+ # Else if there is no room in the MAU bucket, bail
+ current_mau = yield self.store.get_monthly_active_count()
+ if current_mau >= self.hs.config.max_mau_value:
+ raise ResourceLimitError(
+ 403, "Monthly Active User Limit Exceeded",
+
+ admin_uri=self.hs.config.admin_uri,
+ errcode=Codes.RESOURCE_LIMIT_EXCEED,
+ limit_type="monthly_active_user"
+ )
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 4df930c8d1..b0da506f6d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,3 +95,11 @@ class RoomCreationPreset(object):
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"
+
+
+# the version we will give rooms which are created on this server
+DEFAULT_ROOM_VERSION = "1"
+
+# vdh-test-version is a placeholder to get room versioning support working and tested
+# until we have a working v2.
+KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"}
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b41d595059..e26001ab12 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -55,7 +56,9 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
- MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
+ RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED"
+ UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
+ INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
class CodeMessageException(RuntimeError):
@@ -228,6 +231,30 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs)
+class ResourceLimitError(SynapseError):
+ """
+ Any error raised when there is a problem with resource usage.
+ For instance, the monthly active user limit for the server has been exceeded
+ """
+ def __init__(
+ self, code, msg,
+ errcode=Codes.RESOURCE_LIMIT_EXCEED,
+ admin_uri=None,
+ limit_type=None,
+ ):
+ self.admin_uri = admin_uri
+ self.limit_type = limit_type
+ super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
+
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ admin_uri=self.admin_uri,
+ limit_type=self.limit_type
+ )
+
+
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""
@@ -285,6 +312,27 @@ class LimitExceededError(SynapseError):
)
+class IncompatibleRoomVersionError(SynapseError):
+ """A server is trying to join a room whose version it does not support."""
+
+ def __init__(self, room_version):
+ super(IncompatibleRoomVersionError, self).__init__(
+ code=400,
+ msg="Your homeserver does not support the features required to "
+ "join this room",
+ errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
+ )
+
+ self._room_version = room_version
+
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ room_version=self._room_version,
+ )
+
+
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server
interactions.
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 06cc8d90b8..3bb5b3da37 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
- for user_id in self.message_counts.keys():
+ for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 391bd14c5c..7c866e246a 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
logger.info("Metrics now reporting on %s:%d", host, port)
-def listen_tcp(bind_addresses, port, factory, backlog=50):
+def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
"""
@@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
check_bind_error(e, address, bind_addresses)
-def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+def listen_ssl(
+ bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
+):
"""
Create an SSL socket for a port and several addresses
"""
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 9a37384fb7..3348a8ec6d 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index e2c91123db..ab79a45646 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet,
@@ -66,7 +66,7 @@ class ClientReaderSlavedStore(
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
- TransactionStore,
+ SlavedTransactionStore,
SlavedClientIpStore,
BaseSlavedStore,
):
@@ -168,11 +168,13 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 374f115644..03d39968a8 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -43,7 +43,7 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import (
JoinRoomAliasServlet,
@@ -63,7 +63,7 @@ logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore(
DirectoryStore,
- TransactionStore,
+ SlavedTransactionStore,
SlavedProfileStore,
SlavedAccountDataStore,
SlavedPusherStore,
@@ -174,11 +174,13 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 7af00b8bcf..7d8105778d 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -32,11 +32,17 @@ from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -49,11 +55,17 @@ logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
+ SlavedAccountDataStore,
+ SlavedProfileStore,
+ SlavedApplicationServiceStore,
+ SlavedPusherStore,
+ SlavedPushRuleStore,
+ SlavedReceiptsStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
- TransactionStore,
+ SlavedTransactionStore,
BaseSlavedStore,
):
pass
@@ -143,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 18469013fa..d59007099b 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
@@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore(
- SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
+ SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
):
def __init__(self, db_conn, hs):
@@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self)
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(FederationSenderReplicationHandler, self).on_rdata(
+ yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
@@ -186,11 +187,13 @@ def start(config_options):
config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index b5f78f4640..8d484c1cd4 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy")
+class PresenceStatusStubServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+
+ def __init__(self, hs):
+ super(PresenceStatusStubServlet, self).__init__(hs)
+ self.http_client = hs.get_simple_http_client()
+ self.auth = hs.get_auth()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+ headers = {
+ "Authorization": auth_headers,
+ }
+ result = yield self.http_client.get_json(
+ self.main_uri + request.uri,
+ headers=headers,
+ )
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ yield self.auth.get_user_by_req(request)
+ defer.returnValue((200, {}))
+
+
class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource)
+
+ # If presence is disabled, use the stub servlet that does
+ # not allow sending presence
+ if not self.config.use_presence:
+ PresenceStatusStubServlet(self).register(resource)
+
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
@@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
+ reactor=self.get_reactor()
)
logger.info("Synapse client reader now listening on port %d", port)
@@ -208,11 +245,13 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index fba51c26e8..005921dcf7 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics
-current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
-max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
+current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
def setup(config_options):
@@ -338,6 +338,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@@ -346,6 +347,7 @@ def setup(config_options):
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
@@ -519,17 +521,27 @@ def run(hs):
# table will decrease
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
+ # monthly active user limiting functionality
+ clock.looping_call(
+ hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
+ )
+ hs.get_datastore().reap_monthly_active_users()
+
@defer.inlineCallbacks
def generate_monthly_active_users():
count = 0
if hs.config.limit_usage_by_mau:
- count = yield hs.get_datastore().count_monthly_users()
+ count = yield hs.get_datastore().get_monthly_active_count()
current_mau_gauge.set(float(count))
- max_mau_value_gauge.set(float(hs.config.max_mau_value))
+ max_mau_gauge.set(float(hs.config.max_mau_value))
+ hs.get_datastore().initialise_reserved_users(
+ hs.config.mau_limits_reserved_threepids
+ )
generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+ # End of monthly active user settings
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 749bbf37d0..fd1f6cbf7e 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
@@ -52,7 +52,7 @@ class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedClientIpStore,
- TransactionStore,
+ SlavedTransactionStore,
BaseSlavedStore,
MediaRepositoryStore,
):
@@ -155,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 9295a51d5b..a4fc7e91fa 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
@@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
+ self.pusher_pool.on_new_notifications(
token, token,
)
elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
+ self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
)
except Exception:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index e201f18efd..27e1998660 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -114,7 +114,10 @@ class SynchrotronPresence(object):
logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
+ if self.hs.config.use_presence:
+ self.hs.get_tcp_replication().send_user_sync(
+ user_id, is_syncing, last_sync_ms
+ )
def mark_as_coming_online(self, user_id):
"""A user has started syncing. Send a UserSync to the master, unless they
@@ -211,10 +214,13 @@ class SynchrotronPresence(object):
yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self):
- return [
- user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
- if count > 0
- ]
+ if self.hs.config.use_presence:
+ return [
+ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
+ if count > 0
+ ]
+ else:
+ return set()
class SynchrotronTyping(object):
@@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 637a89530a..1388a42b59 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(UserDirectoryReplicationHandler, self).on_rdata(
+ yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == "current_state_deltas":
@@ -214,11 +215,13 @@ def start(config_options):
config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config)
+ tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
+ tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index a87b11a1df..3f187adfc8 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file:
# TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
+ log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
+ encoding='utf8'
)
def sighup(signum, stack):
@@ -193,9 +194,8 @@ def setup_logging(config, use_worker_options=False):
def sighup(signum, stack):
# it might be better to use a file watcher or something for this.
- logging.info("Reloading log config from %s due to SIGHUP",
- log_config)
load_log_config()
+ logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 6a471a0a5e..68a612e594 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -49,6 +49,9 @@ class ServerConfig(Config):
# "disable" federation
self.send_federation = config.get("send_federation", True)
+ # Whether to enable user presence.
+ self.use_presence = config.get("use_presence", True)
+
# Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True)
@@ -69,12 +72,24 @@ class ServerConfig(Config):
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
+ self.max_mau_value = 0
if self.limit_usage_by_mau:
self.max_mau_value = config.get(
"max_mau_value", 0,
)
- else:
- self.max_mau_value = 0
+ self.mau_limits_reserved_threepids = config.get(
+ "mau_limit_reserved_threepids", []
+ )
+
+ # Options to disable HS
+ self.hs_disabled = config.get("hs_disabled", False)
+ self.hs_disabled_message = config.get("hs_disabled_message", "")
+ self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
+
+ # Admin uri to direct users at should their instance become blocked
+ # due to resource constraints
+ self.admin_uri = config.get("admin_uri", None)
+
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
@@ -238,6 +253,9 @@ class ServerConfig(Config):
# hard limit.
soft_file_limit: 0
+ # Set to false to disable presence tracking on this homeserver.
+ use_presence: true
+
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
@@ -329,6 +347,32 @@ class ServerConfig(Config):
# - port: 9000
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
+
+
+ # Homeserver blocking
+ #
+ # How to reach the server admin, used in ResourceLimitError
+ # admin_uri: 'mailto:admin@server.com'
+ #
+ # Global block config
+ #
+ # hs_disabled: False
+ # hs_disabled_message: 'Human readable reason for why the HS is blocked'
+ # hs_disabled_limit_type: 'error code(str), to help clients decode reason'
+ #
+ # Monthly Active User Blocking
+ #
+ # Enables monthly active user checking
+ # limit_usage_by_mau: False
+ # max_mau_value: 50
+ #
+ # Sometimes the server admin will want to ensure certain accounts are
+ # never blocked by mau checking. These accounts are specified here.
+ #
+ # mau_limit_reserved_threepids:
+ # - medium: 'email'
+ # address: 'reserved_user@example.com'
+
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index a1e1d0d33a..1a391adec1 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from zope.interface import implementer
+
from OpenSSL import SSL, crypto
-from twisted.internet import ssl
from twisted.internet._sslverify import _defaultCurveName
+from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
+from twisted.internet.ssl import CertificateOptions, ContextFactory
+from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
-class ServerContextFactory(ssl.ContextFactory):
+class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
- connections and to make connections to remote servers."""
+ connections."""
def __init__(self, config):
self._context = SSL.Context(SSL.SSLv23_METHOD)
@@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
def getContext(self):
return self._context
+
+
+def _idnaBytes(text):
+ """
+ Convert some text typed by a human into some ASCII bytes. This is a
+ copy of twisted.internet._idna._idnaBytes. For documentation, see the
+ twisted documentation.
+ """
+ try:
+ import idna
+ except ImportError:
+ return text.encode("idna")
+ else:
+ return idna.encode(text)
+
+
+def _tolerateErrors(wrapped):
+ """
+ Wrap up an info_callback for pyOpenSSL so that if something goes wrong
+ the error is immediately logged and the connection is dropped if possible.
+ This is a copy of twisted.internet._sslverify._tolerateErrors. For
+ documentation, see the twisted documentation.
+ """
+
+ def infoCallback(connection, where, ret):
+ try:
+ return wrapped(connection, where, ret)
+ except: # noqa: E722, taken from the twisted implementation
+ f = Failure()
+ logger.exception("Error during info_callback")
+ connection.get_app_data().failVerification(f)
+
+ return infoCallback
+
+
+@implementer(IOpenSSLClientConnectionCreator)
+class ClientTLSOptions(object):
+ """
+ Client creator for TLS without certificate identity verification. This is a
+ copy of twisted.internet._sslverify.ClientTLSOptions with the identity
+ verification left out. For documentation, see the twisted documentation.
+ """
+
+ def __init__(self, hostname, ctx):
+ self._ctx = ctx
+ self._hostname = hostname
+ self._hostnameBytes = _idnaBytes(hostname)
+ ctx.set_info_callback(
+ _tolerateErrors(self._identityVerifyingInfoCallback)
+ )
+
+ def clientConnectionForTLS(self, tlsProtocol):
+ context = self._ctx
+ connection = SSL.Connection(context, None)
+ connection.set_app_data(tlsProtocol)
+ return connection
+
+ def _identityVerifyingInfoCallback(self, connection, where, ret):
+ if where & SSL.SSL_CB_HANDSHAKE_START:
+ connection.set_tlsext_host_name(self._hostnameBytes)
+
+
+class ClientTLSOptionsFactory(object):
+ """Factory for Twisted ClientTLSOptions that are used to make connections
+ to remote servers for federation."""
+
+ def __init__(self, config):
+ # We don't use config options yet
+ pass
+
+ def get_options(self, host):
+ return ClientTLSOptions(
+ host.decode('utf-8'),
+ CertificateOptions(verify=False).getContext()
+ )
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 668b4f517d..c20a32096a 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks
-def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
+def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
factory.path = path
factory.host = server_name
endpoint = matrix_federation_endpoint(
- reactor, server_name, ssl_context_factory, timeout=30
+ reactor, server_name, tls_client_options_factory, timeout=30
)
for i in range(5):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e95b9fb43e..30e2742102 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -512,7 +512,7 @@ class Keyring(object):
continue
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_server_context_factory,
+ server_name, self.hs.tls_client_options_factory,
path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id),
)).encode("ascii"),
@@ -655,7 +655,7 @@ class Keyring(object):
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_server_context_factory
+ server_name, self.hs.tls_client_options_factory
)
# Check the response.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index b32f64e729..6baeccca38 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -20,7 +20,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id
@@ -83,6 +83,14 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403,
"Creation event's room_id domain does not match sender's"
)
+
+ room_version = event.content.get("room_version", "1")
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ raise AuthError(
+ 403,
+ "room appears to have unsupported version %s" % (
+ room_version,
+ ))
# FIXME
logger.debug("Allowing! %s", event)
return
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7550e11b6e..c9f3c2d352 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,7 +25,7 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
FederationDeniedError,
@@ -518,10 +518,10 @@ class FederationClient(FederationBase):
description, destination, exc_info=1,
)
- raise RuntimeError("Failed to %s via any server", description)
+ raise RuntimeError("Failed to %s via any server" % (description, ))
def make_membership_event(self, destinations, room_id, user_id, membership,
- content={},):
+ content, params):
"""
Creates an m.room.member event, with context, without participating in the room.
@@ -537,8 +537,10 @@ class FederationClient(FederationBase):
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
- content (object): Any additional data to put into the content field
+ content (dict): Any additional data to put into the content field
of the event.
+ params (dict[str, str|Iterable[str]]): Query parameters to include in the
+ request.
Return:
Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event.
@@ -558,10 +560,12 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event(
- destination, room_id, user_id, membership
+ destination, room_id, user_id, membership, params,
)
- pdu_dict = ret["event"]
+ pdu_dict = ret.get("event", None)
+ if not isinstance(pdu_dict, dict):
+ raise InvalidResponseError("Bad 'event' field in response")
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
@@ -605,6 +609,26 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable.
"""
+ def check_authchain_validity(signed_auth_chain):
+ for e in signed_auth_chain:
+ if e.type == EventTypes.Create:
+ create_event = e
+ break
+ else:
+ raise InvalidResponseError(
+ "no %s in auth chain" % (EventTypes.Create,),
+ )
+
+ # the room version should be sane.
+ room_version = create_event.content.get("room_version", "1")
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ # This shouldn't be possible, because the remote server should have
+ # rejected the join attempt during make_join.
+ raise InvalidResponseError(
+ "room appears to have unsupported version %s" % (
+ room_version,
+ ))
+
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
@@ -661,7 +685,7 @@ class FederationClient(FederationBase):
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
- auth_chain.sort(key=lambda e: e.depth)
+ check_authchain_validity(signed_auth)
defer.returnValue({
"state": signed_state,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index bf89d568af..3e0cd294a1 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,14 +27,24 @@ from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, FederationError, NotFoundError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ FederationError,
+ IncompatibleRoomVersionError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
+from synapse.replication.http.federation import (
+ ReplicationFederationSendEduRestServlet,
+ ReplicationGetQueryRestServlet,
+)
from synapse.types import get_domain_from_id
-from synapse.util import async
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logutils import log_function
@@ -61,8 +71,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
- self._server_linearizer = async.Linearizer("fed_server")
- self._transaction_linearizer = async.Linearizer("fed_txn_handler")
+ self._server_linearizer = Linearizer("fed_server")
+ self._transaction_linearizer = Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
@@ -194,7 +204,7 @@ class FederationServer(FederationBase):
event_id, f.getTraceback().rstrip(),
)
- yield async.concurrently_execute(
+ yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT,
)
@@ -323,12 +333,21 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp))
@defer.inlineCallbacks
- def on_make_join_request(self, origin, room_id, user_id):
+ def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = yield self.store.get_room_version(room_id)
+ if room_version not in supported_versions:
+ logger.warn("Room version %s not in %s", room_version, supported_versions)
+ raise IncompatibleRoomVersionError(room_version=room_version)
+
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+ defer.returnValue({
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version,
+ })
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
@@ -745,6 +764,8 @@ class FederationHandlerRegistry(object):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+ logger.info("Registering federation EDU handler for %r", edu_type)
+
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
@@ -763,6 +784,8 @@ class FederationHandlerRegistry(object):
"Already have a Query handler for %s" % (query_type,)
)
+ logger.info("Registering federation query handler for %r", query_type)
+
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@@ -785,3 +808,49 @@ class FederationHandlerRegistry(object):
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
+
+
+class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
+ """A FederationHandlerRegistry for worker processes.
+
+ When receiving EDU or queries it will check if an appropriate handler has
+ been registered on the worker, if there isn't one then it calls off to the
+ master process.
+ """
+
+ def __init__(self, hs):
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+ self.clock = hs.get_clock()
+
+ self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
+ self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
+
+ super(ReplicationFederationHandlerRegistry, self).__init__()
+
+ def on_edu(self, edu_type, origin, content):
+ """Overrides FederationHandlerRegistry
+ """
+ handler = self.edu_handlers.get(edu_type)
+ if handler:
+ return super(ReplicationFederationHandlerRegistry, self).on_edu(
+ edu_type, origin, content,
+ )
+
+ return self._send_edu(
+ edu_type=edu_type,
+ origin=origin,
+ content=content,
+ )
+
+ def on_query(self, query_type, args):
+ """Overrides FederationHandlerRegistry
+ """
+ handler = self.query_handlers.get(query_type)
+ if handler:
+ return handler(args)
+
+ return self._get_query_client(
+ query_type=query_type,
+ args=args,
+ )
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 78f9d40a3a..94d7423d01 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,8 @@ from synapse.api.errors import FederationDeniedError, HttpResponseException
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import (
LaterGauge,
+ event_processing_loop_counter,
+ event_processing_loop_room_count,
events_processed_counter,
sent_edus_counter,
sent_transactions_counter,
@@ -56,6 +58,7 @@ class TransactionQueue(object):
"""
def __init__(self, hs):
+ self.hs = hs
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -253,7 +256,13 @@ class TransactionQueue(object):
synapse.metrics.event_processing_last_ts.labels(
"federation_sender").set(ts)
- events_processed_counter.inc(len(events))
+ events_processed_counter.inc(len(events))
+
+ event_processing_loop_room_count.labels(
+ "federation_sender"
+ ).inc(len(events_by_room))
+
+ event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels(
"federation_sender").set(next_token)
@@ -300,6 +309,9 @@ class TransactionQueue(object):
Args:
states (list(UserPresenceState))
"""
+ if not self.hs.config.use_presence:
+ # No-op if presence is disabled.
+ return
# First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 4529d454af..b4fbe2c9d5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -195,7 +195,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def make_membership_event(self, destination, room_id, user_id, membership):
+ def make_membership_event(self, destination, room_id, user_id, membership, params):
"""Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs.
@@ -205,6 +205,8 @@ class TransportLayerClient(object):
room_id (str): room to join/leave
user_id (str): user to be joined/left
membership (str): one of join/leave
+ params (dict[str, str|Iterable[str]]): Query parameters to include in the
+ request.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -241,6 +243,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(
destination=destination,
path=path,
+ args=params,
retry_on_dns_fail=retry_on_dns_fail,
timeout=20000,
ignore_backoff=ignore_backoff,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index eae5f2b427..77969a4f38 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -190,6 +190,41 @@ def _parse_auth_header(header_bytes):
class BaseFederationServlet(object):
+ """Abstract base class for federation servlet classes.
+
+ The servlet object should have a PATH attribute which takes the form of a regexp to
+ match against the request path (excluding the /federation/v1 prefix).
+
+ The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
+ the appropriate HTTP method. These methods have the signature:
+
+ on_<METHOD>(self, origin, content, query, **kwargs)
+
+ With arguments:
+
+ origin (unicode|None): The authenticated server_name of the calling server,
+ unless REQUIRE_AUTH is set to False and authentication failed.
+
+ content (unicode|None): decoded json body of the request. None if the
+ request was a GET.
+
+ query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
+ (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
+ yet.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+
+ Raises:
+ SynapseError: to return an error code
+
+ Exception: other exceptions will be caught, logged, and a 500 will be
+ returned.
+ """
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name):
@@ -204,6 +239,18 @@ class BaseFederationServlet(object):
@defer.inlineCallbacks
@functools.wraps(func)
def new_func(request, *args, **kwargs):
+ """ A callback which can be passed to HttpServer.RegisterPaths
+
+ Args:
+ request (twisted.web.http.Request):
+ *args: unused?
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: (response code, response object) as returned
+ by the callback method. None if the request has already been handled.
+ """
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
@@ -384,9 +431,31 @@ class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
@defer.inlineCallbacks
- def on_GET(self, origin, content, query, context, user_id):
+ def on_GET(self, origin, _content, query, context, user_id):
+ """
+ Args:
+ origin (unicode): The authenticated server_name of the calling server
+
+ _content (None): (GETs don't have bodies)
+
+ query (dict[bytes, list[bytes]]): Query params from the request.
+
+ **kwargs (dict[unicode, unicode]): the dict mapping keys to path
+ components as specified in the path match regexp.
+
+ Returns:
+ Deferred[(int, object)|None]: either (response code, response object) to
+ return a JSON response, or None if the request has already been handled.
+ """
+ versions = query.get(b'ver')
+ if versions is not None:
+ supported_versions = [v.decode("utf-8") for v in versions]
+ else:
+ supported_versions = ["1"]
+
content = yield self.handler.on_make_join_request(
origin, context, user_id,
+ supported_versions=supported_versions,
)
defer.returnValue((200, content))
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index ee41aed69e..f0f89af7dc 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -23,6 +23,10 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
+from synapse.metrics import (
+ event_processing_loop_counter,
+ event_processing_loop_room_count,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure
@@ -136,6 +140,12 @@ class ApplicationServicesHandler(object):
events_processed_counter.inc(len(events))
+ event_processing_loop_room_count.labels(
+ "appservice_sender"
+ ).inc(len(events_by_room))
+
+ event_processing_loop_counter.labels("appservice_sender").inc()
+
synapse.metrics.event_processing_lag.labels(
"appservice_sender").set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 184eef09d0..4a81bd2ba9 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
- yield self._check_mau_limits()
+ yield self.auth.check_auth_blocking(user_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
@@ -734,7 +734,6 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
- yield self._check_mau_limits()
auth_api = self.hs.get_auth()
user_id = None
try:
@@ -743,6 +742,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", True, user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ yield self.auth.check_auth_blocking(user_id)
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -828,12 +828,26 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address):
+ """Attempts to unbind the 3pid on the identity servers and deletes it
+ from the local database.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+
+ Returns:
+ Deferred[bool]: Returns True if successfully unbound the 3pid on
+ the identity server, False if identity server doesn't support the
+ unbind API.
+ """
+
# 'Canonicalise' email addresses as per above
if medium == 'email':
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
- yield identity_handler.unbind_threepid(
+ result = yield identity_handler.try_unbind_threepid(
user_id,
{
'medium': medium,
@@ -841,10 +855,10 @@ class AuthHandler(BaseHandler):
},
)
- ret = yield self.store.user_delete_threepid(
+ yield self.store.user_delete_threepid(
user_id, medium, address,
)
- defer.returnValue(ret)
+ defer.returnValue(result)
def _save_session(self, session):
# TODO: Persistent storage
@@ -907,19 +921,6 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
- @defer.inlineCallbacks
- def _check_mau_limits(self):
- """
- Ensure that if mau blocking is enabled that invalid users cannot
- log in.
- """
- if self.hs.config.limit_usage_by_mau is True:
- current_mau = yield self.store.count_monthly_users()
- if current_mau >= self.hs.config.max_mau_value:
- raise AuthError(
- 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
- )
-
@attr.s
class MacaroonGenerator(object):
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index b3c5a9ee64..b078df4a76 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler):
erase_data (bool): whether to GDPR-erase the user's data
Returns:
- Deferred
+ Deferred[bool]: True if identity server supports removing
+ threepids, otherwise False.
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
@@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler):
# leave the user still active so they can try again.
# Ideally we would prevent password resets and then do this in the
# background thread.
+
+ # This will be set to false if the identity server doesn't support
+ # unbinding
+ identity_server_supports_unbinding = True
+
threepids = yield self.store.user_get_threepids(user_id)
for threepid in threepids:
try:
- yield self._identity_handler.unbind_threepid(
+ result = yield self._identity_handler.try_unbind_threepid(
user_id,
{
'medium': threepid['medium'],
'address': threepid['address'],
},
)
+ identity_server_supports_unbinding &= result
except Exception:
# Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server")
@@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ defer.returnValue(identity_server_supports_unbinding)
+
def _start_user_parting(self):
"""
Start the process that goes through the table of users
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2d44f15da3..9e017116a9 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 533b82c783..3dd107a285 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -30,7 +30,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.api.constants import (
+ KNOWN_ROOM_VERSIONS,
+ EventTypes,
+ Membership,
+ RejectedReason,
+)
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -44,10 +49,15 @@ from synapse.crypto.event_signing import (
compute_event_signature,
)
from synapse.events.validator import EventValidator
+from synapse.replication.http.federation import (
+ ReplicationCleanRoomRestServlet,
+ ReplicationFederationSendEventsRestServlet,
+)
+from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function
@@ -86,6 +96,18 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self.config = hs.config
+ self.http_client = hs.get_simple_http_client()
+
+ self._send_events_to_master = (
+ ReplicationFederationSendEventsRestServlet.make_client(hs)
+ )
+ self._notify_user_membership_change = (
+ ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+ )
+ self._clean_room_for_join_client = (
+ ReplicationCleanRoomRestServlet.make_client(hs)
+ )
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -922,6 +944,9 @@ class FederationHandler(BaseHandler):
joinee,
"join",
content,
+ params={
+ "ver": KNOWN_ROOM_VERSIONS,
+ },
)
# This shouldn't happen, because the RoomMemberHandler has a
@@ -1150,7 +1175,7 @@ class FederationHandler(BaseHandler):
)
context = yield self.state_handler.compute_event_context(event)
- yield self._persist_events([(event, context)])
+ yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event)
@@ -1181,19 +1206,20 @@ class FederationHandler(BaseHandler):
)
context = yield self.state_handler.compute_event_context(event)
- yield self._persist_events([(event, context)])
+ yield self.persist_events_and_notify([(event, context)])
defer.returnValue(event)
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
- content={},):
+ content={}, params=None):
origin, pdu = yield self.federation_client.make_membership_event(
target_hosts,
room_id,
user_id,
membership,
content,
+ params=params,
)
logger.debug("Got response to make_%s: %s", membership, pdu)
@@ -1423,7 +1449,7 @@ class FederationHandler(BaseHandler):
event, context
)
- yield self._persist_events(
+ yield self.persist_events_and_notify(
[(event, context)],
backfilled=backfilled,
)
@@ -1461,7 +1487,7 @@ class FederationHandler(BaseHandler):
], consumeErrors=True,
))
- yield self._persist_events(
+ yield self.persist_events_and_notify(
[
(ev_info["event"], context)
for ev_info, context in zip(event_infos, contexts)
@@ -1549,7 +1575,7 @@ class FederationHandler(BaseHandler):
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
- yield self._persist_events(
+ yield self.persist_events_and_notify(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
@@ -1560,7 +1586,7 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
- yield self._persist_events(
+ yield self.persist_events_and_notify(
[(event, new_event_context)],
)
@@ -2288,7 +2314,7 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.hs.get_simple_http_client().get_json(
+ response = yield self.http_client.get_json(
url,
{"public_key": public_key}
)
@@ -2301,7 +2327,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Third party certificate was invalid")
@defer.inlineCallbacks
- def _persist_events(self, event_and_contexts, backfilled=False):
+ def persist_events_and_notify(self, event_and_contexts, backfilled=False):
"""Persists events and tells the notifier/pushers about them, if
necessary.
@@ -2313,14 +2339,21 @@ class FederationHandler(BaseHandler):
Returns:
Deferred
"""
- max_stream_id = yield self.store.persist_events(
- event_and_contexts,
- backfilled=backfilled,
- )
+ if self.config.worker_app:
+ yield self._send_events_to_master(
+ store=self.store,
+ event_and_contexts=event_and_contexts,
+ backfilled=backfilled
+ )
+ else:
+ max_stream_id = yield self.store.persist_events(
+ event_and_contexts,
+ backfilled=backfilled,
+ )
- if not backfilled: # Never notify for backfilled events
- for event, _ in event_and_contexts:
- self._notify_persisted_event(event, max_stream_id)
+ if not backfilled: # Never notify for backfilled events
+ for event, _ in event_and_contexts:
+ self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the
@@ -2353,15 +2386,30 @@ class FederationHandler(BaseHandler):
extra_users=extra_users
)
- logcontext.run_in_background(
- self.pusher_pool.on_new_notifications,
+ self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id,
)
def _clean_room_for_join(self, room_id):
- return self.store.clean_room_for_join(room_id)
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Args:
+ room_id (str)
+ """
+ if self.config.worker_app:
+ return self._clean_room_for_join_client(room_id)
+ else:
+ return self.store.clean_room_for_join(room_id)
def user_joined_room(self, user, room_id):
"""Called when a new user has joined the room
"""
- return user_joined_room(self.distributor, user, room_id)
+ if self.config.worker_app:
+ return self._notify_user_membership_change(
+ room_id=room_id,
+ user_id=user.to_string(),
+ change="joined",
+ )
+ else:
+ return user_joined_room(self.distributor, user, room_id)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 1d36d967c3..5feb3f22a6 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def unbind_threepid(self, mxid, threepid):
- """
- Removes a binding from an identity server
+ def try_unbind_threepid(self, mxid, threepid):
+ """Removes a binding from an identity server
+
Args:
mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be removed
+ Raises:
+ SynapseError: If we failed to contact the identity server
+
Returns:
- Deferred[bool]: True on success, otherwise False
+ Deferred[bool]: True on success, otherwise False if the identity
+ server doesn't support unbinding
"""
logger.debug("unbinding threepid %r from %s", threepid, mxid)
if not self.trusted_id_servers:
@@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler):
content=content,
destination_is=id_server,
)
- yield self.http_client.post_json_get_json(
- url,
- content,
- headers,
- )
+ try:
+ yield self.http_client.post_json_get_json(
+ url,
+ content,
+ headers,
+ )
+ except HttpResponseException as e:
+ if e.code in (400, 404, 501,):
+ # The remote server probably doesn't support unbinding (yet)
+ logger.warn("Received %d response while unbinding threepid", e.code)
+ defer.returnValue(False)
+ else:
+ logger.error("Failed to unbind threepid on identity server: %s", e)
+ raise SynapseError(502, "Failed to contact identity server")
+
defer.returnValue(True)
@defer.inlineCallbacks
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 40e7580a61..e009395207 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
@@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence():
+ # If presence is disabled, return an empty list
+ if not self.hs.config.use_presence:
+ defer.returnValue([])
+
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 39d7724778..e484061cc0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -25,17 +25,24 @@ from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ ConsentNotGivenError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
-from synapse.replication.http.send_event import send_event_to_master
+from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.types import RoomAlias, UserID
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
+from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -82,28 +89,85 @@ class MessageHandler(object):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_state_events(self, user_id, room_id, is_guest=False):
+ def get_state_events(
+ self, user_id, room_id, types=None, filtered_types=None,
+ at_token=None, is_guest=False,
+ ):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
- left the room return the state events from when they left.
+ left the room return the state events from when they left. If an explicit
+ 'at' parameter is passed, return the state events as of that event, if
+ visible.
Args:
user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from.
+ types(list[(str, str|None)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. If `state_key` is None,
+ all events are returned of the given type.
+ May be None, which matches any key.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+ at_token(StreamToken|None): the stream token of the at which we are requesting
+ the stats. If the user is not allowed to view the state as of that
+ stream token, we raise a 403 SynapseError. If None, returns the current
+ state based on the current_state_events table.
+ is_guest(bool): whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
+ Raises:
+ NotFoundError (404) if the at token does not yield an event
+
+ AuthError (403) if the user doesn't have permission to view
+ members of this room.
"""
- membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ if at_token:
+ # FIXME this claims to get the state at a stream position, but
+ # get_recent_events_for_room operates by topo ordering. This therefore
+ # does not reliably give you the state at the given stream position.
+ # (https://github.com/matrix-org/synapse/issues/3305)
+ last_events, _ = yield self.store.get_recent_events_for_room(
+ room_id, end_token=at_token.room_key, limit=1,
+ )
- if membership == Membership.JOIN:
- room_state = yield self.state.get_current_state(room_id)
- elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], None
+ if not last_events:
+ raise NotFoundError("Can't find event for token %s" % (at_token, ))
+
+ visible_events = yield filter_events_for_client(
+ self.store, user_id, last_events,
)
- room_state = room_state[membership_event_id]
+
+ event = last_events[0]
+ if visible_events:
+ room_state = yield self.store.get_state_for_events(
+ [event.event_id], types, filtered_types=filtered_types,
+ )
+ room_state = room_state[event.event_id]
+ else:
+ raise AuthError(
+ 403,
+ "User %s not allowed to view events in room %s at token %s" % (
+ user_id, room_id, at_token,
+ )
+ )
+ else:
+ membership, membership_event_id = (
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, user_id,
+ )
+ )
+
+ if membership == Membership.JOIN:
+ state_ids = yield self.store.get_filtered_current_state_ids(
+ room_id, types, filtered_types=filtered_types,
+ )
+ room_state = yield self.store.get_events(state_ids.values())
+ elif membership == Membership.LEAVE:
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], types, filtered_types=filtered_types,
+ )
+ room_state = room_state[membership_event_id]
now = self.clock.time_msec()
defer.returnValue(
@@ -171,7 +235,7 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
@@ -212,10 +276,14 @@ class EventCreationHandler(object):
where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database.
-
+ Raises:
+ ResourceLimitError if server is blocked to some resource being
+ exceeded
Returns:
Tuple of created event (FrozenEvent), Context
"""
+ yield self.auth.check_auth_blocking(requester.user.to_string())
+
builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder)
@@ -559,12 +627,9 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
- yield send_event_to_master(
- clock=self.hs.get_clock(),
+ yield self.send_event_to_master(
+ event_id=event.event_id,
store=self.store,
- client=self.http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
requester=requester,
event=event,
context=context,
@@ -713,11 +778,8 @@ class EventCreationHandler(object):
event, context=context
)
- # this intentionally does not yield: we don't care about the result
- # and don't need to wait for it.
- run_in_background(
- self.pusher_pool.on_new_notifications,
- event_stream_id, max_stream_id
+ self.pusher_pool.on_new_notifications(
+ event_stream_id, max_stream_id,
)
def _notify():
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index b2849783ed..5170d093e3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -18,11 +18,11 @@ import logging
from twisted.internet import defer
from twisted.python.failure import Failure
-from synapse.api.constants import Membership
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from synapse.types import RoomStreamToken
-from synapse.util.async import ReadWriteLock
+from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -251,6 +251,33 @@ class PaginationHandler(object):
is_peeking=(member_event_id is None),
)
+ state = None
+ if event_filter and event_filter.lazy_load_members():
+ # TODO: remove redundant members
+
+ types = [
+ (EventTypes.Member, state_key)
+ for state_key in set(
+ event.sender # FIXME: we also care about invite targets etc.
+ for event in events
+ )
+ ]
+
+ state_ids = yield self.store.get_state_ids_for_event(
+ events[0].event_id, types=types,
+ )
+
+ if state_ids:
+ state = yield self.store.get_events(list(state_ids.values()))
+
+ if state:
+ state = yield filter_events_for_client(
+ self.store,
+ user_id,
+ state.values(),
+ is_peeking=(member_event_id is None),
+ )
+
time_now = self.clock.time_msec()
chunk = {
@@ -262,4 +289,10 @@ class PaginationHandler(object):
"end": next_token.to_string(),
}
+ if state:
+ chunk["state"] = [
+ serialize_event(e, time_now, as_client_event)
+ for e in state
+ ]
+
defer.returnValue(chunk)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3732830194..ba3856674d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
@@ -95,6 +95,7 @@ class PresenceHandler(object):
Args:
hs (synapse.server.HomeServer):
"""
+ self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
@@ -230,6 +231,10 @@ class PresenceHandler(object):
earlier than they should when synapse is restarted. This affect of this
is some spurious presence changes that will self-correct.
"""
+ # If the DB pool has already terminated, don't try updating
+ if not self.hs.get_db_pool().running:
+ return
+
logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state)
@@ -390,6 +395,10 @@ class PresenceHandler(object):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
+ # If presence is disabled, no-op
+ if not self.hs.config.use_presence:
+ return
+
user_id = user.to_string()
bump_active_time_counter.inc()
@@ -419,6 +428,11 @@ class PresenceHandler(object):
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
+ # Override if it should affect the user's presence, if presence is
+ # disabled.
+ if not self.hs.config.use_presence:
+ affect_presence = False
+
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
@@ -464,13 +478,16 @@ class PresenceHandler(object):
Returns:
set(str): A set of user_id strings.
"""
- syncing_user_ids = {
- user_id for user_id, count in self.user_to_num_current_syncs.items()
- if count
- }
- for user_ids in self.external_process_to_current_syncs.values():
- syncing_user_ids.update(user_ids)
- return syncing_user_ids
+ if self.hs.config.use_presence:
+ syncing_user_ids = {
+ user_id for user_id, count in self.user_to_num_current_syncs.items()
+ if count
+ }
+ for user_ids in self.external_process_to_current_syncs.values():
+ syncing_user_ids.update(user_ids)
+ return syncing_user_ids
+ else:
+ return set()
@defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 995460f82a..32108568c6 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cb905a3903..a6f3181f09 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util import logcontext
-from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler
@@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r["room_id"] for r in receipts]))
- with PreserveLoggingContext():
- self.notifier.on_new_event(
- "receipt_key", max_batch_id, rooms=affected_room_ids
- )
- # Note that the min here shouldn't be relied upon to be accurate.
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ self.notifier.on_new_event(
+ "receipt_key", max_batch_id, rooms=affected_room_ids
+ )
+ # Note that the min here shouldn't be relied upon to be accurate.
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids,
+ )
- defer.returnValue(True)
+ defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 289704b241..f03ee1476b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -28,7 +28,7 @@ from synapse.api.errors import (
)
from synapse.http.client import CaptchaServerHttpClient
from synapse.types import RoomAlias, RoomID, UserID, create_requester
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
@@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
- yield self._check_mau_limits()
+
+ yield self.auth.check_auth_blocking()
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
@@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
)
- yield self._check_mau_limits()
+ yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
- yield self._check_mau_limits()
+ yield self.auth.check_auth_blocking()
need_register = True
try:
@@ -533,16 +534,3 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts,
action="join",
)
-
- @defer.inlineCallbacks
- def _check_mau_limits(self):
- """
- Do not accept registrations if monthly active user limits exceeded
- and limiting is enabled
- """
- if self.hs.config.limit_usage_by_mau is True:
- current_mau = yield self.store.count_monthly_users()
- if current_mau >= self.hs.config.max_mau_value:
- raise RegistrationError(
- 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
- )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7b7804d9b2..c3f820b975 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -21,9 +21,17 @@ import math
import string
from collections import OrderedDict
+from six import string_types
+
from twisted.internet import defer
-from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
+from synapse.api.constants import (
+ DEFAULT_ROOM_VERSION,
+ KNOWN_ROOM_VERSIONS,
+ EventTypes,
+ JoinRules,
+ RoomCreationPreset,
+)
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -90,15 +98,34 @@ class RoomCreationHandler(BaseHandler):
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
+ ResourceLimitError if server is blocked to some resource being
+ exceeded
"""
user_id = requester.user.to_string()
+ self.auth.check_auth_blocking(user_id)
+
if not self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
+ room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
+ if not isinstance(room_version, string_types):
+ raise SynapseError(
+ 400,
+ "room_version must be a string",
+ Codes.BAD_JSON,
+ )
+
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ raise SynapseError(
+ 400,
+ "Your homeserver does not support this room version",
+ Codes.UNSUPPORTED_ROOM_VERSION,
+ )
+
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -184,6 +211,9 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
+ # override any attempt to set room versions via the creation_content
+ creation_content["room_version"] = room_version
+
room_member_handler = self.hs.get_room_member_handler()
yield self._send_events_for_new_room(
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 828229f5c3..37e41afd61 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.types import ThirdPartyInstanceID
-from synapse.util.async import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0d4a3f4677..fb94b5d7d4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -30,7 +30,7 @@ import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 22d8b4b0d3..acc6eb8099 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -20,16 +20,24 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
- get_or_register_3pid_guest,
- notify_user_membership_change,
- remote_join,
- remote_reject_invite,
+ ReplicationRegister3PIDGuestRestServlet as Repl3PID,
+ ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
+ ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
+ ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler):
+ def __init__(self, hs):
+ super(RoomMemberWorkerHandler, self).__init__(hs)
+
+ self._get_register_3pid_client = Repl3PID.make_client(hs)
+ self._remote_join_client = ReplRemoteJoin.make_client(hs)
+ self._remote_reject_client = ReplRejectInvite.make_client(hs)
+ self._notify_change_client = ReplJoinedLeft.make_client(hs)
+
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
@@ -37,10 +45,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
- ret = yield remote_join(
- self.simple_http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
+ ret = yield self._remote_join_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
@@ -55,10 +60,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite
"""
- return remote_reject_invite(
- self.simple_http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
+ return self._remote_reject_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
@@ -68,10 +70,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
- return notify_user_membership_change(
- self.simple_http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
+ return self._notify_change_client(
user_id=target.to_string(),
room_id=room_id,
change="joined",
@@ -80,10 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
- return notify_user_membership_change(
- self.simple_http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
+ return self._notify_change_client(
user_id=target.to_string(),
room_id=room_id,
change="left",
@@ -92,10 +88,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
"""Implements RoomMemberHandler.get_or_register_3pid_guest
"""
- return get_or_register_3pid_guest(
- self.simple_http_client,
- host=self.config.worker_replication_host,
- port=self.config.worker_replication_http_port,
+ return self._get_register_3pid_client(
requester=requester,
medium=medium,
address=address,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index dff1f67dcb..648debc8aa 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user
from synapse.types import RoomStreamToken
-from synapse.util.async import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache
@@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"ephemeral",
"account_data",
"unread_notifications",
+ "summary",
])):
__slots__ = []
@@ -184,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
class SyncHandler(object):
def __init__(self, hs):
+ self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
@@ -191,6 +193,7 @@ class SyncHandler(object):
self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
@@ -198,19 +201,27 @@ class SyncHandler(object):
max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
+ @defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
- A Deferred SyncResult.
+ Deferred[SyncResult]
"""
- return self.response_cache.wrap(
+ # If the user is not part of the mau group, then check that limits have
+ # not been exceeded (if not part of the group by this point, almost certain
+ # auth_blocking will occur)
+ user_id = sync_config.user.to_string()
+ yield self.auth.check_auth_blocking(user_id)
+
+ res = yield self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
sync_config, since_token, timeout, full_state,
)
+ defer.returnValue(res)
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -495,9 +506,141 @@ class SyncHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
+ def compute_summary(self, room_id, sync_config, batch, state, now_token):
+ """ Works out a room summary block for this room, summarising the number
+ of joined members in the room, and providing the 'hero' members if the
+ room has no name so clients can consistently name rooms. Also adds
+ state events to 'state' if needed to describe the heroes.
+
+ Args:
+ room_id(str):
+ sync_config(synapse.handlers.sync.SyncConfig):
+ batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
+ the room that will be sent to the user.
+ state(dict): dict of (type, state_key) -> Event as returned by
+ compute_state_delta
+ now_token(str): Token of the end of the current batch.
+
+ Returns:
+ A deferred dict describing the room summary
+ """
+
+ # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
+ last_events, _ = yield self.store.get_recent_event_ids_for_room(
+ room_id, end_token=now_token.room_key, limit=1,
+ )
+
+ if not last_events:
+ defer.returnValue(None)
+ return
+
+ last_event = last_events[-1]
+ state_ids = yield self.store.get_state_ids_for_event(
+ last_event.event_id, [
+ (EventTypes.Member, None),
+ (EventTypes.Name, ''),
+ (EventTypes.CanonicalAlias, ''),
+ ]
+ )
+
+ member_ids = {
+ state_key: event_id
+ for (t, state_key), event_id in state_ids.iteritems()
+ if t == EventTypes.Member
+ }
+ name_id = state_ids.get((EventTypes.Name, ''))
+ canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
+
+ summary = {}
+
+ # FIXME: it feels very heavy to load up every single membership event
+ # just to calculate the counts.
+ member_events = yield self.store.get_events(member_ids.values())
+
+ joined_user_ids = []
+ invited_user_ids = []
+
+ for ev in member_events.values():
+ if ev.content.get("membership") == Membership.JOIN:
+ joined_user_ids.append(ev.state_key)
+ elif ev.content.get("membership") == Membership.INVITE:
+ invited_user_ids.append(ev.state_key)
+
+ # TODO: only send these when they change.
+ summary["m.joined_member_count"] = len(joined_user_ids)
+ summary["m.invited_member_count"] = len(invited_user_ids)
+
+ if name_id or canonical_alias_id:
+ defer.returnValue(summary)
+
+ # FIXME: order by stream ordering, not alphabetic
+
+ me = sync_config.user.to_string()
+ if (joined_user_ids or invited_user_ids):
+ summary['m.heroes'] = sorted(
+ [
+ user_id
+ for user_id in (joined_user_ids + invited_user_ids)
+ if user_id != me
+ ]
+ )[0:5]
+ else:
+ summary['m.heroes'] = sorted(
+ [user_id for user_id in member_ids.keys() if user_id != me]
+ )[0:5]
+
+ if not sync_config.filter_collection.lazy_load_members():
+ defer.returnValue(summary)
+
+ # ensure we send membership events for heroes if needed
+ cache_key = (sync_config.user.to_string(), sync_config.device_id)
+ cache = self.get_lazy_loaded_members_cache(cache_key)
+
+ # track which members the client should already know about via LL:
+ # Ones which are already in state...
+ existing_members = set(
+ user_id for (typ, user_id) in state.keys()
+ if typ == EventTypes.Member
+ )
+
+ # ...or ones which are in the timeline...
+ for ev in batch.events:
+ if ev.type == EventTypes.Member:
+ existing_members.add(ev.state_key)
+
+ # ...and then ensure any missing ones get included in state.
+ missing_hero_event_ids = [
+ member_ids[hero_id]
+ for hero_id in summary['m.heroes']
+ if (
+ cache.get(hero_id) != member_ids[hero_id] and
+ hero_id not in existing_members
+ )
+ ]
+
+ missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
+ missing_hero_state = missing_hero_state.values()
+
+ for s in missing_hero_state:
+ cache.set(s.state_key, s.event_id)
+ state[(EventTypes.Member, s.state_key)] = s
+
+ defer.returnValue(summary)
+
+ def get_lazy_loaded_members_cache(self, cache_key):
+ cache = self.lazy_loaded_members_cache.get(cache_key)
+ if cache is None:
+ logger.debug("creating LruCache for %r", cache_key)
+ cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
+ self.lazy_loaded_members_cache[cache_key] = cache
+ else:
+ logger.debug("found LruCache for %r", cache_key)
+ return cache
+
+ @defer.inlineCallbacks
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
full_state):
- """ Works out the differnce in state between the start of the timeline
+ """ Works out the difference in state between the start of the timeline
and the previous sync.
Args:
@@ -511,7 +654,7 @@ class SyncHandler(object):
full_state(bool): Whether to force returning the full state.
Returns:
- A deferred new event dictionary
+ A deferred dict of (type, state_key) -> Event
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -609,13 +752,7 @@ class SyncHandler(object):
if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id)
- cache = self.lazy_loaded_members_cache.get(cache_key)
- if cache is None:
- logger.debug("creating LruCache for %r", cache_key)
- cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
- self.lazy_loaded_members_cache[cache_key] = cache
- else:
- logger.debug("found LruCache for %r", cache_key)
+ cache = self.get_lazy_loaded_members_cache(cache_key)
# if it's a new sync sequence, then assume the client has had
# amnesia and doesn't want any recent lazy-loaded members
@@ -724,7 +861,7 @@ class SyncHandler(object):
since_token is None and
sync_config.filter_collection.blocks_all_presence()
)
- if not block_all_presence_data:
+ if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users
)
@@ -1416,7 +1553,6 @@ class SyncHandler(object):
if events == [] and tags is None:
return
- since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1459,6 +1595,18 @@ class SyncHandler(object):
full_state=full_state
)
+ summary = {}
+ if (
+ sync_config.filter_collection.lazy_load_members() and
+ (
+ any(ev.type == EventTypes.Member for ev in batch.events) or
+ since_token is None
+ )
+ ):
+ summary = yield self.compute_summary(
+ room_id, sync_config, batch, state, now_token
+ )
+
if room_builder.rtype == "joined":
unread_notifications = {}
room_sync = JoinedSyncResult(
@@ -1468,6 +1616,7 @@ class SyncHandler(object):
ephemeral=ephemeral,
account_data=account_data_events,
unread_notifications=unread_notifications,
+ summary=summary,
)
if room_sync or always_include:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3771e0b3f6..ab4fbf59b2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint
-from synapse.util.async import add_timeout_to_deferred
+from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d65daa72bb..b0c9369519 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
-
SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
@@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
return host, port
-def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
+def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
Args:
reactor: Twisted reactor.
destination (bytes): The name of the server to connect to.
- ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
- which generates SSL contexts to use for TLS.
+ tls_client_options_factory
+ (synapse.crypto.context_factory.ClientTLSOptionsFactory):
+ Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds
"""
@@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
if timeout is not None:
endpoint_kw_args.update(timeout=timeout)
- if ssl_context_factory is None:
+ if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
- ssl_context_factory,
+ tls_client_options_factory.get_options(host),
HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bf1aa29502..44b61e70a4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -43,7 +43,7 @@ from synapse.api.errors import (
from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext
-from synapse.util.async import add_timeout_to_deferred
+from synapse.util.async_helpers import add_timeout_to_deferred
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
- self.tls_server_context_factory = hs.tls_server_context_factory
+ self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc
return matrix_federation_endpoint(
reactor, destination, timeout=10,
- ssl_context_factory=self.tls_server_context_factory
+ tls_client_options_factory=self.tls_client_options_factory
)
@@ -439,7 +439,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
+ def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path
@@ -447,7 +447,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
- args (dict): A dictionary used to create query strings, defaults to
+ args (dict|None): A dictionary used to create query strings, defaults to
None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
@@ -702,6 +702,9 @@ def check_content_type_is_json(headers):
def encode_query_args(args):
+ if args is None:
+ return b""
+
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, string_types):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6dacb31037..2d5c23e673 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
from twisted.internet import defer
from twisted.python import failure
-from twisted.web import resource, server
+from twisted.web import resource
from twisted.web.server import NOT_DONE_YET
+from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo
import synapse.events
@@ -37,10 +38,13 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.metrics import Measure
+from synapse.util.logcontext import preserve_fn
+
+if PY3:
+ from io import BytesIO
+else:
+ from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__)
@@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling.
- Also adds logging as per wrap_request_handler_with_logging.
+ Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
The handler must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use
@@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
pretty_print=_request_user_agent_is_curl(request),
)
- return wrap_request_handler_with_logging(wrapped_request_handler)
+ return wrap_async_request_handler(wrapped_request_handler)
def wrap_html_request_handler(h):
"""Wraps a request handler method with exception handling.
- Also adds logging as per wrap_request_handler_with_logging.
+ Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
"""
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
return d
- return wrap_request_handler_with_logging(wrapped_request_handler)
+ return wrap_async_request_handler(wrapped_request_handler)
def _return_html_error(f, request):
@@ -170,46 +172,26 @@ def _return_html_error(f, request):
finish_request(request)
-def wrap_request_handler_with_logging(h):
- """Wraps a request handler to provide logging and metrics
+def wrap_async_request_handler(h):
+ """Wraps an async request handler so that it calls request.processing.
+
+ This helps ensure that work done by the request handler after the request is completed
+ is correctly recorded against the request metrics/logs.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
- As well as calling `request.processing` (which will log the response and
- duration for this request), the wrapped request handler will insert the
- request id into the logging context.
+ The handler may return a deferred, in which case the completion of the request isn't
+ logged until the deferred completes.
"""
@defer.inlineCallbacks
- def wrapped_request_handler(self, request):
- """
- Args:
- self:
- request (synapse.http.site.SynapseRequest):
- """
+ def wrapped_async_request_handler(self, request):
+ with request.processing():
+ yield h(self, request)
- request_id = request.get_request_id()
- with LoggingContext(request_id) as request_context:
- request_context.request = request_id
- with Measure(self.clock, "wrapped_request_handler"):
- # we start the request metrics timer here with an initial stab
- # at the servlet name. For most requests that name will be
- # JsonResource (or a subclass), and JsonResource._async_render
- # will update it once it picks a servlet.
- servlet_name = self.__class__.__name__
- with request.processing(servlet_name):
- with PreserveLoggingContext(request_context):
- d = defer.maybeDeferred(h, self, request)
-
- # record the arrival of the request *after*
- # dispatching to the handler, so that the handler
- # can update the servlet name in the request
- # metrics
- requests_counter.labels(request.method,
- request.request_metrics.name).inc()
- yield d
- return wrapped_request_handler
+ # we need to preserve_fn here, because the synchronous render method won't yield for
+ # us (obviously)
+ return preserve_fn(wrapped_async_request_handler)
class HttpServer(object):
@@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This gets called by twisted every time someone sends us a request.
"""
self._async_render(request)
- return server.NOT_DONE_YET
+ return NOT_DONE_YET
@wrap_json_request_handler
@defer.inlineCallbacks
@@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return
if pretty_print:
- json_bytes = (encode_pretty_printed_json(json_object) + "\n"
- ).encode("utf-8")
+ json_bytes = encode_pretty_printed_json(json_object) + b"\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
@@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors:
set_cors_headers(request)
- request.write(json_bytes)
- finish_request(request)
+ # todo: we can almost certainly avoid this copy and encode the json straight into
+ # the bytesIO, but it would involve faffing around with string->bytes wrappers.
+ bytes_io = BytesIO(json_bytes)
+
+ producer = NoRangeStaticProducer(request, bytes_io)
+ producer.start()
return NOT_DONE_YET
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return {
- "true": True,
- "false": False,
+ b"true": True,
+ b"false": False,
}[args[name][0]]
except Exception:
message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string"):
- """Parse a string parameter from the request query string.
+ allowed_values=None, param_type="string", encoding='ascii'):
+ """
+ Parse a string parameter from the request query string.
+
+ If encoding is not None, the content of the query param will be
+ decoded to Unicode using the encoding, otherwise it will be encoded
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
- default (str|None): value to use if the parameter is absent, defaults
- to None.
+ name (bytes/unicode): the name of the query parameter.
+ default (bytes/unicode|None): value to use if the parameter is absent,
+ defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
- allowed_values (list[str]): List of allowed values for the string,
- or None if any value is allowed, defaults to None
+ allowed_values (list[bytes/unicode]): List of allowed values for the
+ string, or None if any value is allowed, defaults to None. Must be
+ the same type as name, if given.
+ encoding: The encoding to decode the name to, and decode the string
+ content with.
Returns:
- str|None: A string value or the default.
+ bytes/unicode|None: A string value or the default. Unicode if encoding
+ was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values.
"""
return parse_string_from_args(
- request.args, name, default, required, allowed_values, param_type,
+ request.args, name, default, required, allowed_values, param_type, encoding
)
def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string"):
+ allowed_values=None, param_type="string", encoding='ascii'):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
value = args[name][0]
+
+ if encoding:
+ value = value.decode(encoding)
+
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:
+
+ if encoding and isinstance(default, bytes):
+ return default.decode(encoding)
+
return default
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5fd30a4c2c..88ed3714f9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
import logging
import time
@@ -19,8 +18,8 @@ import time
from twisted.web.server import Request, Site
from synapse.http import redact_uri
-from synapse.http.request_metrics import RequestMetrics
-from synapse.util.logcontext import ContextResourceUsage, LoggingContext
+from synapse.http.request_metrics import RequestMetrics, requests_counter
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -34,25 +33,43 @@ class SynapseRequest(Request):
It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID
+ * A log context associated with the request
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
- It provides a method `processing` which should be called by the Resource
- which is handling the request, and returns a context manager.
+ It also provides a method `processing`, which returns a context manager. If this
+ method is called, the request won't be logged until the context manager is closed;
+ this is useful for asynchronous request handlers which may go on processing the
+ request even after the client has disconnected.
+ Attributes:
+ logcontext(LoggingContext) : the log context for this request
"""
def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = site
- self._channel = channel
+ self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
+ # we can't yet create the logcontext, as we don't know the method.
+ self.logcontext = None
+
global _next_request_seq
self.request_seq = _next_request_seq
_next_request_seq += 1
+ # whether an asynchronous request handler has called processing()
+ self._is_processing = False
+
+ # the time when the asynchronous request handler completed its processing
+ self._processing_finished_time = None
+
+ # what time we finished sending the response to the client (or the connection
+ # dropped)
+ self.finish_time = None
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
@@ -74,11 +91,116 @@ class SynapseRequest(Request):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def render(self, resrc):
+ # this is called once a Resource has been found to serve the request; in our
+ # case the Resource in question will normally be a JsonResource.
+
+ # create a LogContext for this request
+ request_id = self.get_request_id()
+ logcontext = self.logcontext = LoggingContext(request_id)
+ logcontext.request = request_id
+
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
- return Request.render(self, resrc)
+
+ with PreserveLoggingContext(self.logcontext):
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
+ servlet_name = resrc.__class__.__name__
+ self._started_processing(servlet_name)
+
+ Request.render(self, resrc)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(self.method,
+ self.request_metrics.name).inc()
+
+ @contextlib.contextmanager
+ def processing(self):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ Once the context manager is closed, the completion of the request will be logged,
+ and the various metrics will be updated.
+ """
+ if self._is_processing:
+ raise RuntimeError("Request is already processing")
+ self._is_processing = True
+
+ try:
+ yield
+ except Exception:
+ # this should already have been caught, and sent back to the client as a 500.
+ logger.exception("Asynchronous messge handler raised an uncaught exception")
+ finally:
+ # the request handler has finished its work and either sent the whole response
+ # back, or handed over responsibility to a Producer.
+
+ self._processing_finished_time = time.time()
+ self._is_processing = False
+
+ # if we've already sent the response, log it now; otherwise, we wait for the
+ # response to be sent.
+ if self.finish_time is not None:
+ self._finished_processing()
+
+ def finish(self):
+ """Called when all response data has been written to this Request.
+
+ Overrides twisted.web.server.Request.finish to record the finish time and do
+ logging.
+ """
+ self.finish_time = time.time()
+ Request.finish(self)
+ if not self._is_processing:
+ with PreserveLoggingContext(self.logcontext):
+ self._finished_processing()
+
+ def connectionLost(self, reason):
+ """Called when the client connection is closed before the response is written.
+
+ Overrides twisted.web.server.Request.connectionLost to record the finish time and
+ do logging.
+ """
+ self.finish_time = time.time()
+ Request.connectionLost(self, reason)
+
+ # we only get here if the connection to the client drops before we send
+ # the response.
+ #
+ # It's useful to log it here so that we can get an idea of when
+ # the client disconnects.
+ with PreserveLoggingContext(self.logcontext):
+ logger.warn(
+ "Error processing request %r: %s %s", self, reason.type, reason.value,
+ )
+
+ if not self._is_processing:
+ self._finished_processing()
def _started_processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ This will log the request's arrival. Once the request completes,
+ be sure to call finished_processing.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.name.
+ """
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
@@ -94,18 +216,32 @@ class SynapseRequest(Request):
)
def _finished_processing(self):
- try:
- context = LoggingContext.current_context()
- usage = context.get_resource_usage()
- except Exception:
- usage = ContextResourceUsage()
+ """Log the completion of this request and update the metrics
+ """
+
+ if self.logcontext is None:
+ # this can happen if the connection closed before we read the
+ # headers (so render was never called). In that case we'll already
+ # have logged a warning, so just bail out.
+ return
+
+ usage = self.logcontext.get_resource_usage()
+
+ if self._processing_finished_time is None:
+ # we completed the request without anything calling processing()
+ self._processing_finished_time = time.time()
- end_time = time.time()
+ # the time between receiving the request and the request handler finishing
+ processing_time = self._processing_finished_time - self.start_time
+
+ # the time between the request handler finishing and the response being sent
+ # to the client (nb may be negative)
+ response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
- if authenticated_entity is not None:
+ if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header.
@@ -116,22 +252,31 @@ class SynapseRequest(Request):
user_agent = self.get_user_agent()
if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace")
+ else:
+ user_agent = "-"
+
+ code = str(self.code)
+ if not self.finished:
+ # we didn't send the full response before we gave up (presumably because
+ # the connection dropped)
+ code += "!"
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
- end_time - self.start_time,
+ processing_time,
+ response_send_time,
usage.ru_utime,
usage.ru_stime,
usage.db_sched_duration_sec,
usage.db_txn_duration_sec,
int(usage.db_txn_count),
self.sentLength,
- self.code,
+ code,
self.method,
self.get_redacted_uri(),
self.clientproto,
@@ -140,38 +285,10 @@ class SynapseRequest(Request):
)
try:
- self.request_metrics.stop(end_time, self)
+ self.request_metrics.stop(self.finish_time, self)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
- @contextlib.contextmanager
- def processing(self, servlet_name):
- """Record the fact that we are processing this request.
-
- Returns a context manager; the correct way to use this is:
-
- @defer.inlineCallbacks
- def handle_request(request):
- with request.processing("FooServlet"):
- yield really_handle_the_request()
-
- This will log the request's arrival. Once the context manager is
- closed, the completion of the request will be logged, and the various
- metrics will be updated.
-
- Args:
- servlet_name (str): the name of the servlet which will be
- processing this request. This is used in the metrics.
-
- It is possible to update this afterwards by updating
- self.request_metrics.servlet_name.
- """
- # TODO: we should probably just move this into render() and finish(),
- # to save having to call a separate method.
- self._started_processing(servlet_name)
- yield
- self._finished_processing()
-
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
@@ -217,7 +334,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
- self.server_version_string = server_version_string
+ self.server_version_string = server_version_string.encode('ascii')
def log(self, request):
pass
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index a9158fc066..550f8443f7 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -174,6 +174,19 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions
events_processed_counter = Counter("synapse_federation_client_events_processed", "")
+event_processing_loop_counter = Counter(
+ "synapse_event_processing_loop_count",
+ "Event processing loop iterations",
+ ["name"],
+)
+
+event_processing_loop_room_count = Counter(
+ "synapse_event_processing_loop_room_count",
+ "Rooms seen per event processing loop iteration",
+ ["name"],
+)
+
+
# Used to track where various components have processed in the event stream,
# e.g. federation sending, appservice sending, etc.
event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])
diff --git a/synapse/notifier.py b/synapse/notifier.py
index e650c3e494..82f391481c 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -25,7 +25,7 @@ from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import LaterGauge
from synapse.types import StreamToken
-from synapse.util.async import (
+from synapse.util.async_helpers import (
DeferredTimeoutError,
ObservableDeferred,
add_timeout_to_deferred,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 1d14d3639c..8f9a76147f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 9d601208fd..bfa6df7b68 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -35,7 +35,7 @@ from synapse.push.presentable_names import (
name_from_member_event,
)
from synapse.types import UserID
-from synapse.util.async import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 36bb5bbc65..9f7d5ef217 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -122,8 +123,14 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'],
)
- @defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
+ run_as_background_process(
+ "on_new_notifications",
+ self._on_new_notifications, min_stream_id, max_stream_id,
+ )
+
+ @defer.inlineCallbacks
+ def _on_new_notifications(self, min_stream_id, max_stream_id):
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
@@ -147,8 +154,14 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- @defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ run_as_background_process(
+ "on_new_receipts",
+ self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
+ )
+
+ @defer.inlineCallbacks
+ def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 589ee94c66..19f214281e 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import membership, send_event
+from synapse.replication.http import federation, membership, send_event
REPLICATION_PREFIX = "/_synapse/replication"
@@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ federation.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
new file mode 100644
index 0000000000..5e5376cf58
--- /dev/null
+++ b/synapse/replication/http/_base.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+import re
+
+from six.moves import urllib
+
+from twisted.internet import defer
+
+from synapse.api.errors import CodeMessageException, HttpResponseException
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import random_string
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationEndpoint(object):
+ """Helper base class for defining new replication HTTP endpoints.
+
+ This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
+ (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+ PATH_ARGS are a tuple of parameters to be encoded in the URL.
+
+ For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
+ with `CACHE` set to true then this generates an endpoint:
+
+ /_synapse/replication/send_event/:event_id/:txn_id
+
+ For POST/PUT requests the payload is serialized to json and sent as the
+ body, while for GET requests the payload is added as query parameters. See
+ `_serialize_payload` for details.
+
+ Incoming requests are handled by overriding `_handle_request`. Servers
+ must call `register` to register the path with the HTTP server.
+
+ Requests can be sent by calling the client returned by `make_client`.
+
+ Attributes:
+ NAME (str): A name for the endpoint, added to the path as well as used
+ in logging and metrics.
+ PATH_ARGS (tuple[str]): A list of parameters to be added to the path.
+ Adding parameters to the path (rather than payload) can make it
+ easier to follow along in the log files.
+ METHOD (str): The method of the HTTP request, defaults to POST. Can be
+ one of POST, PUT or GET. If GET then the payload is sent as query
+ parameters rather than a JSON body.
+ CACHE (bool): Whether server should cache the result of the request/
+ If true then transparently adds a txn_id to all requests, and
+ `_handle_request` must return a Deferred.
+ RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504
+ is received.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ NAME = abc.abstractproperty()
+ PATH_ARGS = abc.abstractproperty()
+
+ METHOD = "POST"
+ CACHE = True
+ RETRY_ON_TIMEOUT = True
+
+ def __init__(self, hs):
+ if self.CACHE:
+ self.response_cache = ResponseCache(
+ hs, "repl." + self.NAME,
+ timeout_ms=30 * 60 * 1000,
+ )
+
+ assert self.METHOD in ("PUT", "POST", "GET")
+
+ @abc.abstractmethod
+ def _serialize_payload(**kwargs):
+ """Static method that is called when creating a request.
+
+ Concrete implementations should have explicit parameters (rather than
+ kwargs) so that an appropriate exception is raised if the client is
+ called with unexpected parameters. All PATH_ARGS must appear in
+ argument list.
+
+ Returns:
+ Deferred[dict]|dict: If POST/PUT request then dictionary must be
+ JSON serialisable, otherwise must be appropriate for adding as
+ query args.
+ """
+ return {}
+
+ @abc.abstractmethod
+ def _handle_request(self, request, **kwargs):
+ """Handle incoming request.
+
+ This is called with the request object and PATH_ARGS.
+
+ Returns:
+ Deferred[dict]: A JSON serialisable dict to be used as response
+ body of request.
+ """
+ pass
+
+ @classmethod
+ def make_client(cls, hs):
+ """Create a client that makes requests.
+
+ Returns a callable that accepts the same parameters as `_serialize_payload`.
+ """
+ clock = hs.get_clock()
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_http_port
+
+ client = hs.get_simple_http_client()
+
+ @defer.inlineCallbacks
+ def send_request(**kwargs):
+ data = yield cls._serialize_payload(**kwargs)
+
+ url_args = [urllib.parse.quote(kwargs[name]) for name in cls.PATH_ARGS]
+
+ if cls.CACHE:
+ txn_id = random_string(10)
+ url_args.append(txn_id)
+
+ if cls.METHOD == "POST":
+ request_func = client.post_json_get_json
+ elif cls.METHOD == "PUT":
+ request_func = client.put_json
+ elif cls.METHOD == "GET":
+ request_func = client.get_json
+ else:
+ # We have already asserted in the constructor that a
+ # compatible was picked, but lets be paranoid.
+ raise Exception(
+ "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
+ )
+
+ uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+ host, port, cls.NAME, "/".join(url_args)
+ )
+
+ try:
+ # We keep retrying the same request for timeouts. This is so that we
+ # have a good idea that the request has either succeeded or failed on
+ # the master, and so whether we should clean up or not.
+ while True:
+ try:
+ result = yield request_func(uri, data)
+ break
+ except CodeMessageException as e:
+ if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
+ raise
+
+ logger.warn("%s request timed out", cls.NAME)
+
+ # If we timed out we probably don't need to worry about backing
+ # off too much, but lets just wait a little anyway.
+ yield clock.sleep(1)
+ except HttpResponseException as e:
+ # We convert to SynapseError as we know that it was a SynapseError
+ # on the master process that we should send to the client. (And
+ # importantly, not stack traces everywhere)
+ raise e.to_synapse_error()
+
+ defer.returnValue(result)
+
+ return send_request
+
+ def register(self, http_server):
+ """Called by the server to register this as a handler to the
+ appropriate path.
+ """
+
+ url_args = list(self.PATH_ARGS)
+ handler = self._handle_request
+ method = self.METHOD
+
+ if self.CACHE:
+ handler = self._cached_handler
+ url_args.append("txn_id")
+
+ args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
+ pattern = re.compile("^/_synapse/replication/%s/%s$" % (
+ self.NAME,
+ args
+ ))
+
+ http_server.register_paths(method, [pattern], handler)
+
+ def _cached_handler(self, request, txn_id, **kwargs):
+ """Called on new incoming requests when caching is enabled. Checks
+ if there is a cached response for the request and returns that,
+ otherwise calls `_handle_request` and caches its response.
+ """
+ # We just use the txn_id here, but we probably also want to use the
+ # other PATH_ARGS as well.
+
+ assert self.CACHE
+
+ return self.response_cache.wrap(
+ txn_id,
+ self._handle_request,
+ request, **kwargs
+ )
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
new file mode 100644
index 0000000000..64a79da162
--- /dev/null
+++ b/synapse/replication/http/federation.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
+ """Handles events newly received from federation, including persisting and
+ notifying.
+
+ The API looks like:
+
+ POST /_synapse/replication/fed_send_events/:txn_id
+
+ {
+ "events": [{
+ "event": { .. serialized event .. },
+ "internal_metadata": { .. serialized internal_metadata .. },
+ "rejected_reason": .., // The event.rejected_reason field
+ "context": { .. serialized event context .. },
+ }],
+ "backfilled": false
+ """
+
+ NAME = "fed_send_events"
+ PATH_ARGS = ()
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.federation_handler = hs.get_handlers().federation_handler
+
+ @staticmethod
+ @defer.inlineCallbacks
+ def _serialize_payload(store, event_and_contexts, backfilled):
+ """
+ Args:
+ store
+ event_and_contexts (list[tuple[FrozenEvent, EventContext]])
+ backfilled (bool): Whether or not the events are the result of
+ backfilling
+ """
+ event_payloads = []
+ for event, context in event_and_contexts:
+ serialized_context = yield context.serialize(event, store)
+
+ event_payloads.append({
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ })
+
+ payload = {
+ "events": event_payloads,
+ "backfilled": backfilled,
+ }
+
+ defer.returnValue(payload)
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request):
+ with Measure(self.clock, "repl_fed_send_events_parse"):
+ content = parse_json_object_from_request(request)
+
+ backfilled = content["backfilled"]
+
+ event_payloads = content["events"]
+
+ event_and_contexts = []
+ for event_payload in event_payloads:
+ event_dict = event_payload["event"]
+ internal_metadata = event_payload["internal_metadata"]
+ rejected_reason = event_payload["rejected_reason"]
+ event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ context = yield EventContext.deserialize(
+ self.store, event_payload["context"],
+ )
+
+ event_and_contexts.append((event, context))
+
+ logger.info(
+ "Got %d events from federation",
+ len(event_and_contexts),
+ )
+
+ yield self.federation_handler.persist_events_and_notify(
+ event_and_contexts, backfilled,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
+ """Handles EDUs newly received from federation, including persisting and
+ notifying.
+
+ Request format:
+
+ POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
+
+ {
+ "origin": ...,
+ "content: { ... }
+ }
+ """
+
+ NAME = "fed_send_edu"
+ PATH_ARGS = ("edu_type",)
+
+ def __init__(self, hs):
+ super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(edu_type, origin, content):
+ return {
+ "origin": origin,
+ "content": content,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, edu_type):
+ with Measure(self.clock, "repl_fed_send_edu_parse"):
+ content = parse_json_object_from_request(request)
+
+ origin = content["origin"]
+ edu_content = content["content"]
+
+ logger.info(
+ "Got %r edu from %s",
+ edu_type, origin,
+ )
+
+ result = yield self.registry.on_edu(edu_type, origin, edu_content)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationGetQueryRestServlet(ReplicationEndpoint):
+ """Handle responding to queries from federation.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:query_type
+
+ {
+ "args": { ... }
+ }
+ """
+
+ NAME = "fed_query"
+ PATH_ARGS = ("query_type",)
+
+ # This is a query, so let's not bother caching
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationGetQueryRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.registry = hs.get_federation_registry()
+
+ @staticmethod
+ def _serialize_payload(query_type, args):
+ """
+ Args:
+ query_type (str)
+ args (dict): The arguments received for the given query type
+ """
+ return {
+ "args": args,
+ }
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, query_type):
+ with Measure(self.clock, "repl_fed_query_parse"):
+ content = parse_json_object_from_request(request)
+
+ args = content["args"]
+
+ logger.info(
+ "Got %r query",
+ query_type,
+ )
+
+ result = yield self.registry.on_query(query_type, args)
+
+ defer.returnValue((200, result))
+
+
+class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+
+ {}
+ """
+
+ NAME = "fed_cleanup_room"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, args):
+ """
+ Args:
+ room_id (str)
+ """
+ return {}
+
+ @defer.inlineCallbacks
+ def _handle_request(self, request, room_id):
+ yield self.store.clean_room_for_join(room_id)
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReplicationFederationSendEventsRestServlet(hs).register(http_server)
+ ReplicationFederationSendEduRestServlet(hs).register(http_server)
+ ReplicationGetQueryRestServlet(hs).register(http_server)
+ ReplicationCleanRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 7a3cfb159c..e58bebf12a 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,182 +14,63 @@
# limitations under the License.
import logging
-import re
from twisted.internet import defer
-from synapse.api.errors import HttpResponseException
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def remote_join(client, host, port, requester, remote_room_hosts,
- room_id, user_id, content):
- """Ask the master to do a remote join for the given user to the given room
+class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
+ """Does a remote join for the given user to the given room
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- remote_room_hosts (list[str]): Servers to try and join via
- room_id (str)
- user_id (str)
- content (dict): The event content to use for the join event
+ Request format:
- Returns:
- Deferred
- """
- uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
- "room_id": room_id,
- "user_id": user_id,
- "content": content,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise e.to_synapse_error()
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def remote_reject_invite(client, host, port, requester, remote_room_hosts,
- room_id, user_id):
- """Ask master to reject the invite for the user and room.
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- remote_room_hosts (list[str]): Servers to try and reject via
- room_id (str)
- user_id (str)
-
- Returns:
- Deferred
- """
- uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
- "room_id": room_id,
- "user_id": user_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise e.to_synapse_error()
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def get_or_register_3pid_guest(client, host, port, requester,
- medium, address, inviter_user_id):
- """Ask the master to get/create a guest account for given 3PID.
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
+ POST /_synapse/replication/remote_join/:room_id/:user_id
- uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
-
- payload = {
- "requester": requester.serialize(),
- "medium": medium,
- "address": address,
- "inviter_user_id": inviter_user_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise e.to_synapse_error()
- defer.returnValue(result)
-
-
-@defer.inlineCallbacks
-def notify_user_membership_change(client, host, port, user_id, room_id, change):
- """Notify master that a user has joined or left the room
-
- Args:
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication.
- user_id (str)
- room_id (str)
- change (str): Either "join" or "left"
-
- Returns:
- Deferred
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ "content": { ... }
+ }
"""
- assert change in ("joined", "left")
-
- uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
-
- payload = {
- "user_id": user_id,
- "room_id": room_id,
- }
-
- try:
- result = yield client.post_json_get_json(uri, payload)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise e.to_synapse_error()
- defer.returnValue(result)
-
-class ReplicationRemoteJoinRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
+ NAME = "remote_join"
+ PATH_ARGS = ("room_id", "user_id",)
def __init__(self, hs):
- super(ReplicationRemoteJoinRestServlet, self).__init__()
+ super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
+ content):
+ """
+ Args:
+ requester(Requester)
+ room_id (str)
+ user_id (str)
+ remote_room_hosts (list[str]): Servers to try and join via
+ content(dict): The event content to use for the join event
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ "content": content,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
- room_id = content["room_id"]
- user_id = content["user_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -212,23 +93,48 @@ class ReplicationRemoteJoinRestServlet(RestServlet):
defer.returnValue((200, {}))
-class ReplicationRemoteRejectInviteRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
+class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
+ """Rejects the invite for the user and room.
+
+ Request format:
+
+ POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+
+ {
+ "requester": ...,
+ "remote_room_hosts": [...],
+ }
+ """
+
+ NAME = "remote_reject_invite"
+ PATH_ARGS = ("room_id", "user_id",)
def __init__(self, hs):
- super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
+ super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+ """
+ Args:
+ requester(Requester)
+ room_id (str)
+ user_id (str)
+ remote_room_hosts (list[str]): Servers to try and reject via
+ """
+ return {
+ "requester": requester.serialize(),
+ "remote_room_hosts": remote_room_hosts,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
- room_id = content["room_id"]
- user_id = content["user_id"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -264,18 +170,50 @@ class ReplicationRemoteRejectInviteRestServlet(RestServlet):
defer.returnValue((200, ret))
-class ReplicationRegister3PIDGuestRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
+class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
+ """Gets/creates a guest account for given 3PID.
+
+ Request format:
+
+ POST /_synapse/replication/get_or_register_3pid_guest/
+
+ {
+ "requester": ...,
+ "medium": ...,
+ "address": ...,
+ "inviter_user_id": ...
+ }
+ """
+
+ NAME = "get_or_register_3pid_guest"
+ PATH_ARGS = ()
def __init__(self, hs):
- super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
+ super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @staticmethod
+ def _serialize_payload(requester, medium, address, inviter_user_id):
+ """
+ Args:
+ requester(Requester)
+ medium (str)
+ address (str)
+ inviter_user_id (str): The user ID who is trying to invite the
+ 3PID
+ """
+ return {
+ "requester": requester.serialize(),
+ "medium": medium,
+ "address": address,
+ "inviter_user_id": inviter_user_id,
+ }
+
@defer.inlineCallbacks
- def on_POST(self, request):
+ def _handle_request(self, request):
content = parse_json_object_from_request(request)
medium = content["medium"]
@@ -296,23 +234,41 @@ class ReplicationRegister3PIDGuestRestServlet(RestServlet):
defer.returnValue((200, ret))
-class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
- PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
+class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
+ """Notifies that a user has joined or left the room
+
+ Request format:
+
+ POST /_synapse/replication/membership_change/:room_id/:user_id/:change
+
+ {}
+ """
+
+ NAME = "membership_change"
+ PATH_ARGS = ("room_id", "user_id", "change")
+ CACHE = False # No point caching as should return instantly.
def __init__(self, hs):
- super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
+ super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
self.registeration_handler = hs.get_handlers().registration_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
- def on_POST(self, request, change):
- content = parse_json_object_from_request(request)
+ @staticmethod
+ def _serialize_payload(room_id, user_id, change):
+ """
+ Args:
+ room_id (str)
+ user_id (str)
+ change (str): Either "joined" or "left"
+ """
+ assert change in ("joined", "left",)
- user_id = content["user_id"]
- room_id = content["room_id"]
+ return {}
+ def _handle_request(self, request, room_id, user_id, change):
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index d3509dc288..5b52c91650 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -14,86 +14,26 @@
# limitations under the License.
import logging
-import re
from twisted.internet import defer
-from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
-from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def send_event_to_master(clock, store, client, host, port, requester, event, context,
- ratelimit, extra_users):
- """Send event to be handled on the master
-
- Args:
- clock (synapse.util.Clock)
- store (DataStore)
- client (SimpleHttpClient)
- host (str): host of master
- port (int): port on master listening for HTTP replication
- requester (Requester)
- event (FrozenEvent)
- context (EventContext)
- ratelimit (bool)
- extra_users (list(UserID)): Any extra users to notify about event
- """
- uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
- host, port, event.event_id,
- )
-
- serialized_context = yield context.serialize(event, store)
-
- payload = {
- "event": event.get_pdu_json(),
- "internal_metadata": event.internal_metadata.get_dict(),
- "rejected_reason": event.rejected_reason,
- "context": serialized_context,
- "requester": requester.serialize(),
- "ratelimit": ratelimit,
- "extra_users": [u.to_string() for u in extra_users],
- }
-
- try:
- # We keep retrying the same request for timeouts. This is so that we
- # have a good idea that the request has either succeeded or failed on
- # the master, and so whether we should clean up or not.
- while True:
- try:
- result = yield client.put_json(uri, payload)
- break
- except CodeMessageException as e:
- if e.code != 504:
- raise
-
- logger.warn("send_event request timed out")
-
- # If we timed out we probably don't need to worry about backing
- # off too much, but lets just wait a little anyway.
- yield clock.sleep(1)
- except HttpResponseException as e:
- # We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
- # importantly, not stack traces everywhere)
- raise e.to_synapse_error()
- defer.returnValue(result)
-
-
-class ReplicationSendEventRestServlet(RestServlet):
+class ReplicationSendEventRestServlet(ReplicationEndpoint):
"""Handles events newly created on workers, including persisting and
notifying.
The API looks like:
- POST /_synapse/replication/send_event/:event_id
+ POST /_synapse/replication/send_event/:event_id/:txn_id
{
"event": { .. serialized event .. },
@@ -105,27 +45,47 @@ class ReplicationSendEventRestServlet(RestServlet):
"extra_users": [],
}
"""
- PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
+ NAME = "send_event"
+ PATH_ARGS = ("event_id",)
def __init__(self, hs):
- super(ReplicationSendEventRestServlet, self).__init__()
+ super(ReplicationSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
- # The responses are tiny, so we may as well cache them for a while
- self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
+ @staticmethod
+ @defer.inlineCallbacks
+ def _serialize_payload(event_id, store, event, context, requester,
+ ratelimit, extra_users):
+ """
+ Args:
+ event_id (str)
+ store (DataStore)
+ requester (Requester)
+ event (FrozenEvent)
+ context (EventContext)
+ ratelimit (bool)
+ extra_users (list(UserID)): Any extra users to notify about event
+ """
+
+ serialized_context = yield context.serialize(event, store)
+
+ payload = {
+ "event": event.get_pdu_json(),
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ "requester": requester.serialize(),
+ "ratelimit": ratelimit,
+ "extra_users": [u.to_string() for u in extra_users],
+ }
- def on_PUT(self, request, event_id):
- return self.response_cache.wrap(
- event_id,
- self._handle_request,
- request
- )
+ defer.returnValue(payload)
@defer.inlineCallbacks
- def _handle_request(self, request):
+ def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index bdb5eee4af..4830c68f35 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -44,8 +44,8 @@ class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
- EventsWorkerStore,
StateGroupWorkerStore,
+ EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 9c9a5eadd9..3527beb3c9 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,19 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
from ._base import BaseSlavedStore
-class TransactionStore(BaseSlavedStore):
- get_destination_retry_timings = TransactionStore.__dict__[
- "get_destination_retry_timings"
- ]
- _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
- set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
- _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
-
- prep_send_transaction = DataStore.prep_send_transaction.__func__
- delivered_txn = DataStore.delivered_txn.__func__
+class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
+ pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 970e94313e..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
+ return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
@@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- self.store.process_replication_rows(stream_name, token, [])
+ return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f3908df642..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -59,6 +59,12 @@ class Command(object):
"""
return self.data
+ def get_logcontext_id(self):
+ """Get a suitable string for the logcontext when processing this command"""
+
+ # by default, we just use the command name.
+ return self.NAME
+
class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name.
@@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row),
))
+ def get_logcontext_id(self):
+ return "RDATA-" + self.stream_name
+
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
@@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self):
return " ".join((self.stream_name, str(self.token),))
+ def get_logcontext_id(self):
+ return "REPLICATE-" + self.stream_name
+
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dec5ac0913..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from .commands import (
@@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
try:
- getattr(self, "on_%s" % (cmd_name,))(cmd)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ getattr(self, "on_%s" % (cmd_name,)),
+ cmd,
+ )
except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data
def on_USER_SYNC(self, cmd):
- self.streamer.on_user_sync(
+ return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
)
@@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in iterkeys(self.streamer.streams_by_name):
- self.subscribe_to_stream(stream, token)
+ deferreds = [
+ run_in_background(
+ self.subscribe_to_stream,
+ stream, token,
+ )
+ for stream in iterkeys(self.streamer.streams_by_name)
+ ]
+
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
else:
- self.subscribe_to_stream(stream_name, token)
+ return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
+ return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ return self.streamer.on_remove_pusher(
+ cmd.app_id, cmd.push_key, cmd.user_id,
+ )
def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
+ return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
-
- self.handler.on_rdata(stream_name, cmd.token, rows)
+ return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
- self.handler.on_position(cmd.stream_name, cmd.token)
+ return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 00b1b3066e..48c17f1b6d 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,7 +17,7 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
- return request.path + "/" + token
+ return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 80d625eecc..ad536ab570 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self._deactivate_account_handler.deactivate_account(
+ result = yield self._deactivate_account_handler.deactivate_account(
target_user_id, erase,
)
- defer.returnValue((200, {}))
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ defer.returnValue((200, {
+ "id_server_unbind_result": id_server_unbind_result,
+ }))
class ShutdownRoomRestServlet(ClientV1RestServlet):
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index a14f0c807e..b5a6d6aebf 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except Exception:
raise SynapseError(400, "Unable to parse state")
- yield self.presence_handler.set_state(user, state)
+ if self.hs.config.use_presence:
+ yield self.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fa5989e74e..fcc1091760 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -34,7 +34,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.streams.config import PaginationConfig
-from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID
+from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
@@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request)
- events = yield self.message_handler.get_state_events(
+ handler = self.message_handler
+
+ # request the state as of a given event, as identified by a stream token,
+ # for consistency with /messages etc.
+ # useful for getting the membership in retrospect as of a given /sync
+ # response.
+ at_token_string = parse_string(request, "at")
+ if at_token_string is None:
+ at_token = None
+ else:
+ at_token = StreamToken.from_string(at_token_string)
+
+ # let you filter down on particular memberships.
+ # XXX: this may not be the best shape for this API - we could pass in a filter
+ # instead, except filters aren't currently aware of memberships.
+ # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details.
+ membership = parse_string(request, "membership")
+ not_membership = parse_string(request, "not_membership")
+
+ events = yield handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
+ at_token=at_token,
+ types=[(EventTypes.Member, None)],
)
chunk = []
for event in events:
- if event["type"] != EventTypes.Member:
+ if (
+ (membership and event['content'].get("membership") != membership) or
+ (not_membership and event['content'].get("membership") == not_membership)
+ ):
continue
chunk.append(event)
@@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
}))
+# deprecated in favour of /members?membership=join?
+# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py
index 3439c3c6d4..5e99cffbcb 100644
--- a/synapse/rest/client/v1_only/register.py
+++ b/synapse/rest/client/v1_only/register.py
@@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
- is_using_shared_secret = login_type == LoginType.SHARED_SECRET
-
can_register = (
self.enable_registration
or is_application_server
- or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
@@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service,
- LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
- @defer.inlineCallbacks
- def _do_shared_secret(self, request, register_json, session):
- assert_params_in_dict(register_json, ["mac", "user", "password"])
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- user = register_json["user"].encode("utf-8")
- password = register_json["password"].encode("utf-8")
- admin = register_json.get("admin", None)
-
- # Its important to check as we use null bytes as HMAC field separators
- if b"\x00" in user:
- raise SynapseError(400, "Invalid user")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- # str() because otherwise hmac complains that 'unicode' does not
- # have the buffer interface
- got_mac = str(register_json["mac"])
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=sha1,
- )
- want_mac.update(user)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- want_mac = want_mac.hexdigest()
-
- if compare_digest(want_mac, got_mac):
- handler = self.handlers.registration_handler
- user_id, token = yield handler.register(
- localpart=user.lower(),
- password=password,
- admin=bool(admin),
- )
- self._remove_session(session)
- defer.returnValue({
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- })
- else:
- raise SynapseError(
- 403, "HMAC incorrect",
- )
-
class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index eeae466d82..372648cafd 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet):
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
- yield self._deactivate_account_handler.deactivate_account(
+ result = yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase,
)
- defer.returnValue((200, {}))
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ defer.returnValue((200, {
+ "id_server_unbind_result": id_server_unbind_result,
+ }))
class EmailThreepidRequestTokenRestServlet(RestServlet):
@@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet):
user_id = requester.user.to_string()
try:
- yield self.auth_handler.delete_threepid(
+ ret = yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address']
)
except Exception:
@@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
- defer.returnValue((200, {}))
+ if ret:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ defer.returnValue((200, {
+ "id_server_unbind_result": id_server_unbind_result,
+ }))
class WhoamiRestServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8aa06faf23..1275baa1ba 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
+ result["summary"] = room.summary
return result
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 6ac2987b98..29e62bfcdd 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -27,11 +27,22 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": [
+ # XXX: at some point we need to decide whether we need to include
+ # the previous version numbers, given we've defined r0.3.0 to be
+ # backwards compatible with r0.2.0. But need to check how
+ # conscientious we've been in compatibility, and decide whether the
+ # middle number is the major revision when at 0.X.Y (as opposed to
+ # X.Y.Z). And we need to decide whether it's fair to make clients
+ # parse the version string to figure out what's going on.
"r0.0.1",
"r0.1.0",
"r0.2.0",
"r0.3.0",
- ]
+ ],
+ # as per MSC1497:
+ "unstable_features": {
+ "m.lazy_load_members": True,
+ }
})
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 147ff7d79b..7362e1858d 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -140,7 +140,7 @@ class ConsentResource(Resource):
version = parse_string(request, "v",
default=self._default_consent_version)
username = parse_string(request, "u", required=True)
- userhmac = parse_string(request, "h", required=True)
+ userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac)
@@ -175,7 +175,7 @@ class ConsentResource(Resource):
"""
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- userhmac = parse_string(request, "h", required=True)
+ userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac)
@@ -210,9 +210,18 @@ class ConsentResource(Resource):
finish_request(request)
def _check_hash(self, userid, userhmac):
+ """
+ Args:
+ userid (unicode):
+ userhmac (bytes):
+
+ Raises:
+ SynapseError if the hash doesn't match
+
+ """
want_mac = hmac.new(
key=self._hmac_secret,
- msg=userid,
+ msg=userid.encode('utf-8'),
digestmod=sha256,
).hexdigest()
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
new file mode 100644
index 0000000000..d6605b6027
--- /dev/null
+++ b/synapse/rest/media/v1/config_resource.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 Will Hunt <will@half-shot.uk>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.http.server import respond_with_json, wrap_json_request_handler
+
+
+class MediaConfigResource(Resource):
+ isLeaf = True
+
+ def __init__(self, hs):
+ Resource.__init__(self)
+ config = hs.get_config()
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.limits_dict = {
+ "m.upload.size": config.max_upload_size,
+ }
+
+ def render_GET(self, request):
+ self._async_render_GET(request)
+ return NOT_DONE_YET
+
+ @wrap_json_request_handler
+ @defer.inlineCallbacks
+ def _async_render_GET(self, request):
+ yield self.auth.get_user_by_req(request)
+ respond_with_json(request, 200, self.limits_dict)
+
+ def render_OPTIONS(self, request):
+ respond_with_json(request, 200, {}, send_cors=True)
+ return NOT_DONE_YET
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 8fb413d825..241c972070 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -36,12 +36,13 @@ from synapse.api.errors import (
)
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import is_ascii, random_string
from ._base import FileInfo, respond_404, respond_with_responder
+from .config_resource import MediaConfigResource
from .download_resource import DownloadResource
from .filepath import MediaFilePaths
from .identicon_resource import IdenticonResource
@@ -754,7 +755,6 @@ class MediaRepositoryResource(Resource):
Resource.__init__(self)
media_repo = hs.get_media_repository()
-
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(
@@ -765,3 +765,4 @@ class MediaRepositoryResource(Resource):
self.putChild("preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))
+ self.putChild("config", MediaConfigResource(hs))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 27aa0def2f..778ef97337 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -42,7 +42,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import is_ascii, random_string
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 9b22d204a6..c1240e1963 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -55,7 +55,7 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader("Content-Length")
+ content_length = request.getHeader(b"Content-Length").decode('ascii')
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
@@ -66,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
- upload_name = parse_string(request, "filename")
+ upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
try:
- upload_name = upload_name.decode('UTF-8')
+ upload_name = upload_name.decode('utf8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@@ -78,8 +78,8 @@ class UploadResource(Resource):
headers = request.requestHeaders
- if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0]
+ if headers.hasHeader(b"Content-Type"):
+ media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
diff --git a/synapse/secrets.py b/synapse/secrets.py
index f05e9ea535..f6280f951c 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -38,4 +38,4 @@ else:
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
- return binascii.hexlify(self.token_bytes(nbytes))
+ return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
diff --git a/synapse/server.py b/synapse/server.py
index 140be9ebe8..26228d8c72 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
FederationServer,
+ ReplicationFederationHandlerRegistry,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transaction_queue import TransactionQueue
@@ -423,7 +424,10 @@ class HomeServer(object):
return RoomMemberMasterHandler(self)
def build_federation_registry(self):
- return FederationHandlerRegistry()
+ if self.config.worker_app:
+ return ReplicationFederationHandlerRegistry(self)
+ else:
+ return FederationHandlerRegistry()
def build_server_notices_manager(self):
if self.config.worker_app:
diff --git a/synapse/state.py b/synapse/state.py
index e1092b97a9..0f2bedb694 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -28,8 +28,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
-from synapse.util.async import Linearizer
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
+SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache")
EVICTION_TIMEOUT_SECONDS = 60 * 60
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 134e4a80f1..23b4a8d76d 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -39,6 +39,7 @@ from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
@@ -87,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
+ MonthlyActiveUsersStore,
):
def __init__(self, db_conn, hs):
@@ -94,7 +96,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
- self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
@@ -267,31 +268,6 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users)
- def count_monthly_users(self):
- """Counts the number of users who used this homeserver in the last 30 days
-
- This method should be refactored with count_daily_users - the only
- reason not to is waiting on definition of mau
-
- Returns:
- Defered[int]
- """
- def _count_monthly_users(txn):
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago,))
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_monthly_users", _count_monthly_users)
-
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 44f37b4c1e..08dffd774f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1150,17 +1150,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
- """Get a total number of registerd users in the users list.
+ """Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
- defer.Deferred: resolves to int
+ int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
- count = txn.fetchone()[0]
- defer.returnValue(count)
+ return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index b8cefd43d6..8fc678fa67 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
+
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
@@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
+ @defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
if not now:
@@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
-
+ yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -94,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self):
+
+ # If the DB pool has already terminated, don't try updating
+ if not self.hs.get_db_pool().running:
+ return
+
def update():
to_update = self._batch_row_update
self._batch_row_update = {}
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e8e5a0fe44..025a7fb6d9 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
@@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(chunk))
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering,
- )
+
+ if not backfilled:
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
+
for event, context in chunk:
if context.app_service:
origin_type = "local"
@@ -1431,88 +1436,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
@defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
- """Given a list of event ids, check if we have already processed and
- stored them as non outliers.
- """
- rows = yield self._simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
-
- defer.returnValue(set(r["event_id"] for r in rows))
-
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Args:
- event_ids (iterable[str]):
-
- Returns:
- Deferred[set[str]]: The events we have already seen.
- """
- results = set()
-
- def have_seen_events_txn(txn, chunk):
- sql = (
- "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
- % (",".join("?" * len(chunk)), )
- )
- txn.execute(sql, chunk)
- for (event_id, ) in txn:
- results.add(event_id)
-
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
- []):
- yield self.runInteraction(
- "have_seen_events",
- have_seen_events_txn,
- chunk,
- )
- defer.returnValue(results)
-
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_rejection_reasons", f)
-
- @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -1988,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
max_depth = max(row[0] for row in rows)
if max_depth <= token.topological:
- # We need to ensure we don't delete all the events from the datanase
+ # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 9b4cfeb899..59822178ff 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
from collections import namedtuple
@@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ defer.returnValue(set(r["event_id"] for r in rows))
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction("get_rejection_reasons", f)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
new file mode 100644
index 0000000000..06f9a75a97
--- /dev/null
+++ b/synapse/storage/monthly_active_users.py
@@ -0,0 +1,216 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cached
+
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+# Number of msec of granularity to store the monthly_active_user timestamp
+# This means it is not necessary to update the table on every request
+LAST_SEEN_GRANULARITY = 60 * 60 * 1000
+
+
+class MonthlyActiveUsersStore(SQLBaseStore):
+ def __init__(self, dbconn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(None, hs)
+ self._clock = hs.get_clock()
+ self.hs = hs
+ self.reserved_users = ()
+
+ @defer.inlineCallbacks
+ def initialise_reserved_users(self, threepids):
+ # TODO Why can't I do this in init?
+ store = self.hs.get_datastore()
+ reserved_user_list = []
+
+ # Do not add more reserved users than the total allowable number
+ for tp in threepids[:self.hs.config.max_mau_value]:
+ user_id = yield store.get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ yield self.upsert_monthly_active_user(user_id)
+ reserved_user_list.append(user_id)
+ else:
+ logger.warning(
+ "mau limit reserved threepid %s not found in db" % tp
+ )
+ self.reserved_users = tuple(reserved_user_list)
+
+ @defer.inlineCallbacks
+ def reap_monthly_active_users(self):
+ """
+ Cleans out monthly active user table to ensure that no stale
+ entries exist.
+
+ Returns:
+ Deferred[]
+ """
+ def _reap_users(txn):
+ # Purge stale users
+
+ thirty_days_ago = (
+ int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ )
+ query_args = [thirty_days_ago]
+ base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
+
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ # questionmarks is a hack to overcome sqlite not supporting
+ # tuples in 'WHERE IN %s'
+ questionmarks = '?' * len(self.reserved_users)
+
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+
+ txn.execute(sql, query_args)
+
+ # If MAU user count still exceeds the MAU threshold, then delete on
+ # a least recently active basis.
+ # Note it is not possible to write this query using OFFSET due to
+ # incompatibilities in how sqlite and postgres support the feature.
+ # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
+ # While Postgres does not require 'LIMIT', but also does not support
+ # negative LIMIT values. So there is no way to write it that both can
+ # support
+ safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
+ # Must be greater than zero for postgres
+ safe_guard = safe_guard if safe_guard > 0 else 0
+ query_args = [safe_guard]
+
+ base_sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ """
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+ txn.execute(sql, query_args)
+
+ yield self.runInteraction("reap_monthly_active_users", _reap_users)
+ # It seems poor to invalidate the whole cache, Postgres supports
+ # 'Returning' which would allow me to invalidate only the
+ # specific users, but sqlite has no way to do this and instead
+ # I would need to SELECT and the DELETE which without locking
+ # is racy.
+ # Have resolved to invalidate the whole cache for now and do
+ # something about it if and when the perf becomes significant
+ self.user_last_seen_monthly_active.invalidate_all()
+ self.get_monthly_active_count.invalidate_all()
+
+ @cached(num_args=0)
+ def get_monthly_active_count(self):
+ """Generates current count of monthly active users
+
+ Returns:
+ Defered[int]: Number of current monthly active users
+ """
+
+ def _count_users(txn):
+ sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+
+ txn.execute(sql)
+ count, = txn.fetchone()
+ return count
+ return self.runInteraction("count_users", _count_users)
+
+ def upsert_monthly_active_user(self, user_id):
+ """
+ Updates or inserts monthly active user member
+ Arguments:
+ user_id (str): user to add/update
+ Deferred[bool]: True if a new entry was created, False if an
+ existing one was updated.
+ """
+ is_insert = self._simple_upsert(
+ desc="upsert_monthly_active_user",
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "timestamp": int(self._clock.time_msec()),
+ },
+ lock=False,
+ )
+ if is_insert:
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+ self.get_monthly_active_count.invalidate(())
+
+ @cached(num_args=1)
+ def user_last_seen_monthly_active(self, user_id):
+ """
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id (str): user to add/update
+ Return:
+ Deferred[int] : timestamp since last seen, None if never seen
+
+ """
+
+ return(self._simple_select_one_onecol(
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcol="timestamp",
+ allow_none=True,
+ desc="user_last_seen_monthly_active",
+ ))
+
+ @defer.inlineCallbacks
+ def populate_monthly_active_users(self, user_id):
+ """Checks on the state of monthly active user limits and optionally
+ add the user to the monthly active tables
+
+ Args:
+ user_id(str): the user_id to query
+ """
+ if self.hs.config.limit_usage_by_mau:
+ last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ now = self.hs.get_clock().time_msec()
+
+ # We want to reduce to the total number of db writes, and are happy
+ # to trade accuracy of timestamp in order to lighten load. This means
+ # We always insert new users (where MAU threshold has not been reached),
+ # but only update if we have not previously seen the user for
+ # LAST_SEEN_GRANULARITY ms
+ if last_seen_timestamp is None:
+ count = yield self.get_monthly_active_count()
+ if count < self.hs.config.max_mau_value:
+ yield self.upsert_monthly_active_user(user_id)
+ elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
+ yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b290f834b3..b364719312 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 50
+SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 3147fb6827..3378fc77d1 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore):
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A namedtuple containing the room information, or an empty list.
+ """
+ return self._simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "is_public", "creator"),
+ desc="get_room",
+ allow_none=True,
+ )
+
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
@@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- def get_room(self, room_id):
- """Retrieve a room.
-
- Args:
- room_id (str): The ID of the room to retrieve.
- Returns:
- A namedtuple containing the room information, or an empty list.
- """
- return self._simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
- desc="get_room",
- allow_none=True,
- )
-
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 10dce21cea..9b4e6d6aa8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql
new file mode 100644
index 0000000000..c9d537d5a3
--- /dev/null
+++ b/synapse/storage/schema/delta/51/monthly_active_users.sql
@@ -0,0 +1,27 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of monthly active users, for use where blocking based on mau limits
+CREATE TABLE monthly_active_users (
+ user_id TEXT NOT NULL,
+ -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
+ -- on updates, Granularity of updates governed by
+ -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
+ -- Measured in ms since epoch.
+ timestamp BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
+CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b27b3ae144..dd03c4168b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -21,15 +21,17 @@ from six.moves import range
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
@@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupWorkerStore(SQLBaseStore):
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
@@ -61,6 +64,30 @@ class StateGroupWorkerStore(SQLBaseStore):
"*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
)
+ @defer.inlineCallbacks
+ def get_room_version(self, room_id):
+ """Get the room_version of a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[str]
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ # for now we do this by looking at the create event. We may want to cache this
+ # more intelligently in future.
+ state_ids = yield self.get_current_state_ids(room_id)
+ create_id = state_ids.get((EventTypes.Create, ""))
+
+ if not create_id:
+ raise NotFoundError("Unknown room")
+
+ create_event = yield self.get_event(create_id)
+ defer.returnValue(create_event.content.get("room_version", "1"))
+
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
"""Get the current state event ids for a room based on the
@@ -89,6 +116,69 @@ class StateGroupWorkerStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+ Args:
+ room_id (str)
+ types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+ which are used to filter the state fetched. `state_key` may be
+ None, which matches any `state_key`
+ filtered_types (list[Str]|None): List of types to apply the above filter to.
+ Returns:
+ deferred: dict of (type, state_key) -> event
+ """
+
+ include_other_types = False if filtered_types is None else True
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ? %s"""
+ # Turns out that postgres doesn't like doing a list of OR's and
+ # is about 1000x slower, so we just issue a query for each specific
+ # type seperately.
+ if types:
+ clause_to_args = [
+ (
+ "AND type = ? AND state_key = ?",
+ (etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
+ )
+ for etype, state_key in types
+ ]
+
+ if include_other_types:
+ unique_types = set(filtered_types)
+ clause_to_args.append(
+ (
+ "AND type <> ? " * len(unique_types),
+ list(unique_types)
+ )
+ )
+ else:
+ # If types is None we fetch all the state, and so just use an
+ # empty where clause with no extra args.
+ clause_to_args = [("", [])]
+ for where_clause, where_args in clause_to_args:
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql % (where_clause,), args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+ return results
+
+ return self.runInteraction(
+ "get_filtered_current_state_ids",
+ _get_filtered_current_state_ids_txn,
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -362,8 +452,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- deferred: A list of dicts corresponding to the event_ids given.
- The dicts are mappings from (type, state_key) -> state_events
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@@ -391,7 +480,8 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
- Get the state dicts corresponding to a list of events
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
@@ -404,7 +494,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- A deferred dict from event_id -> (type, state_key) -> state_event
+ A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b9f2b74ac6..4c296d72c0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
@@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.
diff --git a/synapse/util/async.py b/synapse/util/async_helpers.py
index a7094e2fb4..9b3f2f4b96 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async_helpers.py
@@ -188,62 +188,30 @@ class Linearizer(object):
# things blocked from executing.
self.key_to_defer = {}
- @defer.inlineCallbacks
def queue(self, key):
+ # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
+ # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
+ # propagated inside inlineCallbacks until Twisted 18.7)
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
- # When on of the things currently executing finishes it will callback
+ # When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
- new_defer = defer.Deferred()
- entry[1][new_defer] = 1
-
- logger.info(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
- try:
- yield make_deferred_yieldable(new_defer)
- except Exception as e:
- if isinstance(e, CancelledError):
- logger.info(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
- )
- else:
- logger.warn(
- "Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
- )
-
- # we just have to take ourselves back out of the queue.
- del entry[1][new_defer]
- raise
-
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
- entry[0] += 1
-
- # if the code holding the lock completes synchronously, then it
- # will recursively run the next claimant on the list. That can
- # relatively rapidly lead to stack exhaustion. This is essentially
- # the same problem as http://twistedmatrix.com/trac/ticket/9304.
- #
- # In order to break the cycle, we add a cheeky sleep(0) here to
- # ensure that we fall back to the reactor between each iteration.
- #
- # (This needs to happen while we hold the lock, and the context manager's exit
- # code must be synchronous, so this is the only sensible place.)
- yield self._clock.sleep(0)
-
+ res = self._await_lock(key)
else:
logger.info(
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
+ res = defer.succeed(None)
+
+ # once we successfully get the lock, we need to return a context manager which
+ # will release the lock.
@contextmanager
- def _ctx_manager():
+ def _ctx_manager(_):
try:
yield
finally:
@@ -264,7 +232,64 @@ class Linearizer(object):
# map.
del self.key_to_defer[key]
- defer.returnValue(_ctx_manager())
+ res.addCallback(_ctx_manager)
+ return res
+
+ def _await_lock(self, key):
+ """Helper for queue: adds a deferred to the queue
+
+ Assumes that we've already checked that we've reached the limit of the number
+ of lock-holders we allow. Creates a new deferred which is added to the list, and
+ adds some management around cancellations.
+
+ Returns the deferred, which will callback once we have secured the lock.
+
+ """
+ entry = self.key_to_defer[key]
+
+ logger.info(
+ "Waiting to acquire linearizer lock %r for key %r", self.name, key,
+ )
+
+ new_defer = make_deferred_yieldable(defer.Deferred())
+ entry[1][new_defer] = 1
+
+ def cb(_r):
+ logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ entry[0] += 1
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # (This needs to happen while we hold the lock, and the context manager's exit
+ # code must be synchronous, so this is the only sensible place.)
+ return self._clock.sleep(0)
+
+ def eb(e):
+ logger.info("defer %r got err %r", new_defer, e)
+ if isinstance(e, CancelledError):
+ logger.info(
+ "Cancelling wait for linearizer lock %r for key %r",
+ self.name, key,
+ )
+
+ else:
+ logger.warn(
+ "Unexpected exception waiting for linearizer lock %r for key %r",
+ self.name, key,
+ )
+
+ # we just have to take ourselves back out of the queue.
+ del entry[1][new_defer]
+ return e
+
+ new_defer.addCallbacks(cb, eb)
+ return new_defer
class ReadWriteLock(object):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 861c24809c..187510576a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,7 +25,7 @@ from six import itervalues, string_types
from twisted.internet import defer
from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a8491b42d5..afb03b2e1b 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index d03678b8c8..8318db8d2c 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
class SnapshotCache(object):
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 8dcae50b39..a0c2d37610 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter):
context = LoggingContext.current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
- context.copy_to(record)
+
+ # context should never be None, but if it somehow ends up being, then
+ # we end up in a death spiral of infinite loops, so let's check, for
+ # robustness' sake.
+ if context is not None:
+ context.copy_to(record)
+
return True
@@ -396,7 +402,9 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
- def __init__(self, new_context=LoggingContext.sentinel):
+ def __init__(self, new_context=None):
+ if new_context is None:
+ new_context = LoggingContext.sentinel
self.new_context = new_context
def __enter__(self):
@@ -526,7 +534,7 @@ _to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
- "synapse.util.async",
+ "synapse.util.async_helpers",
]
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 62a00189cc..ef31458226 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -20,6 +20,8 @@ import time
from functools import wraps
from inspect import getcallargs
+from six import PY3
+
_TIME_FUNC_ID = 0
@@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
+ if PY3:
+ lineno = f.__code__.co_firstlineno
+ pathname = f.__code__.co_filename
+ else:
+ lineno = f.func_code.co_firstlineno
+ pathname = f.func_code.co_filename
record = logging.LogRecord(
name=name,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 43d9db67ec..6f318c6a29 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,7 @@
import random
import string
+from six import PY3
from six.moves import range
_string_with_symbols = (
@@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s):
+
+ if PY3:
+ if isinstance(s, bytes):
+ try:
+ s.decode('ascii').encode('ascii')
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
+
try:
s.encode("ascii")
except UnicodeEncodeError:
@@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None.
"""
+ if PY3:
+ return s
+
if s is None:
return None
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1fbcd41115..3baba3225a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
@@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
@@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
except subprocess.CalledProcessError:
git_commit = ""
@@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
- ).strip().endswith(dirty_string)
+ ).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % (
module.__version__, git_version,
)
- ).encode("ascii")
+ )
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__.encode("ascii")
+ return module.__version__
|