diff --git a/synapse/__init__.py b/synapse/__init__.py
index d6a191ccc6..048d6e572f 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.34.1.1"
+__version__ = "0.99.0"
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 4c3abf06fe..6e93f5a0c6 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -46,7 +46,7 @@ def request_registration(
# Get the nonce
r = requests.get(url, verify=False)
- if r.status_code is not 200:
+ if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
@@ -84,7 +84,7 @@ def request_registration(
_print("Sending registration request...")
r = requests.post(url, json=data, verify=False)
- if r.status_code is not 200:
+ if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index ba1019b9b2..5992d30623 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -65,7 +65,7 @@ class Auth(object):
register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks
- def check_from_context(self, event, context, do_sig_check=True):
+ def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True,
@@ -74,12 +74,16 @@ class Auth(object):
auth_events = {
(e.type, e.state_key): e for e in itervalues(auth_events)
}
- self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
+ self.check(
+ room_version, event,
+ auth_events=auth_events, do_sig_check=do_sig_check,
+ )
- def check(self, event, auth_events, do_sig_check=True):
+ def check(self, room_version, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
Args:
+ room_version (str): version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
@@ -88,7 +92,9 @@ class Auth(object):
True if the auth checks pass.
"""
with Measure(self.clock, "auth.check"):
- event_auth.check(event, auth_events, do_sig_check=do_sig_check)
+ event_auth.check(
+ room_version, event, auth_events, do_sig_check=do_sig_check
+ )
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
@@ -545,17 +551,6 @@ class Auth(object):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
- def add_auth_events(self, builder, context):
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
-
- auth_events_entries = yield self.store.add_event_hashes(
- auth_ids
- )
-
- builder.auth_events = auth_events_entries
-
- @defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
defer.returnValue([])
@@ -571,7 +566,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key)
- key = (EventTypes.Member, event.user_id, )
+ key = (EventTypes.Member, event.sender, )
member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
@@ -621,7 +616,7 @@ class Auth(object):
defer.returnValue(auth_ids)
- def check_redaction(self, event, auth_events):
+ def check_redaction(self, room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@@ -634,7 +629,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
- return event_auth.check_redaction(event, auth_events)
+ return event_auth.check_redaction(room_version, event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@@ -819,7 +814,9 @@ class Auth(object):
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
- if is_threepid_reserved(self.hs.config, threepid):
+ if is_threepid_reserved(
+ self.hs.config.mau_limits_reserved_threepids, threepid
+ ):
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 46c4b4b9dc..f47c33a074 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -73,6 +73,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
+ RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
# These are used for validation
@@ -104,10 +105,15 @@ class ThirdPartyEntityKind(object):
class RoomVersions(object):
V1 = "1"
V2 = "2"
- VDH_TEST = "vdh-test-version"
+ V3 = "3"
STATE_V2_TEST = "state-v2-test"
+class RoomDisposition(object):
+ STABLE = "stable"
+ UNSTABLE = "unstable"
+
+
# the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = RoomVersions.V1
@@ -116,10 +122,26 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1
KNOWN_ROOM_VERSIONS = {
RoomVersions.V1,
RoomVersions.V2,
- RoomVersions.VDH_TEST,
+ RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
+ RoomVersions.V3,
}
+
+class EventFormatVersions(object):
+ """This is an internal enum for tracking the version of the event format,
+ independently from the room version.
+ """
+ V1 = 1
+ V2 = 2
+
+
+KNOWN_EVENT_FORMAT_VERSIONS = {
+ EventFormatVersions.V1,
+ EventFormatVersions.V2,
+}
+
+
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 16ad654864..3906475403 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -444,6 +444,20 @@ class Filter(object):
def include_redundant_members(self):
return self.filter_json.get("include_redundant_members", False)
+ def with_room_ids(self, room_ids):
+ """Returns a new filter with the given room IDs appended.
+
+ Args:
+ room_ids (iterable[unicode]): The room_ids to add
+
+ Returns:
+ filter: A new filter including the given rooms and the old
+ filter's rooms.
+ """
+ newFilter = Filter(self.filter_json)
+ newFilter.rooms += room_ids
+ return newFilter
+
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index b45adafdd3..f56f5fcc13 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -12,15 +12,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import logging
import sys
from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
+logger = logging.getLogger(__name__)
+
try:
python_dependencies.check_requirements()
except python_dependencies.DependencyException as e:
sys.stderr.writelines(e.message)
sys.exit(1)
+
+
+def check_bind_error(e, address, bind_addresses):
+ """
+ This method checks an exception occurred while binding on 0.0.0.0.
+ If :: is specified in the bind addresses a warning is shown.
+ The exception is still raised otherwise.
+
+ Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
+ because :: binds on both IPv4 and IPv6 (as per RFC 3493).
+ When binding on 0.0.0.0 after :: this can safely be ignored.
+
+ Args:
+ e (Exception): Exception that was caught.
+ address (str): Address on which binding was attempted.
+ bind_addresses (list): Addresses on which the service listens.
+ """
+ if address == '0.0.0.0' and '::' in bind_addresses:
+ logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+ else:
+ raise e
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 18584226e9..3cbb003035 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -15,18 +15,35 @@
import gc
import logging
+import signal
import sys
+import traceback
import psutil
from daemonize import Daemonize
from twisted.internet import error, reactor
+from synapse.app import check_bind_error
+from synapse.crypto import context_factory
from synapse.util import PreserveLoggingContext
from synapse.util.rlimit import change_resource_limit
logger = logging.getLogger(__name__)
+_sighup_callbacks = []
+
+
+def register_sighup(func):
+ """
+ Register a function to be called when a SIGHUP occurs.
+
+ Args:
+ func (function): Function to be called when sent a SIGHUP signal.
+ Will be called with a single argument, the homeserver.
+ """
+ _sighup_callbacks.append(func)
+
def start_worker_reactor(appname, config):
""" Run the reactor in the main process
@@ -143,6 +160,9 @@ def listen_metrics(bind_addresses, port):
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
+
+ Returns:
+ list (empty)
"""
for address in bind_addresses:
try:
@@ -155,42 +175,80 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
+ logger.info("Synapse now listening on TCP port %d", port)
+ return []
+
def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
"""
- Create an SSL socket for a port and several addresses
+ Create an TLS-over-TCP socket for a port and several addresses
+
+ Returns:
+ list of twisted.internet.tcp.Port listening for TLS connections
"""
+ r = []
for address in bind_addresses:
try:
- reactor.listenSSL(
- port,
- factory,
- context_factory,
- backlog,
- address
+ r.append(
+ reactor.listenSSL(
+ port,
+ factory,
+ context_factory,
+ backlog,
+ address
+ )
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
+ logger.info("Synapse now listening on port %d (TLS)", port)
+ return r
+
-def check_bind_error(e, address, bind_addresses):
+def refresh_certificate(hs):
"""
- This method checks an exception occurred while binding on 0.0.0.0.
- If :: is specified in the bind addresses a warning is shown.
- The exception is still raised otherwise.
+ Refresh the TLS certificates that Synapse is using by re-reading them from
+ disk and updating the TLS context factories to use them.
+ """
+ logging.info("Loading certificate from disk...")
+ hs.config.read_certificate_from_disk()
+ hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
+ hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+ hs.config
+ )
+ logging.info("Certificate loaded.")
- Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
- because :: binds on both IPv4 and IPv6 (as per RFC 3493).
- When binding on 0.0.0.0 after :: this can safely be ignored.
+
+def start(hs, listeners=None):
+ """
+ Start a Synapse server or worker.
Args:
- e (Exception): Exception that was caught.
- address (str): Address on which binding was attempted.
- bind_addresses (list): Addresses on which the service listens.
+ hs (synapse.server.HomeServer)
+ listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
"""
- if address == '0.0.0.0' and '::' in bind_addresses:
- logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
- else:
- raise e
+ try:
+ # Set up the SIGHUP machinery.
+ if hasattr(signal, "SIGHUP"):
+ def handle_sighup(*args, **kwargs):
+ for i in _sighup_callbacks:
+ i(hs)
+
+ signal.signal(signal.SIGHUP, handle_sighup)
+
+ register_sighup(refresh_certificate)
+
+ # Load the certificate from disk.
+ refresh_certificate(hs)
+
+ # It is now safe to start your Synapse.
+ hs.start_listening(listeners)
+ hs.get_datastore().start_profiling()
+ except Exception:
+ traceback.print_exc(file=sys.stderr)
+ reactor = hs.get_reactor()
+ if reactor.running:
+ reactor.stop()
+ sys.exit(1)
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 8559e141af..33107f56d1 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -168,12 +168,7 @@ def start(config_options):
)
ps.setup()
- ps.start_listening(config.worker_listeners)
-
- def start():
- ps.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ps, config.worker_listeners)
_base.start_worker_reactor("synapse-appservice", config)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 76aed8c60a..a9d2147022 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@@ -164,26 +163,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-client-reader", config)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index e4a68715aa..b8e5196152 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@@ -185,26 +184,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-event-creator", config)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 2c99ce8c64..6ee2b76dcd 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@@ -162,26 +161,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-federation-reader", config)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index e9a99d76e1..a461442fdc 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -25,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.federation import send_queue
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@@ -183,26 +182,17 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
- ps = FederationSenderServer(
+ ss = FederationSenderServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
- ps.setup()
- ps.start_listening(config.worker_listeners)
-
- def start():
- ps.get_datastore().start_profiling()
+ ss.setup()
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
- reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-federation-sender", config)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index f5c61dec5b..d5b954361d 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
@@ -241,26 +240,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = FrontendProxyServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-frontend-proxy", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0a924d7a80..d1cab07bb6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,10 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import logging
import os
import sys
+import traceback
from six import iteritems
@@ -45,7 +47,6 @@ from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
-from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
@@ -82,6 +83,7 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
+ _listening_services = []
def _listener_http(self, config, listener_config):
port = listener_config["port"]
@@ -90,7 +92,9 @@ class SynapseHomeServer(HomeServer):
site_tag = listener_config.get("tag", port)
if tls and config.no_tls:
- return
+ raise ConfigError(
+ "Listener on port %i has TLS enabled, but no_tls is set" % (port,),
+ )
resources = {}
for res in listener_config["resources"]:
@@ -123,7 +127,7 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource)
if tls:
- listen_ssl(
+ return listen_ssl(
bind_addresses,
port,
SynapseSite(
@@ -138,7 +142,7 @@ class SynapseHomeServer(HomeServer):
)
else:
- listen_tcp(
+ return listen_tcp(
bind_addresses,
port,
SynapseSite(
@@ -150,7 +154,6 @@ class SynapseHomeServer(HomeServer):
),
reactor=self.get_reactor(),
)
- logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
@@ -246,12 +249,14 @@ class SynapseHomeServer(HomeServer):
return resources
- def start_listening(self):
+ def start_listening(self, listeners):
config = self.get_config()
- for listener in config.listeners:
+ for listener in listeners:
if listener["type"] == "http":
- self._listener_http(config, listener)
+ self._listening_services.extend(
+ self._listener_http(config, listener)
+ )
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
@@ -331,21 +336,19 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- synapse.config.logger.setup_logging(config, use_worker_options=False)
+ synapse.config.logger.setup_logging(
+ config,
+ use_worker_options=False
+ )
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
@@ -372,12 +375,44 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
- hs.start_listening()
+ @defer.inlineCallbacks
def start():
- hs.get_pusherpool().start()
- hs.get_datastore().start_profiling()
- hs.get_datastore().start_doing_background_updates()
+ try:
+ # Check if the certificate is still valid.
+ cert_days_remaining = hs.config.is_disk_cert_valid()
+
+ if hs.config.acme_enabled:
+ # If ACME is enabled, we might need to provision a certificate
+ # before starting.
+ acme = hs.get_acme_handler()
+
+ # Start up the webservices which we will respond to ACME
+ # challenges with.
+ yield acme.start_listening()
+
+ # We want to reprovision if cert_days_remaining is None (meaning no
+ # certificate exists), or the days remaining number it returns
+ # is less than our re-registration threshold.
+ if (cert_days_remaining is None) or (
+ not cert_days_remaining > hs.config.acme_reprovision_threshold
+ ):
+ yield acme.provision_certificate()
+
+ _base.start(hs, config.listeners)
+
+ hs.get_pusherpool().start()
+ hs.get_datastore().start_doing_background_updates()
+ except Exception as e:
+ # If a DeferredList failed (like in listening on the ACME listener),
+ # we need to print the subfailure explicitly.
+ if isinstance(e, defer.FirstError):
+ e.subFailure.printTraceback(sys.stderr)
+ sys.exit(1)
+
+ # Something else went wrong when starting. Print it and bail out.
+ traceback.print_exc(file=sys.stderr)
+ sys.exit(1)
reactor.callWhenRunning(start)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index acc0487adc..d4cc4e9443 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
@@ -151,26 +150,16 @@ def start(config_options):
database_engine = create_engine(config.database_config)
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-media-repository", config)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 83b0863f00..cbf0d67f51 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -224,11 +224,10 @@ def start(config_options):
)
ps.setup()
- ps.start_listening(config.worker_listeners)
def start():
+ _base.start(ps, config.worker_listeners)
ps.get_pusherpool().start()
- ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 0354e82bf8..9163b56d86 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -445,12 +445,7 @@ def start(config_options):
)
ss.setup()
- ss.start_listening(config.worker_listeners)
-
- def start():
- ss.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-synchrotron", config)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 0a5f62b509..d1ab9512cd 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -26,7 +26,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
-from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
@@ -211,26 +210,16 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True
- tls_server_context_factory = context_factory.ServerContextFactory(config)
- tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
-
- ps = UserDirectoryServer(
+ ss = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
- tls_server_context_factory=tls_server_context_factory,
- tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
- ps.setup()
- ps.start_listening(config.worker_listeners)
-
- def start():
- ps.get_datastore().start_profiling()
-
- reactor.callWhenRunning(start)
+ ss.setup()
+ reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
_base.start_worker_reactor("synapse-user-dir", config)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fd2d6d52ef..5858fb92b4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -367,7 +367,7 @@ class Config(object):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
- config_dir_path = os.path.abspath(keys_directory)
+ self.config_dir_path = os.path.abspath(keys_directory)
specified_config = {}
for config_file in config_files:
@@ -379,7 +379,7 @@ class Config(object):
server_name = specified_config["server_name"]
config_string = self.generate_config(
- config_dir_path=config_dir_path,
+ config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
generate_secrets=False,
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 403d96ba76..9f25bbc5cb 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -24,6 +24,7 @@ class ApiConfig(Config):
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
EventTypes.Name,
])
@@ -36,5 +37,6 @@ class ApiConfig(Config):
- "{JoinRules}"
- "{CanonicalAlias}"
- "{RoomAvatar}"
+ - "{RoomEncryption}"
- "{Name}"
""".format(**vars(EventTypes))
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index f193a090ae..9f2e85342f 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from os import path
+
+from synapse.config import ConfigError
+
from ._base import Config
DEFAULT_CONFIG = """\
@@ -85,7 +89,15 @@ class ConsentConfig(Config):
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
- self.user_consent_template_dir = consent_config["template_dir"]
+ self.user_consent_template_dir = self.abspath(
+ consent_config["template_dir"]
+ )
+ if not path.isdir(self.user_consent_template_dir):
+ raise ConfigError(
+ "Could not find template directory '%s'" % (
+ self.user_consent_template_dir,
+ ),
+ )
self.user_consent_server_notice_content = consent_config.get(
"server_notice_content",
)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index f87efecbf8..4b938053fb 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -15,7 +15,6 @@
import logging
import logging.config
import os
-import signal
import sys
from string import Template
@@ -24,6 +23,7 @@ import yaml
from twisted.logger import STDLibLogObserver, globalLogBeginner
import synapse
+from synapse.app import _base as appbase
from synapse.util.logcontext import LoggingContextFilter
from synapse.util.versionstring import get_version_string
@@ -136,6 +136,9 @@ def setup_logging(config, use_worker_options=False):
use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'.
+
+ register_sighup (func | None): Function to call to register a
+ sighup handler.
"""
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
@@ -178,7 +181,7 @@ def setup_logging(config, use_worker_options=False):
else:
handler = logging.StreamHandler()
- def sighup(signum, stack):
+ def sighup(*args):
pass
handler.setFormatter(formatter)
@@ -191,20 +194,14 @@ def setup_logging(config, use_worker_options=False):
with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
- def sighup(signum, stack):
+ def sighup(*args):
# it might be better to use a file watcher or something for this.
load_log_config()
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
- # TODO(paul): obviously this is a terrible mechanism for
- # stealing SIGHUP, because it means no other part of synapse
- # can use it instead. If we want to catch SIGHUP anywhere
- # else as well, I'd suggest we find a nicer way to broadcast
- # it around.
- if getattr(signal, "SIGHUP"):
- signal.signal(signal.SIGHUP, sighup)
+ appbase.register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index fe520d6855..d808a989f3 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -84,11 +84,11 @@ class RegistrationConfig(Config):
#
# allowed_local_3pids:
# - medium: email
- # pattern: ".*@matrix\\.org"
+ # pattern: '.*@matrix\\.org'
# - medium: email
- # pattern: ".*@vector\\.im"
+ # pattern: '.*@vector\\.im'
# - medium: msisdn
- # pattern: "\\+44"
+ # pattern: '\\+44'
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 4eefd06f4a..f0a60cc712 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -256,8 +256,12 @@ class ServerConfig(Config):
#
# web_client_location: "/path/to/web/root"
- # The public-facing base URL for the client API (not including _matrix/...)
- # public_baseurl: https://example.com:8448/
+ # The public-facing base URL that clients use to access this HS
+ # (not including _matrix/...). This is the same URL a user would
+ # enter into the 'custom HS URL' field on their client. If you
+ # use synapse with a reverse proxy, this should be the URL to reach
+ # synapse via the proxy.
+ # public_baseurl: https://example.com/
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
@@ -429,19 +433,18 @@ class ServerConfig(Config):
" service on the given port.")
-def is_threepid_reserved(config, threepid):
+def is_threepid_reserved(reserved_threepids, threepid):
"""Check the threepid against the reserved threepid config
Args:
- config(ServerConfig) - to access server config attributes
+ reserved_threepids([dict]) - list of reserved threepids
threepid(dict) - The threepid to test for
Returns:
boolean Is the threepid undertest reserved_user
"""
- for tp in config.mau_limits_reserved_threepids:
- if (threepid['medium'] == tp['medium']
- and threepid['address'] == tp['address']):
+ for tp in reserved_threepids:
+ if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
return True
return False
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index bb8952c672..b5f2cfd9b7 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -13,72 +13,196 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
+import warnings
+from datetime import datetime
from hashlib import sha256
from unpaddedbase64 import encode_base64
from OpenSSL import crypto
-from ._base import Config
+from synapse.config._base import Config
+
+logger = logging.getLogger()
class TlsConfig(Config):
def read_config(self, config):
- self.tls_certificate = self.read_tls_certificate(
- config.get("tls_certificate_path")
- )
- self.tls_certificate_file = config.get("tls_certificate_path")
+ acme_config = config.get("acme", None)
+ if acme_config is None:
+ acme_config = {}
+
+ self.acme_enabled = acme_config.get("enabled", False)
+ self.acme_url = acme_config.get(
+ "url", u"https://acme-v01.api.letsencrypt.org/directory"
+ )
+ self.acme_port = acme_config.get("port", 80)
+ self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
+ self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
+
+ self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
+ self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
+ self._original_tls_fingerprints = config["tls_fingerprints"]
+ self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False)
- if self.no_tls:
- self.tls_private_key = None
- else:
- self.tls_private_key = self.read_tls_private_key(
- config.get("tls_private_key_path")
+ # This config option applies to non-federation HTTP clients
+ # (e.g. for talking to recaptcha, identity servers, and such)
+ # It should never be used in production, and is intended for
+ # use only when running tests.
+ self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
+ "use_insecure_ssl_client_just_for_testing_do_not_use"
+ )
+
+ self.tls_certificate = None
+ self.tls_private_key = None
+
+ def is_disk_cert_valid(self):
+ """
+ Is the certificate we have on disk valid, and if so, for how long?
+
+ Returns:
+ int: Days remaining of certificate validity.
+ None: No certificate exists.
+ """
+ if not os.path.exists(self.tls_certificate_file):
+ return None
+
+ try:
+ with open(self.tls_certificate_file, 'rb') as f:
+ cert_pem = f.read()
+ except Exception:
+ logger.exception("Failed to read existing certificate off disk!")
+ raise
+
+ try:
+ tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+ except Exception:
+ logger.exception("Failed to parse existing certificate off disk!")
+ raise
+
+ # YYYYMMDDhhmmssZ -- in UTC
+ expires_on = datetime.strptime(
+ tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
+ )
+ now = datetime.utcnow()
+ days_remaining = (expires_on - now).days
+ return days_remaining
+
+ def read_certificate_from_disk(self):
+ """
+ Read the certificates from disk.
+ """
+ self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
+
+ # Check if it is self-signed, and issue a warning if so.
+ if self.tls_certificate.get_issuer() == self.tls_certificate.get_subject():
+ warnings.warn(
+ (
+ "Self-signed TLS certificates will not be accepted by Synapse 1.0. "
+ "Please either provide a valid certificate, or use Synapse's ACME "
+ "support to provision one."
+ )
)
- self.tls_fingerprints = config["tls_fingerprints"]
+ if not self.no_tls:
+ self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
+
+ self.tls_fingerprints = list(self._original_tls_fingerprints)
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1,
- self.tls_certificate
+ crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
- # This config option applies to non-federation HTTP clients
- # (e.g. for talking to recaptcha, identity servers, and such)
- # It should never be used in production, and is intended for
- # use only when running tests.
- self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
- "use_insecure_ssl_client_just_for_testing_do_not_use"
- )
-
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
- return """\
- # PEM encoded X509 certificate for TLS.
- # You can replace the self-signed certificate that synapse
- # autogenerates on launch with your own SSL certificate + key pair
- # if you like. Any required intermediary certificates can be
- # appended after the primary certificate in hierarchical order.
+ # this is to avoid the max line length. Sorrynotsorry
+ proxypassline = (
+ 'ProxyPass /.well-known/acme-challenge '
+ 'http://localhost:8009/.well-known/acme-challenge'
+ )
+
+ return (
+ """\
+ # PEM-encoded X509 certificate for TLS.
+ # This certificate, as of Synapse 1.0, will need to be a valid and verifiable
+ # certificate, signed by a recognised Certificate Authority.
+ #
+ # See 'ACME support' below to enable auto-provisioning this certificate via
+ # Let's Encrypt.
+ #
tls_certificate_path: "%(tls_certificate_path)s"
- # PEM encoded private key for TLS
+ # PEM-encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
- # Don't bind to the https port
- no_tls: False
+ # ACME support: This will configure Synapse to request a valid TLS certificate
+ # for your configured `server_name` via Let's Encrypt.
+ #
+ # Note that provisioning a certificate in this way requires port 80 to be
+ # routed to Synapse so that it can complete the http-01 ACME challenge.
+ # By default, if you enable ACME support, Synapse will attempt to listen on
+ # port 80 for incoming http-01 challenges - however, this will likely fail
+ # with 'Permission denied' or a similar error.
+ #
+ # There are a couple of potential solutions to this:
+ #
+ # * If you already have an Apache, Nginx, or similar listening on port 80,
+ # you can configure Synapse to use an alternate port, and have your web
+ # server forward the requests. For example, assuming you set 'port: 8009'
+ # below, on Apache, you would write:
+ #
+ # %(proxypassline)s
+ #
+ # * Alternatively, you can use something like `authbind` to give Synapse
+ # permission to listen on port 80.
+ #
+ acme:
+ # ACME support is disabled by default. Uncomment the following line
+ # to enable it.
+ #
+ # enabled: true
+
+ # Endpoint to use to request certificates. If you only want to test,
+ # use Let's Encrypt's staging url:
+ # https://acme-staging.api.letsencrypt.org/directory
+ #
+ # url: https://acme-v01.api.letsencrypt.org/directory
+
+ # Port number to listen on for the HTTP-01 challenge. Change this if
+ # you are forwarding connections through Apache/Nginx/etc.
+ #
+ # port: 80
+
+ # Local addresses to listen on for incoming connections.
+ # Again, you may want to change this if you are forwarding connections
+ # through Apache/Nginx/etc.
+ #
+ # bind_addresses: ['::', '0.0.0.0']
+
+ # How many days remaining on a certificate before it is renewed.
+ #
+ # reprovision_threshold: 30
+
+ # If your server runs behind a reverse-proxy which terminates TLS connections
+ # (for both client and federation connections), it may be useful to disable
+ # All TLS support for incoming connections. Setting no_tls to True will
+ # do so (and avoid the need to give synapse a TLS private key).
+ #
+ # no_tls: True
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
@@ -107,7 +231,10 @@ class TlsConfig(Config):
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
- """ % locals()
+
+ """
+ % locals()
+ )
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@@ -116,40 +243,3 @@ class TlsConfig(Config):
def read_tls_private_key(self, private_key_path):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
-
- def generate_files(self, config):
- tls_certificate_path = config["tls_certificate_path"]
- tls_private_key_path = config["tls_private_key_path"]
-
- if not self.path_exists(tls_private_key_path):
- with open(tls_private_key_path, "wb") as private_key_file:
- tls_private_key = crypto.PKey()
- tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
- private_key_pem = crypto.dump_privatekey(
- crypto.FILETYPE_PEM, tls_private_key
- )
- private_key_file.write(private_key_pem)
- else:
- with open(tls_private_key_path) as private_key_file:
- private_key_pem = private_key_file.read()
- tls_private_key = crypto.load_privatekey(
- crypto.FILETYPE_PEM, private_key_pem
- )
-
- if not self.path_exists(tls_certificate_path):
- with open(tls_certificate_path, "wb") as certificate_file:
- cert = crypto.X509()
- subject = cert.get_subject()
- subject.CN = config["server_name"]
-
- cert.set_serial_number(1000)
- cert.gmtime_adj_notBefore(0)
- cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
- cert.set_issuer(cert.get_subject())
- cert.set_pubkey(tls_private_key)
-
- cert.sign(tls_private_key, 'sha256')
-
- cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
-
- certificate_file.write(cert_pem)
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 6ba3eca7b2..286ad80100 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -17,6 +17,7 @@ from zope.interface import implementer
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
+from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
@@ -98,8 +99,14 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx):
self._ctx = ctx
- self._hostname = hostname
- self._hostnameBytes = _idnaBytes(hostname)
+
+ if isIPAddress(hostname) or isIPv6Address(hostname):
+ self._hostnameBytes = hostname.encode('ascii')
+ self._sendSNI = False
+ else:
+ self._hostnameBytes = _idnaBytes(hostname)
+ self._sendSNI = True
+
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
@@ -111,7 +118,9 @@ class ClientTLSOptions(object):
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
- if where & SSL.SSL_CB_HANDSHAKE_START:
+ # Literal IPv4 and IPv6 addresses are not permitted
+ # as host names according to the RFCs
+ if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes)
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 8774b28967..1dfa727fcf 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -23,14 +23,14 @@ from signedjson.sign import sign_json
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
-from synapse.events.utils import prune_event
+from synapse.events.utils import prune_event, prune_event_dict
logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
- name, expected_hash = compute_content_hash(event, hash_algorithm)
+ name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
# some malformed events lack a 'hashes'. Protect against it being missing
@@ -59,35 +59,70 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash
-def compute_content_hash(event, hash_algorithm):
- event_json = event.get_pdu_json()
- event_json.pop("age_ts", None)
- event_json.pop("unsigned", None)
- event_json.pop("signatures", None)
- event_json.pop("hashes", None)
- event_json.pop("outlier", None)
- event_json.pop("destinations", None)
+def compute_content_hash(event_dict, hash_algorithm):
+ """Compute the content hash of an event, which is the hash of the
+ unredacted event.
- event_json_bytes = encode_canonical_json(event_json)
+ Args:
+ event_dict (dict): The unredacted event as a dict
+ hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+ to hash the event
+
+ Returns:
+ tuple[str, bytes]: A tuple of the name of hash and the hash as raw
+ bytes.
+ """
+ event_dict = dict(event_dict)
+ event_dict.pop("age_ts", None)
+ event_dict.pop("unsigned", None)
+ event_dict.pop("signatures", None)
+ event_dict.pop("hashes", None)
+ event_dict.pop("outlier", None)
+ event_dict.pop("destinations", None)
+
+ event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+ """Computes the event reference hash. This is the hash of the redacted
+ event.
+
+ Args:
+ event (FrozenEvent)
+ hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+ to hash the event
+
+ Returns:
+ tuple[str, bytes]: A tuple of the name of hash and the hash as raw
+ bytes.
+ """
tmp_event = prune_event(event)
- event_json = tmp_event.get_pdu_json()
- event_json.pop("signatures", None)
- event_json.pop("age_ts", None)
- event_json.pop("unsigned", None)
- event_json_bytes = encode_canonical_json(event_json)
+ event_dict = tmp_event.get_pdu_json()
+ event_dict.pop("signatures", None)
+ event_dict.pop("age_ts", None)
+ event_dict.pop("unsigned", None)
+ event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
-def compute_event_signature(event, signature_name, signing_key):
- tmp_event = prune_event(event)
- redact_json = tmp_event.get_pdu_json()
+def compute_event_signature(event_dict, signature_name, signing_key):
+ """Compute the signature of the event for the given name and key.
+
+ Args:
+ event_dict (dict): The event as a dict
+ signature_name (str): The name of the entity signing the event
+ (typically the server's hostname).
+ signing_key (syutil.crypto.SigningKey): The key to sign with
+
+ Returns:
+ dict[str, dict[str, str]]: Returns a dictionary in the same format of
+ an event's signatures field.
+ """
+ redact_json = prune_event_dict(event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", encode_canonical_json(redact_json))
@@ -96,25 +131,25 @@ def compute_event_signature(event, signature_name, signing_key):
return redact_json["signatures"]
-def add_hashes_and_signatures(event, signature_name, signing_key,
+def add_hashes_and_signatures(event_dict, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
- # if hasattr(event, "old_state_events"):
- # state_json_bytes = encode_canonical_json(
- # [e.event_id for e in event.old_state_events.values()]
- # )
- # hashed = hash_algorithm(state_json_bytes)
- # event.state_hash = {
- # hashed.name: encode_base64(hashed.digest())
- # }
-
- name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
-
- if not hasattr(event, "hashes"):
- event.hashes = {}
- event.hashes[name] = encode_base64(digest)
-
- event.signatures = compute_event_signature(
- event,
+ """Add content hash and sign the event
+
+ Args:
+ event_dict (dict): The event to add hashes to and sign
+ signature_name (str): The name of the entity signing the event
+ (typically the server's hostname).
+ signing_key (syutil.crypto.SigningKey): The key to sign with
+ hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+ to hash the event
+ """
+
+ name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
+
+ event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
+
+ event_dict["signatures"] = compute_event_signature(
+ event_dict,
signature_name=signature_name,
signing_key=signing_key,
)
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c81d8e6729..8f9e330da5 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -20,17 +20,25 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
-from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
+from synapse.api.constants import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ EventTypes,
+ JoinRules,
+ Membership,
+ RoomVersions,
+)
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
-def check(event, auth_events, do_sig_check=True, do_size_check=True):
+def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
+ room_version (str): the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
@@ -48,7 +56,6 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
- event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
@@ -65,9 +72,13 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
- # Check the event_id's domain has signed the event
- if not event.signatures.get(event_id_domain):
- raise AuthError(403, "Event not signed by sending server")
+ if event.format_version in (EventFormatVersions.V1,):
+ # Only older room versions have event IDs to check.
+ event_id_domain = get_domain_from_id(event.event_id)
+
+ # Check the origin domain has signed the event
+ if not event.signatures.get(event_id_domain):
+ raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
@@ -167,7 +178,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
- check_redaction(event, auth_events)
+ check_redaction(room_version, event, auth_events)
logger.debug("Allowing! %s", event)
@@ -421,7 +432,7 @@ def _can_send_event(event, auth_events):
return True
-def check_redaction(event, auth_events):
+def check_redaction(room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@@ -441,10 +452,16 @@ def check_redaction(event, auth_events):
if user_level >= redact_level:
return False
- redacter_domain = get_domain_from_id(event.event_id)
- redactee_domain = get_domain_from_id(event.redacts)
- if redacter_domain == redactee_domain:
+ if room_version in (RoomVersions.V1, RoomVersions.V2,):
+ redacter_domain = get_domain_from_id(event.event_id)
+ redactee_domain = get_domain_from_id(event.redacts)
+ if redacter_domain == redactee_domain:
+ return True
+ elif room_version == RoomVersions.V3:
+ event.internal_metadata.recheck_redaction = True
return True
+ else:
+ raise RuntimeError("Unrecognized room version %r" % (room_version,))
raise AuthError(
403,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 84c75495d5..20c1ab4203 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,9 @@ from distutils.util import strtobool
import six
+from unpaddedbase64 import encode_base64
+
+from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -41,8 +45,13 @@ class _EventInternalMetadata(object):
def is_outlier(self):
return getattr(self, "outlier", False)
- def is_invite_from_remote(self):
- return getattr(self, "invite_from_remote", False)
+ def is_out_of_band_membership(self):
+ """Whether this is an out of band membership, like an invite or an invite
+ rejection. This is needed as those events are marked as outliers, but
+ they still need to be processed as if they're new events (e.g. updating
+ invite state in the database, relaying to clients, etc).
+ """
+ return getattr(self, "out_of_band_membership", False)
def get_send_on_behalf_of(self):
"""Whether this server should send the event on behalf of another server.
@@ -53,6 +62,21 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "send_on_behalf_of", None)
+ def need_to_check_redaction(self):
+ """Whether the redaction event needs to be rechecked when fetching
+ from the database.
+
+ Starting in room v3 redaction events are accepted up front, and later
+ checked to see if the redacter and redactee's domains match.
+
+ If the sender of the redaction event is allowed to redact any event
+ due to auth rules, then this will always return false.
+
+ Returns:
+ bool
+ """
+ return getattr(self, "recheck_redaction", False)
+
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@@ -179,6 +203,8 @@ class EventBase(object):
class FrozenEvent(EventBase):
+ format_version = EventFormatVersions.V1 # All events of this type are V1
+
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
@@ -213,22 +239,136 @@ class FrozenEvent(EventBase):
rejected_reason=rejected_reason,
)
- @staticmethod
- def from_event(event):
- e = FrozenEvent(
- event.get_pdu_json()
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
+ self.get("event_id", None),
+ self.get("type", None),
+ self.get("state_key", None),
+ )
+
+
+class FrozenEventV2(EventBase):
+ format_version = EventFormatVersions.V2 # All events of this type are V2
+
+ def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
+ event_dict = dict(event_dict)
+
+ # Signatures is a dict of dicts, and this is faster than doing a
+ # copy.deepcopy
+ signatures = {
+ name: {sig_id: sig for sig_id, sig in sigs.items()}
+ for name, sigs in event_dict.pop("signatures", {}).items()
+ }
+
+ assert "event_id" not in event_dict
+
+ unsigned = dict(event_dict.pop("unsigned", {}))
+
+ # We intern these strings because they turn up a lot (especially when
+ # caching).
+ event_dict = intern_dict(event_dict)
+
+ if USE_FROZEN_DICTS:
+ frozen_dict = freeze(event_dict)
+ else:
+ frozen_dict = event_dict
+
+ self._event_id = None
+ self.type = event_dict["type"]
+ if "state_key" in event_dict:
+ self.state_key = event_dict["state_key"]
+
+ super(FrozenEventV2, self).__init__(
+ frozen_dict,
+ signatures=signatures,
+ unsigned=unsigned,
+ internal_metadata_dict=internal_metadata_dict,
+ rejected_reason=rejected_reason,
)
- e.internal_metadata = event.internal_metadata
+ @property
+ def event_id(self):
+ # We have to import this here as otherwise we get an import loop which
+ # is hard to break.
+ from synapse.crypto.event_signing import compute_event_reference_hash
- return e
+ if self._event_id:
+ return self._event_id
+ self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
+ return self._event_id
+
+ def prev_event_ids(self):
+ """Returns the list of prev event IDs. The order matches the order
+ specified in the event, though there is no meaning to it.
+
+ Returns:
+ list[str]: The list of event IDs of this event's prev_events
+ """
+ return self.prev_events
+
+ def auth_event_ids(self):
+ """Returns the list of auth event IDs. The order matches the order
+ specified in the event, though there is no meaning to it.
+
+ Returns:
+ list[str]: The list of event IDs of this event's auth_events
+ """
+ return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
- return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
- self.get("event_id", None),
+ return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
+ self.event_id,
self.get("type", None),
self.get("state_key", None),
)
+
+
+def room_version_to_event_format(room_version):
+ """Converts a room version string to the event format
+
+ Args:
+ room_version (str)
+
+ Returns:
+ int
+ """
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ # We should have already checked version, so this should not happen
+ raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
+ if room_version in (
+ RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
+ ):
+ return EventFormatVersions.V1
+ elif room_version in (RoomVersions.V3,):
+ return EventFormatVersions.V2
+ else:
+ raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
+
+def event_type_from_format_version(format_version):
+ """Returns the python type to use to construct an Event object for the
+ given event format version.
+
+ Args:
+ format_version (int): The event format version
+
+ Returns:
+ type: A type that can be initialized as per the initializer of
+ `FrozenEvent`
+ """
+
+ if format_version == EventFormatVersions.V1:
+ return FrozenEvent
+ elif format_version == EventFormatVersions.V2:
+ return FrozenEventV2
+ else:
+ raise Exception(
+ "No event format %r" % (format_version,)
+ )
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index e662eaef10..06e01be918 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -13,63 +13,270 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
+import attr
+from twisted.internet import defer
+
+from synapse.api.constants import (
+ KNOWN_EVENT_FORMAT_VERSIONS,
+ KNOWN_ROOM_VERSIONS,
+ MAX_DEPTH,
+ EventFormatVersions,
+)
+from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID
from synapse.util.stringutils import random_string
-from . import EventBase, FrozenEvent, _event_dict_property
+from . import (
+ _EventInternalMetadata,
+ event_type_from_format_version,
+ room_version_to_event_format,
+)
+
+
+@attr.s(slots=True, cmp=False, frozen=True)
+class EventBuilder(object):
+ """A format independent event builder used to build up the event content
+ before signing the event.
+
+ (Note that while objects of this class are frozen, the
+ content/unsigned/internal_metadata fields are still mutable)
+
+ Attributes:
+ format_version (int): Event format version
+ room_id (str)
+ type (str)
+ sender (str)
+ content (dict)
+ unsigned (dict)
+ internal_metadata (_EventInternalMetadata)
+
+ _state (StateHandler)
+ _auth (synapse.api.Auth)
+ _store (DataStore)
+ _clock (Clock)
+ _hostname (str): The hostname of the server creating the event
+ _signing_key: The signing key to use to sign the event as the server
+ """
+
+ _state = attr.ib()
+ _auth = attr.ib()
+ _store = attr.ib()
+ _clock = attr.ib()
+ _hostname = attr.ib()
+ _signing_key = attr.ib()
+
+ format_version = attr.ib()
+
+ room_id = attr.ib()
+ type = attr.ib()
+ sender = attr.ib()
+
+ content = attr.ib(default=attr.Factory(dict))
+ unsigned = attr.ib(default=attr.Factory(dict))
+
+ # These only exist on a subset of events, so they raise AttributeError if
+ # someone tries to get them when they don't exist.
+ _state_key = attr.ib(default=None)
+ _redacts = attr.ib(default=None)
+ internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
-class EventBuilder(EventBase):
- def __init__(self, key_values={}, internal_metadata_dict={}):
- signatures = copy.deepcopy(key_values.pop("signatures", {}))
- unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
+ @property
+ def state_key(self):
+ if self._state_key is not None:
+ return self._state_key
- super(EventBuilder, self).__init__(
- key_values,
- signatures=signatures,
- unsigned=unsigned,
- internal_metadata_dict=internal_metadata_dict,
+ raise AttributeError("state_key")
+
+ def is_state(self):
+ return self._state_key is not None
+
+ @defer.inlineCallbacks
+ def build(self, prev_event_ids):
+ """Transform into a fully signed and hashed event
+
+ Args:
+ prev_event_ids (list[str]): The event IDs to use as the prev events
+
+ Returns:
+ Deferred[FrozenEvent]
+ """
+
+ state_ids = yield self._state.get_current_state_ids(
+ self.room_id, prev_event_ids,
+ )
+ auth_ids = yield self._auth.compute_auth_events(
+ self, state_ids,
)
- event_id = _event_dict_property("event_id")
- state_key = _event_dict_property("state_key")
- type = _event_dict_property("type")
+ if self.format_version == EventFormatVersions.V1:
+ auth_events = yield self._store.add_event_hashes(auth_ids)
+ prev_events = yield self._store.add_event_hashes(prev_event_ids)
+ else:
+ auth_events = auth_ids
+ prev_events = prev_event_ids
+
+ old_depth = yield self._store.get_max_depth_of(
+ prev_event_ids,
+ )
+ depth = old_depth + 1
+
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(depth, MAX_DEPTH)
+
+ event_dict = {
+ "auth_events": auth_events,
+ "prev_events": prev_events,
+ "type": self.type,
+ "room_id": self.room_id,
+ "sender": self.sender,
+ "content": self.content,
+ "unsigned": self.unsigned,
+ "depth": depth,
+ "prev_state": [],
+ }
+
+ if self.is_state():
+ event_dict["state_key"] = self._state_key
- def build(self):
- return FrozenEvent.from_event(self)
+ if self._redacts is not None:
+ event_dict["redacts"] = self._redacts
+
+ defer.returnValue(
+ create_local_event_from_event_dict(
+ clock=self._clock,
+ hostname=self._hostname,
+ signing_key=self._signing_key,
+ format_version=self.format_version,
+ event_dict=event_dict,
+ internal_metadata_dict=self.internal_metadata.get_dict(),
+ )
+ )
class EventBuilderFactory(object):
- def __init__(self, clock, hostname):
- self.clock = clock
- self.hostname = hostname
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+ self.hostname = hs.hostname
+ self.signing_key = hs.config.signing_key[0]
+
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
+
+ def new(self, room_version, key_values):
+ """Generate an event builder appropriate for the given room version
+
+ Args:
+ room_version (str): Version of the room that we're creating an
+ event builder for
+ key_values (dict): Fields used as the basis of the new event
+
+ Returns:
+ EventBuilder
+ """
+
+ # There's currently only the one event version defined
+ if room_version not in KNOWN_ROOM_VERSIONS:
+ raise Exception(
+ "No event format defined for version %r" % (room_version,)
+ )
+
+ return EventBuilder(
+ store=self.store,
+ state=self.state,
+ auth=self.auth,
+ clock=self.clock,
+ hostname=self.hostname,
+ signing_key=self.signing_key,
+ format_version=room_version_to_event_format(room_version),
+ type=key_values["type"],
+ state_key=key_values.get("state_key"),
+ room_id=key_values["room_id"],
+ sender=key_values["sender"],
+ content=key_values.get("content", {}),
+ unsigned=key_values.get("unsigned", {}),
+ redacts=key_values.get("redacts", None),
+ )
+
+
+def create_local_event_from_event_dict(clock, hostname, signing_key,
+ format_version, event_dict,
+ internal_metadata_dict=None):
+ """Takes a fully formed event dict, ensuring that fields like `origin`
+ and `origin_server_ts` have correct values for a locally produced event,
+ then signs and hashes it.
+
+ Args:
+ clock (Clock)
+ hostname (str)
+ signing_key
+ format_version (int)
+ event_dict (dict)
+ internal_metadata_dict (dict|None)
+
+ Returns:
+ FrozenEvent
+ """
+
+ # There's currently only the one event version defined
+ if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
+ raise Exception(
+ "No event format defined for version %r" % (format_version,)
+ )
+
+ if internal_metadata_dict is None:
+ internal_metadata_dict = {}
+
+ time_now = int(clock.time_msec())
+
+ if format_version == EventFormatVersions.V1:
+ event_dict["event_id"] = _create_event_id(clock, hostname)
+
+ event_dict["origin"] = hostname
+ event_dict["origin_server_ts"] = time_now
+
+ event_dict.setdefault("unsigned", {})
+ age = event_dict["unsigned"].pop("age", 0)
+ event_dict["unsigned"].setdefault("age_ts", time_now - age)
+
+ event_dict.setdefault("signatures", {})
+
+ add_hashes_and_signatures(
+ event_dict,
+ hostname,
+ signing_key,
+ )
+ return event_type_from_format_version(format_version)(
+ event_dict, internal_metadata_dict=internal_metadata_dict,
+ )
- self.event_id_count = 0
- def create_event_id(self):
- i = str(self.event_id_count)
- self.event_id_count += 1
+# A counter used when generating new event IDs
+_event_id_counter = 0
- local_part = str(int(self.clock.time())) + i + random_string(5)
- e_id = EventID(local_part, self.hostname)
+def _create_event_id(clock, hostname):
+ """Create a new event ID
- return e_id.to_string()
+ Args:
+ clock (Clock)
+ hostname (str): The server name for the event ID
- def new(self, key_values={}):
- key_values["event_id"] = self.create_event_id()
+ Returns:
+ str
+ """
- time_now = int(self.clock.time_msec())
+ global _event_id_counter
- key_values.setdefault("origin", self.hostname)
- key_values.setdefault("origin_server_ts", time_now)
+ i = str(_event_id_counter)
+ _event_id_counter += 1
- key_values.setdefault("unsigned", {})
- age = key_values["unsigned"].pop("age", 0)
- key_values["unsigned"].setdefault("age_ts", time_now - age)
+ local_part = str(int(clock.time())) + i + random_string(5)
- key_values["signatures"] = {}
+ e_id = EventID(local_part, hostname)
- return EventBuilder(key_values=key_values,)
+ return e_id.to_string()
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 652941ca0d..07fccdd8f9 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,31 @@ def prune_event(event):
This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like
type, state_key etc.
+
+ Args:
+ event (FrozenEvent)
+
+ Returns:
+ FrozenEvent
+ """
+ pruned_event_dict = prune_event_dict(event.get_dict())
+
+ from . import event_type_from_format_version
+ return event_type_from_format_version(event.format_version)(
+ pruned_event_dict, event.internal_metadata.get_dict()
+ )
+
+
+def prune_event_dict(event_dict):
+ """Redacts the event_dict in the same way as `prune_event`, except it
+ operates on dicts rather than event objects
+
+ Args:
+ event_dict (dict)
+
+ Returns:
+ dict: A copy of the pruned event dict
"""
- event_type = event.type
allowed_keys = [
"event_id",
@@ -59,13 +82,13 @@ def prune_event(event):
"membership",
]
- event_dict = event.get_dict()
+ event_type = event_dict["type"]
new_content = {}
def add_fields(*fields):
for field in fields:
- if field in event.content:
+ if field in event_dict["content"]:
new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member:
@@ -98,17 +121,17 @@ def prune_event(event):
allowed_fields["content"] = new_content
- allowed_fields["unsigned"] = {}
+ unsigned = {}
+ allowed_fields["unsigned"] = unsigned
- if "age_ts" in event.unsigned:
- allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
- if "replaces_state" in event.unsigned:
- allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
+ event_unsigned = event_dict.get("unsigned", {})
- return type(event)(
- allowed_fields,
- internal_metadata_dict=event.internal_metadata.get_dict()
- )
+ if "age_ts" in event_unsigned:
+ unsigned["age_ts"] = event_unsigned["age_ts"]
+ if "replaces_state" in event_unsigned:
+ unsigned["replaces_state"] = event_unsigned["replaces_state"]
+
+ return allowed_fields
def _copy_field(src, dst, field):
@@ -244,6 +267,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
Returns:
dict
"""
+
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@@ -253,6 +277,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
+ d["event_id"] = e.event_id
+
if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index cf184748a1..a072674b02 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -15,23 +15,29 @@
from six import string_types
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventFormatVersions, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
+ def validate_new(self, event):
+ """Validates the event has roughly the right format
- def validate(self, event):
- EventID.from_string(event.event_id)
- RoomID.from_string(event.room_id)
+ Args:
+ event (FrozenEvent)
+ """
+ self.validate_builder(event)
+
+ if event.format_version == EventFormatVersions.V1:
+ EventID.from_string(event.event_id)
required = [
- # "auth_events",
+ "auth_events",
"content",
- # "hashes",
+ "hashes",
"origin",
- # "prev_events",
+ "prev_events",
"sender",
"type",
]
@@ -41,8 +47,25 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
- strings = [
+ event_strings = [
"origin",
+ ]
+
+ for s in event_strings:
+ if not isinstance(getattr(event, s), string_types):
+ raise SynapseError(400, "'%s' not a string type" % (s,))
+
+ def validate_builder(self, event):
+ """Validates that the builder/event has roughly the right format. Only
+ checks values that we expect a proto event to have, rather than all the
+ fields an event would have
+
+ Args:
+ event (EventBuilder|FrozenEvent)
+ """
+
+ strings = [
+ "room_id",
"sender",
"type",
]
@@ -54,22 +77,7 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
- if event.type == EventTypes.Member:
- if "membership" not in event.content:
- raise SynapseError(400, "Content has not membership key")
-
- if event.content["membership"] not in Membership.LIST:
- raise SynapseError(400, "Invalid membership key")
-
- # Check that the following keys have dictionary values
- # TODO
-
- # Check that the following keys have the correct format for DAGs
- # TODO
-
- def validate_new(self, event):
- self.validate(event)
-
+ RoomID.from_string(event.room_id)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
@@ -86,9 +94,16 @@ class EventValidator(object):
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
+ elif event.type == EventTypes.Member:
+ if "membership" not in event.content:
+ raise SynapseError(400, "Content has not membership key")
+
+ if event.content["membership"] not in Membership.LIST:
+ raise SynapseError(400, "Invalid membership key")
+
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types):
- raise SynapseError(400, "Not '%s' a string type" % (s,))
+ raise SynapseError(400, "'%s' not a string type" % (s,))
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index b7ad729c63..a7a2ec4523 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -20,10 +20,10 @@ import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import FrozenEvent
+from synapse.events import event_type_from_format_version
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id
@@ -43,8 +43,8 @@ class FederationBase(object):
self._clock = hs.get_clock()
@defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
- include_none=False):
+ def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
+ outlier=False, include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -56,13 +56,17 @@ class FederationBase(object):
a new list.
Args:
+ origin (str)
pdu (list)
- outlier (bool)
+ room_version (str)
+ outlier (bool): Whether the events are outliers or not
+ include_none (str): Whether to include None in the returned list
+ for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
- deferreds = self._check_sigs_and_hashes(pdus)
+ deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu, deferred):
@@ -84,6 +88,7 @@ class FederationBase(object):
res = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
+ room_version=room_version,
outlier=outlier,
timeout=10000,
)
@@ -116,16 +121,17 @@ class FederationBase(object):
else:
defer.returnValue([p for p in valid_pdus if p])
- def _check_sigs_and_hash(self, pdu):
+ def _check_sigs_and_hash(self, room_version, pdu):
return logcontext.make_deferred_yieldable(
- self._check_sigs_and_hashes([pdu])[0],
+ self._check_sigs_and_hashes(room_version, [pdu])[0],
)
- def _check_sigs_and_hashes(self, pdus):
+ def _check_sigs_and_hashes(self, room_version, pdus):
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
+ room_version (str): The room version of the PDUs
pdus (list[FrozenEvent]): the events to be checked
Returns:
@@ -136,7 +142,7 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
- deferreds = _check_sigs_on_pdus(self.keyring, pdus)
+ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = logcontext.LoggingContext.current_context()
@@ -198,16 +204,17 @@ class FederationBase(object):
class PduToCheckSig(namedtuple("PduToCheckSig", [
- "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
+ "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
])):
pass
-def _check_sigs_on_pdus(keyring, pdus):
+def _check_sigs_on_pdus(keyring, room_version, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
+ room_version (str): the room version of the PDUs
pdus (Collection[EventBase]): the events to be checked
Returns:
@@ -220,9 +227,7 @@ def _check_sigs_on_pdus(keyring, pdus):
# we want to check that the event is signed by:
#
- # (a) the server which created the event_id
- #
- # (b) the sender's server.
+ # (a) the sender's server
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
@@ -236,34 +241,26 @@ def _check_sigs_on_pdus(keyring, pdus):
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
+ # (b) for V1 and V2 rooms, the server which created the event_id
+ #
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
+
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
- event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
- # first make sure that the event is signed by the event_id's domain
- deferreds = keyring.verify_json_objects_for_server([
- (p.event_id_domain, p.redacted_pdu_json)
- for p in pdus_to_check
- ])
-
- for p, d in zip(pdus_to_check, deferreds):
- p.deferreds.append(d)
-
- # now let's look for events where the sender's domain is different to the
- # event id's domain (normally only the case for joins/leaves), and add additional
- # checks.
+ # First we check that the sender event is signed by the sender's domain
+ # (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [
p for p in pdus_to_check
- if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
+ if not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
@@ -274,19 +271,43 @@ def _check_sigs_on_pdus(keyring, pdus):
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
+ # now let's look for events where the sender's domain is different to the
+ # event id's domain (normally only the case for joins/leaves), and add additional
+ # checks. Only do this if the room version has a concept of event ID domain
+ if room_version in (
+ RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
+ ):
+ pdus_to_check_event_id = [
+ p for p in pdus_to_check
+ if p.sender_domain != get_domain_from_id(p.pdu.event_id)
+ ]
+
+ more_deferreds = keyring.verify_json_objects_for_server([
+ (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+ for p in pdus_to_check_event_id
+ ])
+
+ for p, d in zip(pdus_to_check_event_id, more_deferreds):
+ p.deferreds.append(d)
+ elif room_version in (RoomVersions.V3,):
+ pass # No further checks needed, as event IDs are hashes here
+ else:
+ raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
- """Given a list of one or more deferreds, either return the single deferred, or
- combine into a DeferredList.
+ """Given a list of deferreds, either return the single deferred,
+ combine into a DeferredList, or return an already resolved deferred.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
- else:
- assert len(deferreds) == 1
+ elif len(deferreds) == 1:
return deferreds[0]
+ else:
+ return defer.succeed(None)
def _is_invite_via_3pid(event):
@@ -297,11 +318,12 @@ def _is_invite_via_3pid(event):
)
-def event_from_pdu_json(pdu_json, outlier=False):
+def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
Args:
pdu_json (object): pdu as received over federation
+ event_format_version (int): The event format version
outlier (bool): True to mark this event as an outlier
Returns:
@@ -313,7 +335,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
+ assert_params_in_dict(pdu_json, ('type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
@@ -325,8 +347,8 @@ def event_from_pdu_json(pdu_json, outlier=False):
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
- event = FrozenEvent(
- pdu_json
+ event = event_type_from_format_version(event_format_version)(
+ pdu_json,
)
event.internal_metadata.outlier = outlier
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d05ed91d64..4e4f58b418 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,14 +25,19 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
+from synapse.api.constants import (
+ KNOWN_ROOM_VERSIONS,
+ EventTypes,
+ Membership,
+ RoomVersions,
+)
from synapse.api.errors import (
CodeMessageException,
FederationDeniedError,
HttpResponseException,
SynapseError,
)
-from synapse.events import builder
+from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@@ -66,6 +71,9 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
+ self.hostname = hs.hostname
+ self.signing_key = hs.config.signing_key[0]
+
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
clock=self._clock,
@@ -162,13 +170,13 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
- def backfill(self, dest, context, limit, extremities):
+ def backfill(self, dest, room_id, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
- context (str): The context to backfill.
+ room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
@@ -183,18 +191,21 @@ class FederationClient(FederationBase):
return
transaction_data = yield self.transport_layer.backfill(
- dest, context, extremities, limit)
+ dest, room_id, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
pdus = [
- event_from_pdu_json(p, outlier=False)
+ event_from_pdu_json(p, format_ver, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- self._check_sigs_and_hashes(pdus),
+ self._check_sigs_and_hashes(room_version, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError))
@@ -202,7 +213,8 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
+ def get_pdu(self, destinations, event_id, room_version, outlier=False,
+ timeout=None):
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -212,6 +224,7 @@ class FederationClient(FederationBase):
Args:
destinations (list): Which home servers to query
event_id (str): event to fetch
+ room_version (str): version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -230,6 +243,8 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+ format_ver = room_version_to_event_format(room_version)
+
signed_pdu = None
for destination in destinations:
now = self._clock.time_msec()
@@ -245,7 +260,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
- event_from_pdu_json(p, outlier=outlier)
+ event_from_pdu_json(p, format_ver, outlier=outlier)
for p in transaction_data["pdus"]
]
@@ -253,7 +268,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hash(pdu)
+ signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
break
@@ -339,12 +354,16 @@ class FederationClient(FederationBase):
destination, room_id, event_id=event_id,
)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
pdus = [
- event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ event_from_pdu_json(p, format_ver, outlier=True)
+ for p in result["pdus"]
]
auth_chain = [
- event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, format_ver, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -355,7 +374,8 @@ class FederationClient(FederationBase):
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in pdus if p.event_id not in seen_events],
- outlier=True
+ outlier=True,
+ room_version=room_version,
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
@@ -364,7 +384,8 @@ class FederationClient(FederationBase):
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in auth_chain if p.event_id not in seen_events],
- outlier=True
+ outlier=True,
+ room_version=room_version,
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
@@ -411,6 +432,8 @@ class FederationClient(FederationBase):
random.shuffle(srvs)
return srvs
+ room_version = yield self.store.get_room_version(room_id)
+
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
@@ -421,6 +444,7 @@ class FederationClient(FederationBase):
self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
+ room_version=room_version,
)
for e_id in batch
]
@@ -445,13 +469,17 @@ class FederationClient(FederationBase):
destination, room_id, event_id,
)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
auth_chain = [
- event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, format_ver, outlier=True)
for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination, auth_chain,
+ outlier=True, room_version=room_version,
)
signed_auth.sort(key=lambda e: e.depth)
@@ -522,6 +550,8 @@ class FederationClient(FederationBase):
Does so by asking one of the already participating servers to create an
event with proper context.
+ Returns a fully signed and hashed event.
+
Note that this does not append any events to any graphs.
Args:
@@ -536,8 +566,10 @@ class FederationClient(FederationBase):
params (dict[str, str|Iterable[str]]): Query parameters to include in the
request.
Return:
- Deferred: resolves to a tuple of (origin (str), event (object))
- where origin is the remote homeserver which generated the event.
+ Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
+ `(origin, event, event_format)` where origin is the remote
+ homeserver which generated the event, and event_format is one of
+ `synapse.api.constants.EventFormatVersions`.
Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code.
@@ -557,6 +589,11 @@ class FederationClient(FederationBase):
destination, room_id, user_id, membership, params,
)
+ # Note: If not supplied, the room version may be either v1 or v2,
+ # however either way the event format version will be v1.
+ room_version = ret.get("room_version", RoomVersions.V1)
+ event_format = room_version_to_event_format(room_version)
+
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@@ -571,17 +608,20 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
- ev = builder.EventBuilder(pdu_dict)
+ ev = builder.create_local_event_from_event_dict(
+ self._clock, self.hostname, self.signing_key,
+ format_version=event_format, event_dict=pdu_dict,
+ )
defer.returnValue(
- (destination, ev)
+ (destination, ev, event_format)
)
return self._try_destination_list(
"make_" + membership, destinations, send_request,
)
- def send_join(self, destinations, pdu):
+ def send_join(self, destinations, pdu, event_format_version):
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -591,6 +631,7 @@ class FederationClient(FederationBase):
destinations (str): Candidate homeservers which are probably
participating in the room.
pdu (BaseEvent): event to be sent
+ event_format_version (int): The event format version
Return:
Deferred: resolves to a dict with members ``origin`` (a string
@@ -636,12 +677,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
- event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- event_from_pdu_json(p, outlier=True)
+ event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -650,9 +691,21 @@ class FederationClient(FederationBase):
for p in itertools.chain(state, auth_chain)
}
+ room_version = None
+ for e in state:
+ if (e.type, e.state_key) == (EventTypes.Create, ""):
+ room_version = e.content.get("room_version", RoomVersions.V1)
+ break
+
+ if room_version is None:
+ # If the state doesn't have a create event then the room is
+ # invalid, and it would fail auth checks anyway.
+ raise SynapseError(400, "No create event in state")
+
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
+ room_version=room_version,
)
valid_pdus_map = {
@@ -690,32 +743,75 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
- time_now = self._clock.time_msec()
- try:
- code, content = yield self.transport_layer.send_invite(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=pdu.get_pdu_json(time_now),
- )
- except HttpResponseException as e:
- if e.code == 403:
- raise e.to_synapse_error()
- raise
+ room_version = yield self.store.get_room_version(room_id)
+
+ content = yield self._do_send_invite(destination, pdu, room_version)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
- pdu = event_from_pdu_json(pdu_dict)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
+ pdu = event_from_pdu_json(pdu_dict, format_ver)
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
+ @defer.inlineCallbacks
+ def _do_send_invite(self, destination, pdu, room_version):
+ """Actually sends the invite, first trying v2 API and falling back to
+ v1 API if necessary.
+
+ Args:
+ destination (str): Target server
+ pdu (FrozenEvent)
+ room_version (str)
+
+ Returns:
+ dict: The event as a dict as returned by the remote server
+ """
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_invite_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content={
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version,
+ "invite_room_state": pdu.unsigned.get("invite_room_state", []),
+ },
+ )
+ defer.returnValue(content)
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ if room_version in (RoomVersions.V1, RoomVersions.V2):
+ pass # We'll fall through
+ else:
+ raise Exception("Remote server is too old")
+ elif e.code == 403:
+ raise e.to_synapse_error()
+ else:
+ raise
+
+ # Didn't work, try v1 API.
+ # Note the v1 API returns a tuple of `(200, content)`
+
+ _, content = yield self.transport_layer.send_invite_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+ defer.returnValue(content)
+
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@@ -785,13 +881,16 @@ class FederationClient(FederationBase):
content=send_content,
)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
auth_chain = [
- event_from_pdu_json(e)
+ event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination, auth_chain, outlier=True, room_version=room_version,
)
signed_auth.sort(key=lambda e: e.depth)
@@ -833,13 +932,16 @@ class FederationClient(FederationBase):
timeout=timeout,
)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
events = [
- event_from_pdu_json(e)
+ event_from_pdu_json(e, format_ver)
for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False
+ destination, events, outlier=False, room_version=room_version,
)
except HttpResponseException as e:
if not e.code == 400:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 37d29e7027..3da86d4ba6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
FederationError,
@@ -34,6 +34,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.crypto.event_signing import compute_event_signature
+from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
@@ -147,6 +148,22 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
+ # Reject if PDU count > 50 and EDU count > 100
+ if (len(transaction.pdus) > 50
+ or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+
+ logger.info(
+ "Transaction PDU or EDU count too large. Returning 400",
+ )
+
+ response = {}
+ yield self.transaction_actions.set_response(
+ origin,
+ transaction,
+ 400, response
+ )
+ defer.returnValue((400, response))
+
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@@ -178,14 +195,13 @@ class FederationServer(FederationBase):
continue
try:
- # In future we will actually use the room version to parse the
- # PDU into an event.
- yield self.store.get_room_version(room_id)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
- event = event_from_pdu_json(p)
+ event = event_from_pdu_json(p, format_ver)
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
@@ -322,7 +338,7 @@ class FederationServer(FederationBase):
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
- event,
+ event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@@ -370,7 +386,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version):
- pdu = event_from_pdu_json(content)
+ format_ver = room_version_to_event_format(room_version)
+
+ pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
@@ -378,9 +396,12 @@ class FederationServer(FederationBase):
defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
- def on_send_join_request(self, origin, content):
+ def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content)
- pdu = event_from_pdu_json(content)
+
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+ pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
@@ -400,13 +421,22 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
+
+ room_version = yield self.store.get_room_version(room_id)
+
time_now = self._clock.time_msec()
- defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+ defer.returnValue({
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version,
+ })
@defer.inlineCallbacks
- def on_send_leave_request(self, origin, content):
+ def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content)
- pdu = event_from_pdu_json(content)
+
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+ pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
@@ -452,13 +482,16 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
+ room_version = yield self.store.get_room_version(room_id)
+ format_ver = room_version_to_event_format(room_version)
+
auth_chain = [
- event_from_pdu_json(e)
+ event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True
+ origin, auth_chain, outlier=True, room_version=room_version,
)
ret = yield self.handler.on_query_auth(
@@ -603,16 +636,19 @@ class FederationServer(FederationBase):
"""
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
- if origin != get_domain_from_id(pdu.event_id):
+ if origin != get_domain_from_id(pdu.sender):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
- # origin. See bug #1893).
+ # origin. See bug #1893. This is also true for some third party
+ # invites).
if not (
pdu.type == 'm.room.member' and
pdu.content and
- pdu.content.get("membership", None) == 'join'
+ pdu.content.get("membership", None) in (
+ Membership.JOIN, Membership.INVITE,
+ )
):
logger.info(
"Discarding PDU %s from invalid origin %s",
@@ -625,9 +661,12 @@ class FederationServer(FederationBase):
pdu.event_id, origin
)
+ # We've already checked that we know the room version by this point
+ room_version = yield self.store.get_room_version(pdu.room_id)
+
# Check signature.
try:
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError(
"ERROR",
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index fe787abaeb..1f0b67f5f8 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -175,7 +175,7 @@ class TransactionQueue(object):
def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
- is_mine = self.is_mine_id(event.event_id)
+ is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None:
return
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 260178c47b..8e2be218e2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -21,7 +21,7 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.urls import FEDERATION_V1_PREFIX
+from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.util.logutils import log_function
logger = logging.getLogger(__name__)
@@ -289,7 +289,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_invite(self, destination, room_id, event_id, content):
+ def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -303,6 +303,20 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def send_invite_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/invite/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
@@ -958,3 +972,24 @@ def _create_v1_path(path, *args):
FEDERATION_V1_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)
+
+
+def _create_v2_path(path, *args):
+ """Creates a path against V2 federation API from the path template and
+ args. Ensures that all args are url encoded.
+
+ Example:
+
+ _create_v2_path("/event/%s/", event_id)
+
+ Args:
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return (
+ FEDERATION_V2_PREFIX
+ + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9c1c9c39e1..a2396ab466 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -481,7 +481,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, event_id):
- content = yield self.handler.on_send_leave_request(origin, content)
+ content = yield self.handler.on_send_leave_request(origin, content, room_id)
defer.returnValue((200, content))
@@ -499,7 +499,7 @@ class FederationSendJoinServlet(BaseFederationServlet):
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
- content = yield self.handler.on_send_join_request(origin, content)
+ content = yield self.handler.on_send_join_request(origin, content, context)
defer.returnValue((200, content))
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
new file mode 100644
index 0000000000..dd0b217965
--- /dev/null
+++ b/synapse/handlers/acme.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+from zope.interface import implementer
+
+import twisted
+import twisted.internet.error
+from twisted.internet import defer
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+from twisted.web import server, static
+from twisted.web.resource import Resource
+
+from synapse.app import check_bind_error
+
+logger = logging.getLogger(__name__)
+
+try:
+ from txacme.interfaces import ICertificateStore
+
+ @attr.s
+ @implementer(ICertificateStore)
+ class ErsatzStore(object):
+ """
+ A store that only stores in memory.
+ """
+
+ certs = attr.ib(default=attr.Factory(dict))
+
+ def store(self, server_name, pem_objects):
+ self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+ return defer.succeed(None)
+
+
+except ImportError:
+ # txacme is missing
+ pass
+
+
+class AcmeHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.reactor = hs.get_reactor()
+
+ @defer.inlineCallbacks
+ def start_listening(self):
+
+ # Configure logging for txacme, if you need to debug
+ # from eliot import add_destinations
+ # from eliot.twisted import TwistedDestination
+ #
+ # add_destinations(TwistedDestination())
+
+ from txacme.challenges import HTTP01Responder
+ from txacme.service import AcmeIssuingService
+ from txacme.endpoint import load_or_create_client_key
+ from txacme.client import Client
+ from josepy.jwa import RS256
+
+ self._store = ErsatzStore()
+ responder = HTTP01Responder()
+
+ self._issuer = AcmeIssuingService(
+ cert_store=self._store,
+ client_creator=(
+ lambda: Client.from_url(
+ reactor=self.reactor,
+ url=URL.from_text(self.hs.config.acme_url),
+ key=load_or_create_client_key(
+ FilePath(self.hs.config.config_dir_path)
+ ),
+ alg=RS256,
+ )
+ ),
+ clock=self.reactor,
+ responders=[responder],
+ )
+
+ well_known = Resource()
+ well_known.putChild(b'acme-challenge', responder.resource)
+ responder_resource = Resource()
+ responder_resource.putChild(b'.well-known', well_known)
+ responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
+
+ srv = server.Site(responder_resource)
+
+ bind_addresses = self.hs.config.acme_bind_addresses
+ for host in bind_addresses:
+ logger.info(
+ "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
+ )
+ try:
+ self.reactor.listenTCP(
+ self.hs.config.acme_port,
+ srv,
+ interface=host,
+ )
+ except twisted.internet.error.CannotListenError as e:
+ check_bind_error(e, host, bind_addresses)
+
+ # Make sure we are registered to the ACME server. There's no public API
+ # for this, it is usually triggered by startService, but since we don't
+ # want it to control where we save the certificates, we have to reach in
+ # and trigger the registration machinery ourselves.
+ self._issuer._registered = False
+ yield self._issuer._ensure_registered()
+
+ @defer.inlineCallbacks
+ def provision_certificate(self):
+
+ logger.warning("Reprovisioning %s", self.hs.hostname)
+
+ try:
+ yield self._issuer.issue_cert(self.hs.hostname)
+ except Exception:
+ logger.exception("Fail!")
+ raise
+ logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
+ cert_chain = self._store.certs[self.hs.hostname]
+
+ try:
+ with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
+ for x in cert_chain:
+ if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
+ private_key_file.write(x)
+
+ with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
+ for x in cert_chain:
+ if x.startswith(b"-----BEGIN CERTIFICATE-----"):
+ certificate_file.write(x)
+ except Exception:
+ logger.exception("Failed saving!")
+ raise
+
+ defer.returnValue(True)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 0699731c13..6bb254f899 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -57,8 +57,8 @@ class DirectoryHandler(BaseHandler):
# general association creation for both human users and app services
for wchar in string.whitespace:
- if wchar in room_alias.localpart:
- raise SynapseError(400, "Invalid characters in room alias")
+ if wchar in room_alias.localpart:
+ raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a3bb864bb2..083f2e0ac3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -34,6 +34,7 @@ from synapse.api.constants import (
EventTypes,
Membership,
RejectedReason,
+ RoomVersions,
)
from synapse.api.errors import (
AuthError,
@@ -43,10 +44,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.crypto.event_signing import (
- add_hashes_and_signatures,
- compute_event_signature,
-)
+from synapse.crypto.event_signing import compute_event_signature
from synapse.events.validator import EventValidator
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
@@ -58,7 +56,6 @@ from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
-from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -105,7 +102,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
- self.store = hs.get_datastore() # type: synapse.storage.DataStore
+ self.store = hs.get_datastore()
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -342,6 +339,8 @@ class FederationHandler(BaseHandler):
room_id, event_id, p,
)
+ room_version = yield self.store.get_room_version(room_id)
+
with logcontext.nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
@@ -355,7 +354,7 @@ class FederationHandler(BaseHandler):
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
- [origin], p, outlier=True,
+ [origin], p, room_version, outlier=True,
)
if remote_event is None:
@@ -379,7 +378,6 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_store(
room_version, state_maps, event_map,
state_res_store=StateResolutionStore(self.store),
@@ -655,6 +653,8 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
+ room_version = yield self.store.get_room_version(room_id)
+
events = yield self.federation_client.backfill(
dest,
room_id,
@@ -748,6 +748,7 @@ class FederationHandler(BaseHandler):
self.federation_client.get_pdu,
[dest],
event_id,
+ room_version=room_version,
outlier=True,
timeout=10000,
)
@@ -1060,7 +1061,7 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
- origin, event = yield self._make_and_verify_event(
+ origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts,
room_id,
joinee,
@@ -1083,7 +1084,6 @@ class FederationHandler(BaseHandler):
handled_events = set()
try:
- event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
@@ -1091,7 +1091,9 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
- ret = yield self.federation_client.send_join(target_hosts, event)
+ ret = yield self.federation_client.send_join(
+ target_hosts, event, event_format_version,
+ )
origin = ret["origin"]
state = ret["state"]
@@ -1164,13 +1166,18 @@ class FederationHandler(BaseHandler):
"""
event_content = {"membership": Membership.JOIN}
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "content": event_content,
- "room_id": room_id,
- "sender": user_id,
- "state_key": user_id,
- })
+ room_version = yield self.store.get_room_version(room_id)
+
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": event_content,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ }
+ )
try:
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -1182,7 +1189,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- yield self.auth.check_from_context(event, context, do_sig_check=False)
+ yield self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False,
+ )
defer.returnValue(event)
@@ -1287,11 +1296,11 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = True
- event.internal_metadata.invite_from_remote = True
+ event.internal_metadata.out_of_band_membership = True
event.signatures.update(
compute_event_signature(
- event,
+ event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@@ -1304,7 +1313,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
- origin, event = yield self._make_and_verify_event(
+ origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
@@ -1313,7 +1322,7 @@ class FederationHandler(BaseHandler):
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
event.internal_metadata.outlier = True
- event = self._sign_event(event)
+ event.internal_metadata.out_of_band_membership = True
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
@@ -1336,7 +1345,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={}, params=None):
- origin, pdu = yield self.federation_client.make_membership_event(
+ origin, event, format_ver = yield self.federation_client.make_membership_event(
target_hosts,
room_id,
user_id,
@@ -1345,9 +1354,7 @@ class FederationHandler(BaseHandler):
params=params,
)
- logger.debug("Got response to make_%s: %s", membership, pdu)
-
- event = pdu
+ logger.debug("Got response to make_%s: %s", membership, event)
# We should assert some things.
# FIXME: Do this in a nicer way
@@ -1355,28 +1362,7 @@ class FederationHandler(BaseHandler):
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
- defer.returnValue((origin, event))
-
- def _sign_event(self, event):
- event.internal_metadata.outlier = False
-
- builder = self.event_builder_factory.new(
- unfreeze(event.get_pdu_json())
- )
-
- builder.event_id = self.event_builder_factory.create_event_id()
- builder.origin = self.hs.hostname
-
- if not hasattr(event, "signatures"):
- builder.signatures = {}
-
- add_hashes_and_signatures(
- builder,
- self.hs.hostname,
- self.hs.config.signing_key[0],
- )
-
- return builder.build()
+ defer.returnValue((origin, event, format_ver))
@defer.inlineCallbacks
@log_function
@@ -1385,13 +1371,17 @@ class FederationHandler(BaseHandler):
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "content": {"membership": Membership.LEAVE},
- "room_id": room_id,
- "sender": user_id,
- "state_key": user_id,
- })
+ room_version = yield self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.new(
+ room_version,
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.LEAVE},
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": user_id,
+ }
+ )
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
@@ -1400,7 +1390,9 @@ class FederationHandler(BaseHandler):
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- yield self.auth.check_from_context(event, context, do_sig_check=False)
+ yield self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False,
+ )
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -1659,6 +1651,13 @@ class FederationHandler(BaseHandler):
create_event = e
break
+ if create_event is None:
+ # If the state doesn't have a create event then the room is
+ # invalid, and it would fail auth checks anyway.
+ raise SynapseError(400, "No create event in state")
+
+ room_version = create_event.content.get("room_version", RoomVersions.V1)
+
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id in e.auth_event_ids():
@@ -1669,6 +1668,7 @@ class FederationHandler(BaseHandler):
m_ev = yield self.federation_client.get_pdu(
[origin],
e_id,
+ room_version=room_version,
outlier=True,
timeout=10000,
)
@@ -1687,7 +1687,7 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event
try:
- self.auth.check(e, auth_events=auth_for_e)
+ self.auth.check(room_version, e, auth_events=auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
@@ -1931,6 +1931,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
+ room_version = yield self.store.get_room_version(event.room_id)
+
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
@@ -1955,8 +1957,6 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
- room_version = yield self.store.get_room_version(event.room_id)
-
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
@@ -2056,7 +2056,7 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=auth_events)
+ self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@@ -2279,18 +2279,26 @@ class FederationHandler(BaseHandler):
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
- builder = self.event_builder_factory.new(event_dict)
- EventValidator().validate_new(builder)
+ room_version = yield self.store.get_room_version(room_id)
+ builder = self.event_builder_factory.new(room_version, event_dict)
+
+ EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
event, context = yield self.add_display_name_to_third_party_invite(
- event_dict, event, context
+ room_version, event_dict, event, context
)
+ EventValidator().validate_new(event)
+
+ # We need to tell the transaction queue to send this out, even
+ # though the sender isn't a local user.
+ event.internal_metadata.send_on_behalf_of = self.hs.hostname
+
try:
- yield self.auth.check_from_context(event, context)
+ yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
@@ -2317,23 +2325,31 @@ class FederationHandler(BaseHandler):
Returns:
Deferred: resolves (to None)
"""
- builder = self.event_builder_factory.new(event_dict)
+ room_version = yield self.store.get_room_version(room_id)
+
+ # NB: event_dict has a particular specced format we might need to fudge
+ # if we change event formats too much.
+ builder = self.event_builder_factory.new(room_version, event_dict)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
event, context = yield self.add_display_name_to_third_party_invite(
- event_dict, event, context
+ room_version, event_dict, event, context
)
try:
- self.auth.check_from_context(event, context)
+ self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
+ # We need to tell the transaction queue to send this out, even
+ # though the sender isn't a local user.
+ event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+
# XXX we send the invite here, but send_membership_event also sends it,
# so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
@@ -2344,7 +2360,8 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
- def add_display_name_to_third_party_invite(self, event_dict, event, context):
+ def add_display_name_to_third_party_invite(self, room_version, event_dict,
+ event, context):
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
@@ -2368,11 +2385,12 @@ class FederationHandler(BaseHandler):
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
- builder = self.event_builder_factory.new(event_dict)
- EventValidator().validate_new(builder)
+ builder = self.event_builder_factory.new(room_version, event_dict)
+ EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
+ EventValidator().validate_new(event)
defer.returnValue((event, context))
@defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a7cd779b02..3981fe69ce 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder
-from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -278,9 +277,17 @@ class EventCreationHandler(object):
"""
yield self.auth.check_auth_blocking(requester.user.to_string())
- builder = self.event_builder_factory.new(event_dict)
+ if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
+ room_version = event_dict["content"]["room_version"]
+ else:
+ try:
+ room_version = yield self.store.get_room_version(event_dict["room_id"])
+ except NotFoundError:
+ raise AuthError(403, "Unknown room")
- self.validator.validate_new(builder)
+ builder = self.event_builder_factory.new(room_version, event_dict)
+
+ self.validator.validate_builder(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
@@ -318,6 +325,8 @@ class EventCreationHandler(object):
prev_events_and_hashes=prev_events_and_hashes,
)
+ self.validator.validate_new(event)
+
defer.returnValue((event, context))
def _is_exempt_from_privacy_policy(self, builder, requester):
@@ -535,40 +544,19 @@ class EventCreationHandler(object):
prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id)
- if prev_events_and_hashes:
- depth = max([d for _, _, d in prev_events_and_hashes]) + 1
- # we cap depth of generated events, to ensure that they are not
- # rejected by other servers (and so that they can be persisted in
- # the db)
- depth = min(depth, MAX_DEPTH)
- else:
- depth = 1
-
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
- builder.prev_events = prev_events
- builder.depth = depth
-
- context = yield self.state.compute_event_context(builder)
+ event = yield builder.build(
+ prev_event_ids=[p for p, _ in prev_events],
+ )
+ context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
- if builder.is_state():
- builder.prev_state = yield self.store.add_event_hashes(
- context.prev_state_events
- )
-
- yield self.auth.add_auth_events(builder, context)
-
- signing_key = self.hs.config.signing_key[0]
- add_hashes_and_signatures(
- builder, self.server_name, signing_key
- )
-
- event = builder.build()
+ self.validator.validate_new(event)
logger.debug(
"Created event %s",
@@ -603,8 +591,13 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
+ if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
+ room_version = event.content.get("room_version", RoomVersions.V1)
+ else:
+ room_version = yield self.store.get_room_version(event.room_id)
+
try:
- yield self.auth.check_from_context(event, context)
+ yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
@@ -752,7 +745,8 @@ class EventCreationHandler(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
- if self.auth.check_redaction(event, auth_events=auth_events):
+ room_version = yield self.store.get_room_version(event.room_id)
+ if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@@ -766,6 +760,9 @@ class EventCreationHandler(object):
"You don't have permission to redact events"
)
+ # We've already checked.
+ event.internal_metadata.recheck_redaction = False
+
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cb8c5f77dd..5e40e9ea46 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -123,9 +123,12 @@ class RoomCreationHandler(BaseHandler):
token_id=requester.access_token_id,
)
)
- yield self.auth.check_from_context(tombstone_event, tombstone_context)
+ old_room_version = yield self.store.get_room_version(old_room_id)
+ yield self.auth.check_from_context(
+ old_room_version, tombstone_event, tombstone_context,
+ )
- yield self.clone_exiting_room(
+ yield self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@@ -230,7 +233,7 @@ class RoomCreationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def clone_exiting_room(
+ def clone_existing_room(
self, requester, old_room_id, new_room_id, new_room_version,
tombstone_event_id,
):
@@ -260,8 +263,19 @@ class RoomCreationHandler(BaseHandler):
}
}
+ # Check if old room was non-federatable
+
+ # Get old room's create event
+ old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
+
+ # Check if the create event specified a non-federatable room
+ if not old_room_create_event.content.get("m.federate", True):
+ # If so, mark the new room as non-federatable as well
+ creation_content["m.federate"] = False
+
initial_state = dict()
+ # Replicate relevant room events
types_to_copy = (
(EventTypes.JoinRules, ""),
(EventTypes.Name, ""),
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index dc88620885..13e212d669 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -73,8 +73,14 @@ class RoomListHandler(BaseHandler):
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
logger.info("Bypassing cache as search request.")
+
+ # XXX: Quick hack to stop room directory queries taking too long.
+ # Timeout request after 60s. Probably want a more fundamental
+ # solution at some point
+ timeout = self.clock.time() + 60
return self._get_public_room_list(
- limit, since_token, search_filter, network_tuple=network_tuple,
+ limit, since_token, search_filter,
+ network_tuple=network_tuple, timeout=timeout,
)
key = (limit, since_token, network_tuple)
@@ -87,7 +93,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ timeout=None,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@@ -202,6 +209,9 @@ class RoomListHandler(BaseHandler):
chunk = []
for i in range(0, len(rooms_to_scan), step):
+ if timeout and self.clock.time() > timeout:
+ raise Exception("Timed out searching room directory")
+
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 07fd3e82fc..2beffdf41e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,7 +63,7 @@ class RoomMemberHandler(object):
self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler()
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
@@ -161,6 +161,8 @@ class RoomMemberHandler(object):
ratelimit=True,
content=None,
):
+ user_id = target.to_string()
+
if content is None:
content = {}
@@ -168,14 +170,14 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield self.event_creation_hander.create_event(
+ event, context = yield self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
- "state_key": target.to_string(),
+ "state_key": user_id,
# For backwards compatibility:
"membership": membership,
@@ -186,14 +188,14 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ duplicate = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield self.event_creation_hander.handle_new_client_event(
+ yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@@ -204,12 +206,12 @@ class RoomMemberHandler(object):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, target.to_string()),
+ (EventTypes.Member, user_id),
None
)
if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
+ # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@@ -218,6 +220,18 @@ class RoomMemberHandler(object):
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
+
+ # Copy over direct message status and room tags if this is a join
+ # on an upgraded room
+
+ # Check if this is an upgraded room
+ predecessor = yield self.store.get_room_predecessor(room_id)
+
+ if predecessor:
+ # It is an upgraded room. Copy over old tags
+ self.copy_room_tags_and_direct_to_room(
+ predecessor["room_id"], room_id, user_id,
+ )
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@@ -227,6 +241,55 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
+ def copy_room_tags_and_direct_to_room(
+ self,
+ old_room_id,
+ new_room_id,
+ user_id,
+ ):
+ """Copies the tags and direct room state from one room to another.
+
+ Args:
+ old_room_id (str)
+ new_room_id (str)
+ user_id (str)
+
+ Returns:
+ Deferred[None]
+ """
+ # Retrieve user account data for predecessor room
+ user_account_data, _ = yield self.store.get_account_data_for_user(
+ user_id,
+ )
+
+ # Copy direct message state if applicable
+ direct_rooms = user_account_data.get("m.direct", {})
+
+ # Check which key this room is under
+ if isinstance(direct_rooms, dict):
+ for key, room_id_list in direct_rooms.items():
+ if old_room_id in room_id_list and new_room_id not in room_id_list:
+ # Add new room_id to this key
+ direct_rooms[key].append(new_room_id)
+
+ # Save back to user's m.direct account data
+ yield self.store.add_account_data_for_user(
+ user_id, "m.direct", direct_rooms,
+ )
+ break
+
+ # Copy room tags if applicable
+ room_tags = yield self.store.get_tags_for_room(
+ user_id, old_room_id,
+ )
+
+ # Copy each room tag to the new room
+ for tag, tag_content in room_tags.items():
+ yield self.store.add_tag_to_room(
+ user_id, new_room_id, tag, tag_content
+ )
+
+ @defer.inlineCallbacks
def update_membership(
self,
requester,
@@ -493,7 +556,7 @@ class RoomMemberHandler(object):
else:
requester = synapse.types.create_requester(target_user)
- prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context,
)
if prev_event is not None:
@@ -513,7 +576,7 @@ class RoomMemberHandler(object):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield self.event_creation_hander.handle_new_client_event(
+ yield self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
@@ -527,7 +590,7 @@ class RoomMemberHandler(object):
)
if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally joined the
+ # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
@@ -755,7 +818,7 @@ class RoomMemberHandler(object):
)
)
- yield self.event_creation_hander.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -877,7 +940,8 @@ class RoomMemberHandler(object):
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
- defer.returnValue(self.hs.is_mine_id(create_event_id))
+ # We can only get here if we're in the process of creating the room
+ defer.returnValue(True)
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ec936bbb4e..49c439313e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -38,6 +38,41 @@ class SearchHandler(BaseHandler):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
+ def get_old_rooms_from_upgraded_room(self, room_id):
+ """Retrieves room IDs of old rooms in the history of an upgraded room.
+
+ We do so by checking the m.room.create event of the room for a
+ `predecessor` key. If it exists, we add the room ID to our return
+ list and then check that room for a m.room.create event and so on
+ until we can no longer find any more previous rooms.
+
+ The full list of all found rooms in then returned.
+
+ Args:
+ room_id (str): id of the room to search through.
+
+ Returns:
+ Deferred[iterable[unicode]]: predecessor room ids
+ """
+
+ historical_room_ids = []
+
+ while True:
+ predecessor = yield self.store.get_room_predecessor(room_id)
+
+ # If no predecessor, assume we've hit a dead end
+ if not predecessor:
+ break
+
+ # Add predecessor's room ID
+ historical_room_ids.append(predecessor["room_id"])
+
+ # Scan through the old room for further predecessors
+ room_id = predecessor["room_id"]
+
+ defer.returnValue(historical_room_ids)
+
+ @defer.inlineCallbacks
def search(self, user, content, batch=None):
"""Performs a full text search for a user.
@@ -137,6 +172,18 @@ class SearchHandler(BaseHandler):
)
room_ids = set(r.room_id for r in rooms)
+ # If doing a subset of all rooms seearch, check if any of the rooms
+ # are from an upgraded room, and search their contents as well
+ if search_filter.rooms:
+ historical_room_ids = []
+ for room_id in search_filter.rooms:
+ # Add any previous rooms to the search if they exist
+ ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+ historical_room_ids += ids
+
+ # Prevent any historical events from being filtered
+ search_filter = search_filter.with_room_ids(historical_room_ids)
+
room_ids = search_filter.filter_rooms(room_ids)
if batch_group == "room_id":
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f7f768f751..bd97241ab4 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -895,14 +895,17 @@ class SyncHandler(object):
Returns:
Deferred(SyncResult)
"""
- logger.info("Calculating sync response for %r", sync_config.user)
-
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
+ logger.info(
+ "Calculating sync response for %r between %s and %s",
+ sync_config.user, since_token, now_token,
+ )
+
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
@@ -1390,6 +1393,12 @@ class SyncHandler(object):
room_entries = []
invited = []
for room_id, events in iteritems(mem_change_events_by_room_id):
+ logger.info(
+ "Membership changes in %s: [%s]",
+ room_id,
+ ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
+ )
+
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
@@ -1473,10 +1482,22 @@ class SyncHandler(object):
if since_token and since_token.is_after(leave_token):
continue
+ # If this is an out of band message, like a remote invite
+ # rejection, we include it in the recents batch. Otherwise, we
+ # let _load_filtered_recents handle fetching the correct
+ # batches.
+ #
+ # This is all screaming out for a refactor, as the logic here is
+ # subtle and the moving parts numerous.
+ if leave_event.internal_metadata.is_out_of_band_membership():
+ batch_events = [leave_event]
+ else:
+ batch_events = None
+
room_entries.append(RoomSyncResultBuilder(
room_id=room_id,
rtype="archived",
- events=None,
+ events=batch_events,
newly_joined=room_id in newly_joined_rooms,
full_state=False,
since_token=since_token,
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 3c40999338..120815b09b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer
+import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
@@ -163,6 +164,11 @@ class UserDirectoryHandler(object):
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
+
+ # Expose current event processing position to prometheus
+ synapse.metrics.event_processing_positions.labels(
+ "user_dir").set(self.pos)
+
yield self.store.update_user_directory_stream_pos(self.pos)
@defer.inlineCallbacks
diff --git a/synapse/http/client.py b/synapse/http/client.py
index afcf698b29..47a1f82ff0 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -333,9 +333,10 @@ class SimpleHttpClient(object):
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)
+ body = yield make_deferred_yieldable(readBody(response))
+
if 200 <= response.code < 300:
- body = yield make_deferred_yieldable(treq.json_content(response))
- defer.returnValue(body)
+ defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 815f8ff2f7..cd79ebab62 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -13,15 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import random
import re
-from twisted.internet import defer
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.error import ConnectError
-
-from synapse.http.federation.srv_resolver import Server, resolve_service
-
logger = logging.getLogger(__name__)
@@ -88,140 +81,3 @@ def parse_and_validate_server_name(server_name):
))
return host, port
-
-
-def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
- timeout=None):
- """Construct an endpoint for the given matrix destination.
-
- Args:
- reactor: Twisted reactor.
- destination (unicode): The name of the server to connect to.
- tls_client_options_factory
- (synapse.crypto.context_factory.ClientTLSOptionsFactory):
- Factory which generates TLS options for client connections.
- timeout (int): connection timeout in seconds
- """
-
- domain, port = parse_server_name(destination)
-
- endpoint_kw_args = {}
-
- if timeout is not None:
- endpoint_kw_args.update(timeout=timeout)
-
- if tls_client_options_factory is None:
- transport_endpoint = HostnameEndpoint
- default_port = 8008
- else:
- # the SNI string should be the same as the Host header, minus the port.
- # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
- # the Host header and SNI should therefore be the server_name of the remote
- # server.
- tls_options = tls_client_options_factory.get_options(domain)
-
- def transport_endpoint(reactor, host, port, timeout):
- return wrapClientTLS(
- tls_options,
- HostnameEndpoint(reactor, host, port, timeout=timeout),
- )
- default_port = 8448
-
- if port is None:
- return SRVClientEndpoint(
- reactor, "matrix", domain, protocol="tcp",
- default_port=default_port, endpoint=transport_endpoint,
- endpoint_kw_args=endpoint_kw_args
- )
- else:
- return transport_endpoint(
- reactor, domain, port, **endpoint_kw_args
- )
-
-
-class SRVClientEndpoint(object):
- """An endpoint which looks up SRV records for a service.
- Cycles through the list of servers starting with each call to connect
- picking the next server.
- Implements twisted.internet.interfaces.IStreamClientEndpoint.
- """
-
- def __init__(self, reactor, service, domain, protocol="tcp",
- default_port=None, endpoint=HostnameEndpoint,
- endpoint_kw_args={}):
- self.reactor = reactor
- self.service_name = "_%s._%s.%s" % (service, protocol, domain)
-
- if default_port is not None:
- self.default_server = Server(
- host=domain,
- port=default_port,
- )
- else:
- self.default_server = None
-
- self.endpoint = endpoint
- self.endpoint_kw_args = endpoint_kw_args
-
- self.servers = None
- self.used_servers = None
-
- @defer.inlineCallbacks
- def fetch_servers(self):
- self.used_servers = []
- self.servers = yield resolve_service(self.service_name)
-
- def pick_server(self):
- if not self.servers:
- if self.used_servers:
- self.servers = self.used_servers
- self.used_servers = []
- self.servers.sort()
- elif self.default_server:
- return self.default_server
- else:
- raise ConnectError(
- "No server available for %s" % self.service_name
- )
-
- # look for all servers with the same priority
- min_priority = self.servers[0].priority
- weight_indexes = list(
- (index, server.weight + 1)
- for index, server in enumerate(self.servers)
- if server.priority == min_priority
- )
-
- total_weight = sum(weight for index, weight in weight_indexes)
- target_weight = random.randint(0, total_weight)
- for index, weight in weight_indexes:
- target_weight -= weight
- if target_weight <= 0:
- server = self.servers[index]
- # XXX: this looks totally dubious:
- #
- # (a) we never reuse a server until we have been through
- # all of the servers at the same priority, so if the
- # weights are A: 100, B:1, we always do ABABAB instead of
- # AAAA...AAAB (approximately).
- #
- # (b) After using all the servers at the lowest priority,
- # we move onto the next priority. We should only use the
- # second priority if servers at the top priority are
- # unreachable.
- #
- del self.servers[index]
- self.used_servers.append(server)
- return server
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- if self.servers is None:
- yield self.fetch_servers()
- server = self.pick_server()
- logger.info("Connecting to %s:%s", server.host, server.port)
- endpoint = self.endpoint(
- self.reactor, server.host, server.port, **self.endpoint_kw_args
- )
- connection = yield endpoint.connect(protocolFactory)
- defer.returnValue(connection)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
new file mode 100644
index 0000000000..384d8a37a2
--- /dev/null
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -0,0 +1,452 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import random
+import time
+
+import attr
+from netaddr import IPAddress
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
+from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IAgent
+
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util import Clock
+from synapse.util.caches.ttlcache import TTLCache
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.metrics import Measure
+
+# period to cache .well-known results for by default
+WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
+
+# jitter to add to the .well-known default cache ttl
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
+
+# period to cache failure to fetch .well-known for
+WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
+
+# cap for .well-known cache period
+WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
+
+logger = logging.getLogger(__name__)
+well_known_cache = TTLCache('well-known')
+
+
+@implementer(IAgent)
+class MatrixFederationAgent(object):
+ """An Agent-like thing which provides a `request` method which will look up a matrix
+ server and send an HTTP request to it.
+
+ Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
+
+ Args:
+ reactor (IReactor): twisted reactor to use for underlying requests
+
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+
+ _well_known_tls_policy (IPolicyForHTTPS|None):
+ TLS policy to use for fetching .well-known files. None to use a default
+ (browser-like) implementation.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
+ """
+
+ def __init__(
+ self, reactor, tls_client_options_factory,
+ _well_known_tls_policy=None,
+ _srv_resolver=None,
+ _well_known_cache=well_known_cache,
+ ):
+ self._reactor = reactor
+ self._clock = Clock(reactor)
+
+ self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
+
+ self._pool = HTTPConnectionPool(reactor)
+ self._pool.retryAutomatically = False
+ self._pool.maxPersistentPerHost = 5
+ self._pool.cachedConnectionTimeout = 2 * 60
+
+ agent_args = {}
+ if _well_known_tls_policy is not None:
+ # the param is called 'contextFactory', but actually passing a
+ # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
+ agent_args['contextFactory'] = _well_known_tls_policy
+ _well_known_agent = RedirectAgent(
+ Agent(self._reactor, pool=self._pool, **agent_args),
+ )
+ self._well_known_agent = _well_known_agent
+
+ # our cache of .well-known lookup results, mapping from server name
+ # to delegated name. The values can be:
+ # `bytes`: a valid server-name
+ # `None`: there is no (valid) .well-known here
+ self._well_known_cache = _well_known_cache
+
+ @defer.inlineCallbacks
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Args:
+ method (bytes): HTTP method: GET/POST/etc
+
+ uri (bytes): Absolute URI to be retrieved
+
+ headers (twisted.web.http_headers.Headers|None):
+ HTTP headers to send with the request, or None to
+ send no extra headers.
+
+ bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ An object which can generate bytes to make up the
+ body of this request (for example, the properly encoded contents of
+ a file for a file upload). Or None if the request is to have
+ no body.
+
+ Returns:
+ Deferred[twisted.web.iweb.IResponse]:
+ fires when the header of the response has been received (regardless of the
+ response status code). Fails if there is any problem which prevents that
+ response from being received (including problems that prevent the request
+ from being sent).
+ """
+ parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+ res = yield self._route_matrix_uri(parsed_uri)
+
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+ if self._tls_client_options_factory is None:
+ tls_options = None
+ else:
+ tls_options = self._tls_client_options_factory.get_options(
+ res.tls_server_name.decode("ascii")
+ )
+
+ # make sure that the Host header is set correctly
+ if headers is None:
+ headers = Headers()
+ else:
+ headers = headers.copy()
+
+ if not headers.hasHeader(b'host'):
+ headers.addRawHeader(b'host', res.host_header)
+
+ class EndpointFactory(object):
+ @staticmethod
+ def endpointForURI(_uri):
+ ep = LoggingHostnameEndpoint(
+ self._reactor, res.target_host, res.target_port,
+ )
+ if tls_options is not None:
+ ep = wrapClientTLS(tls_options, ep)
+ return ep
+
+ agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
+ res = yield make_deferred_yieldable(
+ agent.request(method, uri, headers, bodyProducer)
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
+ """Helper for `request`: determine the routing for a Matrix URI
+
+ Args:
+ parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+ parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+ if there is no explicit port given.
+
+ lookup_well_known (bool): True if we should look up the .well-known file if
+ there is no SRV record.
+
+ Returns:
+ Deferred[_RoutingResult]
+ """
+ # check for an IP literal
+ try:
+ ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+ except Exception:
+ # not an IP address
+ ip_address = None
+
+ if ip_address:
+ port = parsed_uri.port
+ if port == -1:
+ port = 8448
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ ))
+
+ if parsed_uri.port != -1:
+ # there is an explicit port
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ ))
+
+ if lookup_well_known:
+ # try a .well-known lookup
+ well_known_server = yield self._get_well_known(parsed_uri.host)
+
+ if well_known_server:
+ # if we found a .well-known, start again, but don't do another
+ # .well-known lookup.
+
+ # parse the server name in the .well-known response into host/port.
+ # (This code is lifted from twisted.web.client.URI.fromBytes).
+ if b':' in well_known_server:
+ well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+ try:
+ well_known_port = int(well_known_port)
+ except ValueError:
+ # the part after the colon could not be parsed as an int
+ # - we assume it is an IPv6 literal with no port (the closing
+ # ']' stops it being parsed as an int)
+ well_known_host, well_known_port = well_known_server, -1
+ else:
+ well_known_host, well_known_port = well_known_server, -1
+
+ new_uri = URI(
+ scheme=parsed_uri.scheme,
+ netloc=well_known_server,
+ host=well_known_host,
+ port=well_known_port,
+ path=parsed_uri.path,
+ params=parsed_uri.params,
+ query=parsed_uri.query,
+ fragment=parsed_uri.fragment,
+ )
+
+ res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
+ defer.returnValue(res)
+
+ # try a SRV lookup
+ service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+
+ if not server_list:
+ target_host = parsed_uri.host
+ port = 8448
+ logger.debug(
+ "No SRV record for %s, using %s:%i",
+ parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ )
+ else:
+ target_host, port = pick_server_from_list(server_list)
+ logger.debug(
+ "Picked %s:%i from SRV records for %s",
+ target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ )
+
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ ))
+
+ @defer.inlineCallbacks
+ def _get_well_known(self, server_name):
+ """Attempt to fetch and parse a .well-known file for the given server
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[bytes|None]: either the new server name, from the .well-known, or
+ None if there was no .well-known file.
+ """
+ try:
+ result = self._well_known_cache[server_name]
+ except KeyError:
+ # TODO: should we linearise so that we don't end up doing two .well-known
+ # requests for the same server in parallel?
+ with Measure(self._clock, "get_well_known"):
+ result, cache_period = yield self._do_get_well_known(server_name)
+
+ if cache_period > 0:
+ self._well_known_cache.set(server_name, result, cache_period)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _do_get_well_known(self, server_name):
+ """Actually fetch and parse a .well-known, without checking the cache
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[Tuple[bytes|None|object],int]:
+ result, cache period, where result is one of:
+ - the new server name from the .well-known (as a `bytes`)
+ - None if there was no .well-known file.
+ - INVALID_WELL_KNOWN if the .well-known was invalid
+ """
+ uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+ uri_str = uri.decode("ascii")
+ logger.info("Fetching %s", uri_str)
+ try:
+ response = yield make_deferred_yieldable(
+ self._well_known_agent.request(b"GET", uri),
+ )
+ body = yield make_deferred_yieldable(readBody(response))
+ if response.code != 200:
+ raise Exception("Non-200 response %s" % (response.code, ))
+
+ parsed_body = json.loads(body.decode('utf-8'))
+ logger.info("Response from .well-known: %s", parsed_body)
+ if not isinstance(parsed_body, dict):
+ raise Exception("not a dict")
+ if "m.server" not in parsed_body:
+ raise Exception("Missing key 'm.server'")
+ except Exception as e:
+ logger.info("Error fetching %s: %s", uri_str, e)
+
+ # add some randomness to the TTL to avoid a stampeding herd every hour
+ # after startup
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ defer.returnValue((None, cache_period))
+
+ result = parsed_body["m.server"].encode("ascii")
+
+ cache_period = _cache_period_from_headers(
+ response.headers,
+ time_now=self._reactor.seconds,
+ )
+ if cache_period is None:
+ cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
+ # add some randomness to the TTL to avoid a stampeding herd every 24 hours
+ # after startup
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ else:
+ cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
+
+ defer.returnValue((result, cache_period))
+
+
+@implementer(IStreamClientEndpoint)
+class LoggingHostnameEndpoint(object):
+ """A wrapper for HostnameEndpint which logs when it connects"""
+ def __init__(self, reactor, host, port, *args, **kwargs):
+ self.host = host
+ self.port = port
+ self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+
+ def connect(self, protocol_factory):
+ logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
+ return self.ep.connect(protocol_factory)
+
+
+def _cache_period_from_headers(headers, time_now=time.time):
+ cache_controls = _parse_cache_control(headers)
+
+ if b'no-store' in cache_controls:
+ return 0
+
+ if b'max-age' in cache_controls:
+ try:
+ max_age = int(cache_controls[b'max-age'])
+ return max_age
+ except ValueError:
+ pass
+
+ expires = headers.getRawHeaders(b'expires')
+ if expires is not None:
+ try:
+ expires_date = stringToDatetime(expires[-1])
+ return expires_date - time_now()
+ except ValueError:
+ # RFC7234 says 'A cache recipient MUST interpret invalid date formats,
+ # especially the value "0", as representing a time in the past (i.e.,
+ # "already expired").
+ return 0
+
+ return None
+
+
+def _parse_cache_control(headers):
+ cache_controls = {}
+ for hdr in headers.getRawHeaders(b'cache-control', []):
+ for directive in hdr.split(b','):
+ splits = [x.strip() for x in directive.split(b'=', 1)]
+ k = splits[0].lower()
+ v = splits[1] if len(splits) > 1 else None
+ cache_controls[k] = v
+ return cache_controls
+
+
+@attr.s
+class _RoutingResult(object):
+ """The result returned by `_route_matrix_uri`.
+
+ Contains the parameters needed to direct a federation connection to a particular
+ server.
+
+ Where a SRV record points to several servers, this object contains a single server
+ chosen from the list.
+ """
+
+ host_header = attr.ib()
+ """
+ The value we should assign to the Host header (host:port from the matrix
+ URI, or .well-known).
+
+ :type: bytes
+ """
+
+ tls_server_name = attr.ib()
+ """
+ The server name we should set in the SNI (typically host, without port, from the
+ matrix URI or .well-known)
+
+ :type: bytes
+ """
+
+ target_host = attr.ib()
+ """
+ The hostname (or IP literal) we should route the TCP connection to (the target of the
+ SRV record, or the hostname from the URL/.well-known)
+
+ :type: bytes
+ """
+
+ target_port = attr.ib()
+ """
+ The port we should route the TCP connection to (the target of the SRV record, or
+ the port from the URL/.well-known, or 8448)
+
+ :type: int
+ """
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index c49b82c394..71830c549d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import random
import time
import attr
@@ -51,74 +52,118 @@ class Server(object):
expires = attr.ib(default=0)
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- """Look up a SRV record, with caching
+def pick_server_from_list(server_list):
+ """Randomly choose a server from the server list
+
+ Args:
+ server_list (list[Server]): list of candidate servers
+
+ Returns:
+ Tuple[bytes, int]: (host, port) pair for the chosen server
+ """
+ if not server_list:
+ raise RuntimeError("pick_server_from_list called with empty list")
+
+ # TODO: currently we only use the lowest-priority servers. We should maintain a
+ # cache of servers known to be "down" and filter them out
+
+ min_priority = min(s.priority for s in server_list)
+ eligible_servers = list(s for s in server_list if s.priority == min_priority)
+ total_weight = sum(s.weight for s in eligible_servers)
+ target_weight = random.randint(0, total_weight)
+
+ for s in eligible_servers:
+ target_weight -= s.weight
+
+ if target_weight <= 0:
+ return s.host, s.port
+
+ # this should be impossible.
+ raise RuntimeError(
+ "pick_server_from_list got to end of eligible server list.",
+ )
+
+
+class SrvResolver(object):
+ """Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
- service_name (unicode|bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
- clock (object): clock implementation. must provide a time() method.
-
- Returns:
- Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
+ get_time (callable): clock implementation. Should return seconds since the epoch
"""
- # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
- # byteses; however they will obviously end up as separate entries in the cache. We
- # should pick one form and stick with it.
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- try:
- answers, _, _ = yield make_deferred_yieldable(
- dns_client.lookupService(service_name),
- )
- except DNSNameError:
- # TODO: cache this. We can get the SOA out of the exception, and use
- # the negative-TTL value.
- defer.returnValue([])
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
+ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ self._dns_client = dns_client
+ self._cache = cache
+ self._get_time = get_time
+
+ @defer.inlineCallbacks
+ def resolve_service(self, service_name):
+ """Look up a SRV record
+
+ Args:
+ service_name (bytes): record to look up
+
+ Returns:
+ Deferred[list[Server]]:
+ a list of the SRV records, or an empty list if none found
+ """
+ now = int(self._get_time())
+
+ if not isinstance(service_name, bytes):
+ raise TypeError("%r is not a byte string" % (service_name,))
+
+ cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ if all(s.expires > now for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ self._dns_client.lookupService(service_name),
)
- defer.returnValue(list(cache_entry))
- else:
- raise e
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- servers = []
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
- cache[service_name] = list(servers)
- defer.returnValue(servers)
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ ))
+
+ self._cache[service_name] = list(servers)
+ defer.returnValue(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 250bb1ef91..5ee4d528d2 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -28,11 +28,11 @@ from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from twisted.internet import defer, protocol
+from twisted.internet import defer, protocol, task
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
+from twisted.web.client import FileBodyProducer
from twisted.web.http_headers import Headers
import synapse.metrics
@@ -44,7 +44,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
@@ -66,20 +66,6 @@ else:
MAXINT = sys.maxint
-class MatrixFederationEndpointFactory(object):
- def __init__(self, hs):
- self.reactor = hs.get_reactor()
- self.tls_client_options_factory = hs.tls_client_options_factory
-
- def endpointForURI(self, uri):
- destination = uri.netloc.decode('ascii')
-
- return matrix_federation_endpoint(
- self.reactor, destination, timeout=10,
- tls_client_options_factory=self.tls_client_options_factory
- )
-
-
_next_id = 1
@@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
- pool = HTTPConnectionPool(reactor)
- pool.retryAutomatically = False
- pool.maxPersistentPerHost = 5
- pool.cachedConnectionTimeout = 2 * 60
- self.agent = Agent.usingEndpointFactory(
- reactor, MatrixFederationEndpointFactory(hs), pool=pool
+
+ self.agent = MatrixFederationAgent(
+ hs.get_reactor(),
+ hs.tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@@ -271,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = {
b"User-Agent": [self.version_string_bytes],
- b"Host": [destination_bytes],
}
with limiter:
@@ -303,7 +286,7 @@ class MatrixFederationHttpClient(object):
json,
)
data = encode_canonical_json(json)
- producer = FileBodyProducer(
+ producer = QuieterFileBodyProducer(
BytesIO(data),
cooperator=self._cooperator,
)
@@ -316,9 +299,9 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
logger.info(
- "{%s} [%s] Sending request: %s %s",
+ "{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
- url_str,
+ url_str, _sec_timeout,
)
try:
@@ -338,12 +321,11 @@ class MatrixFederationHttpClient(object):
reactor=self.hs.get_reactor(),
)
- response = yield make_deferred_yieldable(
- request_deferred,
- )
+ response = yield request_deferred
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:
+ logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info(
@@ -857,3 +839,16 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes.encode('utf8')
+
+
+class QuieterFileBodyProducer(FileBodyProducer):
+ """Wrapper for FileBodyProducer that avoids CRITICAL errors when the connection drops.
+
+ Workaround for https://github.com/matrix-org/synapse/issues/4003 /
+ https://twistedmatrix.com/trac/ticket/6528
+ """
+ def stopProducing(self):
+ try:
+ FileBodyProducer.stopProducing(self)
+ except task.TaskStopped:
+ pass
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index ecbf364a5e..8bd96b1178 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -84,7 +84,7 @@ def _rule_to_template(rule):
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
- templaterule['rule_id'] = unscoped_rule_id
+ templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index b5555fbba9..5d087ee26b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -79,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
+ # ACME support is required to provision TLS certificates from authorities
+ # that use the protocol, such as Let's Encrypt.
+ "acme": ["txacme>=0.9.2"],
+
"saml2": ["pysaml2>=4.5.0"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 5e5376cf58..e81456ab2b 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -127,7 +127,10 @@ class ReplicationEndpoint(object):
def send_request(**kwargs):
data = yield cls._serialize_payload(**kwargs)
- url_args = [urllib.parse.quote(kwargs[name]) for name in cls.PATH_ARGS]
+ url_args = [
+ urllib.parse.quote(kwargs[name], safe='')
+ for name in cls.PATH_ARGS
+ ]
if cls.CACHE:
txn_id = random_string(10)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 64a79da162..0f0a07c422 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.events import FrozenEvent
+from synapse.events import event_type_from_format_version
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -70,6 +70,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_payloads.append({
"event": event.get_pdu_json(),
+ "event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
@@ -94,9 +95,12 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
+ format_ver = event_payload["event_format_version"]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
- event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ EventType = event_type_from_format_version(format_ver)
+ event = EventType(event_dict, internal_metadata, rejected_reason)
context = yield EventContext.deserialize(
self.store, event_payload["context"],
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 5b52c91650..3635015eda 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.events import FrozenEvent
+from synapse.events import event_type_from_format_version
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -74,6 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
payload = {
"event": event.get_pdu_json(),
+ "event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
@@ -90,9 +91,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request)
event_dict = content["event"]
+ format_ver = content["event_format_version"]
internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"]
- event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+ EventType = event_type_from_format_version(format_ver)
+ event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
context = yield EventContext.deserialize(self.store, content["context"])
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 66585c991f..91f5247d52 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -34,6 +34,7 @@ from synapse.rest.client.v2_alpha import (
account,
account_data,
auth,
+ capabilities,
devices,
filter,
groups,
@@ -107,3 +108,4 @@ class ClientRestResource(JsonResource):
user_directory.register_servlets(hs, client_resource)
groups.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
+ capabilities.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fcfe7857f6..48da4d557f 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -89,7 +89,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -172,7 +172,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
@@ -189,7 +189,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -211,7 +211,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
new file mode 100644
index 0000000000..373f95126e
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import DEFAULT_ROOM_VERSION, RoomDisposition, RoomVersions
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class CapabilitiesRestServlet(RestServlet):
+ """End point to expose the capabilities of the server."""
+
+ PATTERNS = client_v2_patterns("/capabilities$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(CapabilitiesRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ user = yield self.store.get_user_by_id(requester.user.to_string())
+ change_password = bool(user["password_hash"])
+
+ response = {
+ "capabilities": {
+ "m.room_versions": {
+ "default": DEFAULT_ROOM_VERSION,
+ "available": {
+ RoomVersions.V1: RoomDisposition.STABLE,
+ RoomVersions.V2: RoomDisposition.STABLE,
+ RoomVersions.STATE_V2_TEST: RoomDisposition.UNSTABLE,
+ RoomVersions.V3: RoomDisposition.STABLE,
+ },
+ },
+ "m.change_password": {"enabled": change_password},
+ }
+ }
+ defer.returnValue((200, response))
+
+
+def register_servlets(hs, http_server):
+ CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 14025cd219..7f812b8209 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -416,8 +416,11 @@ class RegisterRestServlet(RestServlet):
)
# Necessary due to auth checks prior to the threepid being
# written to the db
- if is_threepid_reserved(self.hs.config, threepid):
- yield self.store.upsert_monthly_active_user(registered_user_id)
+ if threepid:
+ if is_threepid_reserved(
+ self.hs.config.mau_limits_reserved_threepids, threepid
+ ):
+ yield self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 0251146722..39d157a44b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -75,7 +75,7 @@ class SyncRestServlet(RestServlet):
"""
PATTERNS = client_v2_patterns("/sync$")
- ALLOWED_PRESENCE = set(["online", "offline"])
+ ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 29e62bfcdd..27e7cbf3cc 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -38,6 +38,7 @@ class VersionsRestServlet(RestServlet):
"r0.1.0",
"r0.2.0",
"r0.3.0",
+ "r0.4.0",
],
# as per MSC1497:
"unstable_features": {
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 80611cfe84..008d4edae5 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -101,16 +101,7 @@ class ConsentResource(Resource):
"missing in config file.",
)
- # daemonize changes the cwd to /, so make the path absolute now.
- consent_template_directory = path.abspath(
- hs.config.user_consent_template_dir,
- )
- if not path.isdir(consent_template_directory):
- raise ConfigError(
- "Could not find template directory '%s'" % (
- consent_template_directory,
- ),
- )
+ consent_template_directory = hs.config.user_consent_template_dir
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
diff --git a/synapse/server.py b/synapse/server.py
index 9985687b95..6c52101616 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers
+from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
@@ -129,6 +130,7 @@ class HomeServer(object):
'sync_handler',
'typing_handler',
'room_list_handler',
+ 'acme_handler',
'auth_handler',
'device_handler',
'e2e_keys_handler',
@@ -310,6 +312,9 @@ class HomeServer(object):
def build_e2e_room_keys_handler(self):
return E2eRoomKeysHandler(self)
+ def build_acme_handler(self):
+ return AcmeHandler(self)
+
def build_application_service_api(self):
return ApplicationServiceApi(self)
@@ -350,10 +355,7 @@ class HomeServer(object):
return Keyring(self)
def build_event_builder_factory(self):
- return EventBuilderFactory(
- clock=self.get_clock(),
- hostname=self.hostname,
- )
+ return EventBuilderFactory(self)
def build_filtering(self):
return Filtering(self)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index e9ecb00277..68058f613c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -608,10 +608,10 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
state_sets, event_map, state_res_store.get_events,
)
elif room_version in (
- RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
+ RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3,
):
return v2.resolve_events_with_store(
- state_sets, event_map, state_res_store,
+ room_version, state_sets, event_map, state_res_store,
)
else:
# This should only happen if we added a version but forgot to add it to
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 19e091ce3b..6d3afcae7c 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -21,7 +21,7 @@ from six import iteritems, iterkeys, itervalues
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import AuthError
logger = logging.getLogger(__name__)
@@ -274,7 +274,11 @@ def _resolve_auth_events(events, auth_events):
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+ event_auth.check(
+ RoomVersions.V1, event, auth_events,
+ do_sig_check=False,
+ do_size_check=False,
+ )
prev_event = event
except AuthError:
return prev_event
@@ -286,7 +290,11 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+ event_auth.check(
+ RoomVersions.V1, event, auth_events,
+ do_sig_check=False,
+ do_size_check=False,
+ )
return event
except AuthError:
pass
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 3573bb0028..650995c92c 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -29,10 +29,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_res_store):
+def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
"""Resolves the state using the v2 state resolution algorithm
Args:
+ room_version (str): The room version
+
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
@@ -104,7 +106,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
- sorted_power_events, unconflicted_state, event_map,
+ room_version, sorted_power_events, unconflicted_state, event_map,
state_res_store,
)
@@ -129,7 +131,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
- leftover_events, resolved_state, event_map,
+ room_version, leftover_events, resolved_state, event_map,
state_res_store,
)
@@ -350,11 +352,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
@defer.inlineCallbacks
-def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
+def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
+ state_res_store):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
+ room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with
event_map (dict[str,FrozenEvent])
@@ -385,7 +389,7 @@ def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
try:
event_auth.check(
- event, auth_events,
+ room_version, event, auth_events,
do_sig_check=False,
do_size_check=False
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 24329879e5..42cd3c83ad 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -317,7 +317,7 @@ class DataStore(RoomMemberStore, RoomStore,
thirty_days_ago_in_secs))
for row in txn:
- if row[0] is 'unknown':
+ if row[0] == 'unknown':
pass
results[row[0]] = row[1]
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 865b5e915a..e124161845 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,7 +26,8 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.engines import PostgresEngine
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
@@ -49,6 +50,21 @@ sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -192,6 +208,57 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.database_engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.database_engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self._simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -494,8 +561,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(self, table, keyvalues, values,
- insertion_values={}, desc="_simple_upsert", lock=True):
+ def _simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="_simple_upsert",
+ lock=True
+ ):
"""
`lock` should generally be set to True (the default), but can be set
@@ -516,16 +590,21 @@ class SQLBaseStore(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(bool): True if a new entry was created, False if an
- existing one was updated.
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock=lock
+ self._simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
@@ -537,12 +616,71 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "IntegrityError when upserting into %s; retrying: %s",
- table, e
+ "%s when upserting into %s; retrying: %s", e.__name__, table, e
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
- lock=True):
+ def _simple_upsert_txn(
+ self,
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ lock=True,
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self._simple_upsert_txn_native_upsert(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ )
+ else:
+ return self._simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def _simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
@@ -577,12 +715,44 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
+ ", ".join("?" for _ in allvalues),
)
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
+ def _simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = (
+ "INSERT INTO %s (%s) VALUES (%s) "
+ "ON CONFLICT (%s) DO UPDATE SET %s"
+ ) % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ ", ".join(k + "=EXCLUDED." + k for k in values),
+ )
+ txn.execute(sql, list(allvalues.values()))
+
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 5fe1ca2de7..60cdc884e6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore):
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
- The hander is responsible for updating the progress of the update.
+ The handler is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 78721a941a..091d7116c5 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -143,6 +143,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# If it returns None, then we're processing the last batch
last = end_last_seen is None
+ logger.info(
+ "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
+ begin_last_seen, end_last_seen,
+ )
+
def remove(txn):
# This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range
@@ -170,7 +175,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
SELECT user_id, access_token, ip
FROM user_ips
WHERE {}
- ORDER BY last_seen
) c
INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip
@@ -253,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
def _update_client_ips_batch_txn(self, txn, to_update):
- self.database_engine.lock_table(txn, "user_ips")
+ if "user_ips" in self._unsafe_to_upsert_tables or (
+ not self.database_engine.can_native_upsert
+ ):
+ self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index e2f9de8451..ff5ef97ca8 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,7 +18,7 @@ import platform
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
+from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 42225f8a2a..4004427c7b 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -38,6 +38,13 @@ class PostgresEngine(object):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
+
+ # Get the version of PostgreSQL that we're using. As per the psycopg2
+ # docs: The number is formed by converting the major, minor, and
+ # revision numbers into two-decimal-digit numbers and appending them
+ # together. For example, version 8.1.5 will be returned as 80105
+ self._version = db_conn.server_version
+
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -54,6 +61,13 @@ class PostgresEngine(object):
cursor.close()
+ @property
+ def can_native_upsert(self):
+ """
+ Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ """
+ return self._version >= 90500
+
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite.py
index 19949fc474..059ab81055 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite.py
@@ -30,6 +30,14 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
+ @property
+ def can_native_upsert(self):
+ """
+ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+ more work we haven't done yet to tell what was inserted vs updated.
+ """
+ return self.module.sqlite_version_info >= (3, 24, 0)
+
def check_database(self, txn):
pass
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index d3b9dea1d6..38809ed0fc 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn)
+ @defer.inlineCallbacks
+ def get_max_depth_of(self, event_ids):
+ """Returns the max depth of a set of event IDs
+
+ Args:
+ event_ids (list[str])
+
+ Returns
+ Deferred[int]
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("depth",),
+ desc="get_max_depth_of",
+ )
+
+ if not rows:
+ defer.returnValue(0)
+ else:
+ defer.returnValue(max(row["depth"] for row in rows))
+
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 79e0276de6..81b250480d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -904,106 +904,106 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
- to_delete, to_insert = current_state_tuple
-
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- sql = """
- INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
- SELECT event_id FROM current_state_events
- WHERE room_id = ? AND type = ? AND state_key = ?
- )
- """
- txn.executemany(sql, (
- (
- max_stream_order, room_id, etype, state_key, None,
- room_id, etype, state_key,
- )
- for etype, state_key in to_delete
- # We sanity check that we're deleting rather than updating
- if (etype, state_key) not in to_insert
- ))
- txn.executemany(sql, (
- (
- max_stream_order, room_id, etype, state_key, ev_id,
- room_id, etype, state_key,
- )
- for (etype, state_key), ev_id in iteritems(to_insert)
- ))
+ to_delete, to_insert = current_state_tuple
- # Now we actually update the current_state_events table
-
- txn.executemany(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
)
-
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in iteritems(to_insert)
- ],
+ """
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, None,
+ room_id, etype, state_key,
)
-
- txn.call_after(
- self._curr_state_delta_stream_cache.entity_has_changed,
- room_id, max_stream_order,
+ for etype, state_key in to_delete
+ # We sanity check that we're deleting rather than updating
+ if (etype, state_key) not in to_insert
+ ))
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, ev_id,
+ room_id, etype, state_key,
)
+ for (etype, state_key), ev_id in iteritems(to_insert)
+ ))
- # Invalidate the various caches
-
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invlidate the caches for all
- # those users.
- members_changed = set(
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- )
+ # Now we actually update the current_state_events table
- for member in members_changed:
- self._invalidate_cache_and_stream(
- txn, self.get_rooms_for_user_with_stream_ordering, (member,)
- )
+ txn.executemany(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- for host in set(get_domain_from_id(u) for u in members_changed):
- self._invalidate_cache_and_stream(
- txn, self.is_host_joined, (room_id, host)
- )
- self._invalidate_cache_and_stream(
- txn, self.was_host_joined, (room_id, host)
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_events",
+ values=[
+ {
+ "event_id": ev_id,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ }
+ for key, ev_id in iteritems(to_insert)
+ ],
+ )
+
+ txn.call_after(
+ self._curr_state_delta_stream_cache.entity_has_changed,
+ room_id, max_stream_order,
+ )
+
+ # Invalidate the various caches
+
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invlidate the caches for all
+ # those users.
+ members_changed = set(
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ )
+ for member in members_changed:
self._invalidate_cache_and_stream(
- txn, self.get_users_in_room, (room_id,)
+ txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
+ for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
- txn, self.get_room_summary, (room_id,)
+ txn, self.is_host_joined, (room_id, host)
)
-
self._invalidate_cache_and_stream(
- txn, self.get_current_state_ids, (room_id,)
+ txn, self.was_host_joined, (room_id, host)
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_users_in_room, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_summary, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_current_state_ids, (room_id,)
+ )
+
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
for room_id, new_extrem in iteritems(new_forward_extremities):
@@ -1268,6 +1268,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
event.internal_metadata.get_dict()
),
"json": encode_json(event_dict(event)),
+ "format_version": event.format_version,
}
for event, _ in events_and_contexts
],
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index a8326f5296..1716be529a 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -21,13 +21,14 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventFormatVersions, EventTypes
from synapse.api.errors import NotFoundError
+from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import get_domain_from_id
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -160,9 +161,14 @@ class EventsWorkerStore(SQLBaseStore):
log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
+ # Note that _enqueue_events is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event out
+ # of the database to check it.
+ #
+ # _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
missing_events_ids,
- check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
@@ -174,6 +180,50 @@ class EventsWorkerStore(SQLBaseStore):
if not entry:
continue
+ # Starting in room version v3, some redactions need to be rechecked if we
+ # didn't have the redacted event at the time, so we recheck on read
+ # instead.
+ if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ if entry.event.internal_metadata.need_to_check_redaction():
+ # XXX: we need to avoid calling get_event here.
+ #
+ # The problem is that we end up at this point when an event
+ # which has been redacted is pulled out of the database by
+ # _enqueue_events, because _enqueue_events needs to check the
+ # redaction before it can cache the redacted event. So obviously,
+ # calling get_event to get the redacted event out of the database
+ # gives us an infinite loop.
+ #
+ # For now (quick hack to fix during 0.99 release cycle), we just
+ # go and fetch the relevant row from the db, but it would be nice
+ # to think about how we can cache this rather than hit the db
+ # every time we access a redaction event.
+ #
+ # One thought on how to do this:
+ # 1. split _get_events up so that it is divided into (a) get the
+ # rawish event from the db/cache, (b) do the redaction/rejection
+ # filtering
+ # 2. have _get_event_from_row just call the first half of that
+
+ orig_sender = yield self._simple_select_one_onecol(
+ table="events",
+ keyvalues={"event_id": entry.event.redacts},
+ retcol="sender",
+ allow_none=True,
+ )
+
+ expected_domain = get_domain_from_id(entry.event.sender)
+ if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a
+ # recheck.
+ entry.event.internal_metadata.recheck_redaction = False
+ else:
+ # We don't have the event that is being redacted, so we
+ # assume that the event isn't authorized for now. (If we
+ # later receive the event, then we will always redact
+ # it anyway, since we have this redaction)
+ continue
+
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
@@ -197,7 +247,7 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
+ self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
@@ -310,7 +360,7 @@ class EventsWorkerStore(SQLBaseStore):
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ def _enqueue_events(self, events, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -353,6 +403,7 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
+ format_version=row["format_version"],
)
for row in rows
],
@@ -377,6 +428,7 @@ class EventsWorkerStore(SQLBaseStore):
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
+ " e.format_version, "
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
@@ -392,7 +444,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
+ format_version, rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
@@ -405,8 +457,13 @@ class EventsWorkerStore(SQLBaseStore):
desc="_get_event_from_row_rejected_reason",
)
- original_ev = FrozenEvent(
- d,
+ if format_version is None:
+ # This means that we stored the event before we had the concept
+ # of a event format version, so it must be a V1 event.
+ format_version = EventFormatVersions.V1
+
+ original_ev = event_type_from_format_version(format_version)(
+ event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
@@ -436,6 +493,19 @@ class EventsWorkerStore(SQLBaseStore):
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
+ # Starting in room version v3, some redactions need to be
+ # rechecked if we didn't have the redacted event at the
+ # time, so we recheck on read instead.
+ if because.internal_metadata.need_to_check_redaction():
+ expected_domain = get_domain_from_id(original_ev.sender)
+ if get_domain_from_id(because.sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a
+ # recheck.
+ because.internal_metadata.recheck_redaction = False
+ else:
+ # Senders don't match, so the event isn't actually redacted
+ redacted_event = None
+
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d6fc8edd4c..9e7e09b8c1 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- is_insert = yield self.runInteraction(
+ yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id
)
- if is_insert:
- self.user_last_seen_monthly_active.invalidate((user_id,))
+ user_in_mau = self.user_last_seen_monthly_active.cache.get(
+ (user_id,),
+ None,
+ update_metrics=False
+ )
+ if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2743b52bad..134297e284 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
- newly_inserted = yield self._simple_upsert(
+ yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
@@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- if newly_inserted:
+ user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ (user_id,), None, update_metrics=False
+ )
+
+ if user_has_pusher is not True:
+ # invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0707f9a86a..592c1bcd33 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -588,12 +588,12 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
# We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
+ or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/schema/delta/53/event_format_version.sql
new file mode 100644
index 0000000000..1d977c2834
--- /dev/null
+++ b/synapse/storage/schema/delta/53/event_format_version.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_json ADD COLUMN format_version INTEGER;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a134e9b3e8..d14a7b2538 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -428,14 +428,54 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# for now we do this by looking at the create event. We may want to cache this
# more intelligently in future.
+
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+ defer.returnValue(create_event.content.get("room_version", "1"))
+
+ @defer.inlineCallbacks
+ def get_room_predecessor(self, room_id):
+ """Get the predecessor room of an upgraded room if one exists.
+ Otherwise return None.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[unicode|None]: predecessor room id
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+
+ # Return predecessor if present
+ defer.returnValue(create_event.content.get("predecessor", None))
+
+ @defer.inlineCallbacks
+ def get_create_event_for_room(self, room_id):
+ """Get the create state event for a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[EventBase]: The room creation event.
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
state_ids = yield self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
+ # If we can't find the create event, assume we've hit a dead end
if not create_id:
raise NotFoundError("Unknown room %s" % (room_id))
+ # Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
+ defer.returnValue(create_event)
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..e8b574ee5e 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -31,12 +32,19 @@ logger = logging.getLogger(__name__)
class UserDirectoryStore(SQLBaseStore):
- @cachedInlineCallbacks(cache_context=True)
- def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ @defer.inlineCallbacks
+ def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
- current_state_ids = yield self.get_current_state_ids(
- room_id, on_invalidate=cache_context.invalidate
+
+ # Create a state filter that only queries join and history state event
+ types_to_filter = (
+ (EventTypes.JoinRules, ""),
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
+
+ current_state_ids = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types(types_to_filter),
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
@@ -168,14 +176,14 @@ class UserDirectoryStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
- if new_entry:
+ if self.database_engine.can_native_upsert:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- )
+ ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
sql,
@@ -185,20 +193,45 @@ class UserDirectoryStore(SQLBaseStore):
)
)
else:
- sql = """
- UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- WHERE user_id = ?
- """
- txn.execute(
- sql,
- (
- get_localpart_from_id(user_id), get_domain_from_id(user_id),
- display_name, user_id,
+ # TODO: Remove this code after we've bumped the minimum version
+ # of postgres to always support upserts, so we can get rid of
+ # `new_entry` usage
+ if new_entry is True:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id, get_localpart_from_id(user_id),
+ get_domain_from_id(user_id), display_name,
+ )
+ )
+ elif new_entry is False:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name, user_id,
+ )
+ )
+ else:
+ raise RuntimeError(
+ "upsert returned None when 'can_native_upsert' is False"
)
- )
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name,) if display_name else user_id
self._simple_upsert_txn(
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 430bb15f51..f0e4a0e10c 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -201,7 +201,7 @@ class Linearizer(object):
if entry[0] >= self.max_count:
res = self._await_lock(key)
else:
- logger.info(
+ logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
@@ -215,7 +215,7 @@ class Linearizer(object):
try:
yield
finally:
- logger.info("Releasing linearizer lock %r for key %r", self.name, key)
+ logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
@@ -247,7 +247,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
- logger.info(
+ logger.debug(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
@@ -255,7 +255,7 @@ class Linearizer(object):
entry[1][new_defer] = 1
def cb(_r):
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it
@@ -273,7 +273,7 @@ class Linearizer(object):
def eb(e):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
- logger.info(
+ logger.debug(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
new file mode 100644
index 0000000000..5ba1862506
--- /dev/null
+++ b/synapse/util/caches/ttlcache.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+
+import attr
+from sortedcontainers import SortedList
+
+from synapse.util.caches import register_cache
+
+logger = logging.getLogger(__name__)
+
+SENTINEL = object()
+
+
+class TTLCache(object):
+ """A key/value cache implementation where each entry has its own TTL"""
+
+ def __init__(self, cache_name, timer=time.time):
+ # map from key to _CacheEntry
+ self._data = {}
+
+ # the _CacheEntries, sorted by expiry time
+ self._expiry_list = SortedList()
+
+ self._timer = timer
+
+ self._metrics = register_cache("ttl", cache_name, self)
+
+ def set(self, key, value, ttl):
+ """Add/update an entry in the cache
+
+ Args:
+ key: key for this entry
+ value: value for this entry
+ ttl (float): TTL for this entry, in seconds
+ """
+ expiry = self._timer() + ttl
+
+ self.expire()
+ e = self._data.pop(key, SENTINEL)
+ if e != SENTINEL:
+ self._expiry_list.remove(e)
+
+ entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
+ self._data[key] = entry
+ self._expiry_list.add(entry)
+
+ def get(self, key, default=SENTINEL):
+ """Get a value from the cache
+
+ Args:
+ key: key to look up
+ default: default value to return, if key is not found. If not set, and the
+ key is not found, a KeyError will be raised
+
+ Returns:
+ value from the cache, or the default
+ """
+ self.expire()
+ e = self._data.get(key, SENTINEL)
+ if e == SENTINEL:
+ self._metrics.inc_misses()
+ if default == SENTINEL:
+ raise KeyError(key)
+ return default
+ self._metrics.inc_hits()
+ return e.value
+
+ def get_with_expiry(self, key):
+ """Get a value, and its expiry time, from the cache
+
+ Args:
+ key: key to look up
+
+ Returns:
+ Tuple[Any, float]: the value from the cache, and the expiry time
+
+ Raises:
+ KeyError if the entry is not found
+ """
+ self.expire()
+ try:
+ e = self._data[key]
+ except KeyError:
+ self._metrics.inc_misses()
+ raise
+ self._metrics.inc_hits()
+ return e.value, e.expiry_time
+
+ def pop(self, key, default=SENTINEL):
+ """Remove a value from the cache
+
+ If key is in the cache, remove it and return its value, else return default.
+ If default is not given and key is not in the cache, a KeyError is raised.
+
+ Args:
+ key: key to look up
+ default: default value to return, if key is not found. If not set, and the
+ key is not found, a KeyError will be raised
+
+ Returns:
+ value from the cache, or the default
+ """
+ self.expire()
+ e = self._data.pop(key, SENTINEL)
+ if e == SENTINEL:
+ self._metrics.inc_misses()
+ if default == SENTINEL:
+ raise KeyError(key)
+ return default
+ self._expiry_list.remove(e)
+ self._metrics.inc_hits()
+ return e.value
+
+ def __getitem__(self, key):
+ return self.get(key)
+
+ def __delitem__(self, key):
+ self.pop(key)
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def __len__(self):
+ self.expire()
+ return len(self._data)
+
+ def expire(self):
+ """Run the expiry on the cache. Any entries whose expiry times are due will
+ be removed
+ """
+ now = self._timer()
+ while self._expiry_list:
+ first_entry = self._expiry_list[0]
+ if first_entry.expiry_time - now > 0.0:
+ break
+ del self._data[first_entry.key]
+ del self._expiry_list[0]
+
+
+@attr.s(frozen=True, slots=True)
+class _CacheEntry(object):
+ """TTLCache entry"""
+ # expiry_time is the first attribute, so that entries are sorted by expiry.
+ expiry_time = attr.ib()
+ key = attr.ib()
+ value = attr.ib()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 4c6e92beb8..311b49e18a 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -285,7 +285,10 @@ class LoggingContext(object):
self.alive = False
# if we have a parent, pass our CPU usage stats on
- if self.parent_context is not None:
+ if (
+ self.parent_context is not None
+ and hasattr(self.parent_context, '_resource_usage')
+ ):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
|