diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 49c4b85054..e3f086f1c3 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -94,6 +94,8 @@ class EventTypes(object):
ServerACL = "m.room.server_acl"
Pinned = "m.room.pinned_events"
+ Retention = "m.room.retention"
+
class RejectedReason(object):
AUTH_ERROR = "auth_error"
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 139221ad34..448e45e00f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -69,7 +69,7 @@ class FederationSenderSlaveStore(
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?"
+ sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 3e25bf5747..57174da021 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -185,7 +185,7 @@ class ApplicationServiceApi(SimpleHttpClient):
if not _is_valid_3pe_metadata(info):
logger.warning(
- "query_3pe_protocol to %s did not return a" " valid result", uri
+ "query_3pe_protocol to %s did not return a valid result", uri
)
return None
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index e77d3387ff..ca43e96bd1 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -134,7 +134,7 @@ def _load_appservice(hostname, as_info, config_filename):
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
- "Expected namespace entry in %s to be an object," " but got %s",
+ "Expected namespace entry in %s to be an object, but got %s",
ns,
regex_obj,
)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 43fad0bf8b..18f42a87f9 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -146,6 +146,8 @@ class EmailConfig(Config):
if k not in email_config:
missing.append("email." + k)
+ # public_baseurl is required to build password reset and validation links that
+ # will be emailed to users
if config.get("public_baseurl") is None:
missing.append("public_baseurl")
@@ -305,8 +307,23 @@ class EmailConfig(Config):
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
# require_transport_security: false
+ #
+ # # notif_from defines the "From" address to use when sending emails.
+ # # It must be set if email sending is enabled.
+ # #
+ # # The placeholder '%(app)s' will be replaced by the application name,
+ # # which is normally 'app_name' (below), but may be overridden by the
+ # # Matrix client application.
+ # #
+ # # Note that the placeholder must be written '%(app)s', including the
+ # # trailing 's'.
+ # #
# notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
- # app_name: Matrix
+ #
+ # # app_name defines the default value for '%(app)s' in notif_from. It
+ # # defaults to 'Matrix'.
+ # #
+ # #app_name: my_branded_matrix_server
#
# # Enable email notifications by default
# #
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 1f6dac69da..ee9614c5f7 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -106,6 +106,13 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+ raise ConfigError(
+ "The configuration option `public_baseurl` is required if "
+ "`account_threepid_delegate.msisdn` is set, such that "
+ "clients know where to submit validation tokens to. Please "
+ "configure `public_baseurl`."
+ )
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 7c9f05bde4..7ac7699676 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -170,7 +170,7 @@ class _RoomDirectoryRule(object):
self.action = action
else:
raise ConfigError(
- "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,)
+ "%s rules can only have action of 'allow' or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 00d01c43af..7a9d711669 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import List
+from typing import Dict, List, Optional
import attr
import yaml
@@ -223,7 +223,7 @@ class ServerConfig(Config):
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
- "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e
+ "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
@@ -246,6 +246,124 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ retention_config = config.get("retention")
+ if retention_config is None:
+ retention_config = {}
+
+ self.retention_enabled = retention_config.get("enabled", False)
+
+ retention_default_policy = retention_config.get("default_policy")
+
+ if retention_default_policy is not None:
+ self.retention_default_min_lifetime = retention_default_policy.get(
+ "min_lifetime"
+ )
+ if self.retention_default_min_lifetime is not None:
+ self.retention_default_min_lifetime = self.parse_duration(
+ self.retention_default_min_lifetime
+ )
+
+ self.retention_default_max_lifetime = retention_default_policy.get(
+ "max_lifetime"
+ )
+ if self.retention_default_max_lifetime is not None:
+ self.retention_default_max_lifetime = self.parse_duration(
+ self.retention_default_max_lifetime
+ )
+
+ if (
+ self.retention_default_min_lifetime is not None
+ and self.retention_default_max_lifetime is not None
+ and (
+ self.retention_default_min_lifetime
+ > self.retention_default_max_lifetime
+ )
+ ):
+ raise ConfigError(
+ "The default retention policy's 'min_lifetime' can not be greater"
+ " than its 'max_lifetime'"
+ )
+ else:
+ self.retention_default_min_lifetime = None
+ self.retention_default_max_lifetime = None
+
+ self.retention_allowed_lifetime_min = retention_config.get(
+ "allowed_lifetime_min"
+ )
+ if self.retention_allowed_lifetime_min is not None:
+ self.retention_allowed_lifetime_min = self.parse_duration(
+ self.retention_allowed_lifetime_min
+ )
+
+ self.retention_allowed_lifetime_max = retention_config.get(
+ "allowed_lifetime_max"
+ )
+ if self.retention_allowed_lifetime_max is not None:
+ self.retention_allowed_lifetime_max = self.parse_duration(
+ self.retention_allowed_lifetime_max
+ )
+
+ if (
+ self.retention_allowed_lifetime_min is not None
+ and self.retention_allowed_lifetime_max is not None
+ and self.retention_allowed_lifetime_min
+ > self.retention_allowed_lifetime_max
+ ):
+ raise ConfigError(
+ "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
+ " greater than 'allowed_lifetime_max'"
+ )
+
+ self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]]
+ for purge_job_config in retention_config.get("purge_jobs", []):
+ interval_config = purge_job_config.get("interval")
+
+ if interval_config is None:
+ raise ConfigError(
+ "A retention policy's purge jobs configuration must have the"
+ " 'interval' key set."
+ )
+
+ interval = self.parse_duration(interval_config)
+
+ shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
+
+ if shortest_max_lifetime is not None:
+ shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
+
+ longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
+
+ if longest_max_lifetime is not None:
+ longest_max_lifetime = self.parse_duration(longest_max_lifetime)
+
+ if (
+ shortest_max_lifetime is not None
+ and longest_max_lifetime is not None
+ and shortest_max_lifetime > longest_max_lifetime
+ ):
+ raise ConfigError(
+ "A retention policy's purge jobs configuration's"
+ " 'shortest_max_lifetime' value can not be greater than its"
+ " 'longest_max_lifetime' value."
+ )
+
+ self.retention_purge_jobs.append(
+ {
+ "interval": interval,
+ "shortest_max_lifetime": shortest_max_lifetime,
+ "longest_max_lifetime": longest_max_lifetime,
+ }
+ )
+
+ if not self.retention_purge_jobs:
+ self.retention_purge_jobs = [
+ {
+ "interval": self.parse_duration("1d"),
+ "shortest_max_lifetime": None,
+ "longest_max_lifetime": None,
+ }
+ ]
+
self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -761,6 +879,69 @@ class ServerConfig(Config):
# Defaults to `28d`. Set to `null` to disable clearing out of old rows.
#
#user_ips_max_age: 14d
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
"""
% locals()
)
@@ -787,14 +968,14 @@ class ServerConfig(Config):
"--print-pidfile",
action="store_true",
default=None,
- help="Print the path to the pidfile just" " before daemonizing",
+ help="Print the path to the pidfile just before daemonizing",
)
server_group.add_argument(
"--manhole",
metavar="PORT",
dest="manhole",
type=int,
- help="Turn on the twisted telnet manhole" " service on the given port.",
+ help="Turn on the twisted telnet manhole service on the given port.",
)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 272426e105..9b90c9ce04 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
- def validate_new(self, event):
+ def validate_new(self, event, config):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent)
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
"""
self.validate_builder(event)
@@ -67,6 +68,99 @@ class EventValidator(object):
Codes.INVALID_PARAM,
)
+ if event.type == EventTypes.Retention:
+ self._validate_retention(event, config)
+
+ def _validate_retention(self, event, config):
+ """Checks that an event that defines the retention policy for a room respects the
+ boundaries imposed by the server's administrator.
+
+ Args:
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
+ """
+ min_lifetime = event.content.get("min_lifetime")
+ max_lifetime = event.content.get("max_lifetime")
+
+ if min_lifetime is not None:
+ if not isinstance(min_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and min_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be lower than the minimum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and min_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if max_lifetime is not None:
+ if not isinstance(max_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'max_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and max_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be lower than the minimum allowed value"
+ " enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and max_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ min_lifetime is not None
+ and max_lifetime is not None
+ and min_lifetime > max_lifetime
+ ):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' can't be greater than 'max_lifetime",
+ errcode=Codes.BAD_JSON,
+ )
+
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d942d77a72..84d4eca041 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -73,6 +74,7 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
+ self.state = hs.get_state_handler()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
@@ -264,9 +266,6 @@ class FederationServer(FederationBase):
await self.registry.on_edu(edu_type, origin, content)
async def on_context_state_request(self, origin, room_id, event_id):
- if not event_id:
- raise NotImplementedError("Specify an event")
-
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -280,13 +279,18 @@ class FederationServer(FederationBase):
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
- resp = await self._state_resp_cache.wrap(
- (room_id, event_id),
- self._on_context_state_request_compute,
- room_id,
- event_id,
+ resp = dict(
+ await self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id,
+ event_id,
+ )
)
+ room_version = await self.store.get_room_version(room_id)
+ resp["room_version"] = room_version
+
return 200, resp
async def on_state_ids_request(self, origin, room_id, event_id):
@@ -306,7 +310,11 @@ class FederationServer(FederationBase):
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(self, room_id, event_id):
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ if event_id:
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ else:
+ pdus = (await self.state.get_current_state(room_id)).values()
+
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
return {
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 44edcabed4..d68b4bd670 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -44,7 +44,7 @@ class TransactionActions(object):
response code and response body.
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.get_received_txn_response(transaction.transaction_id, origin)
@@ -56,7 +56,7 @@ class TransactionActions(object):
Deferred
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 2b2ee8612a..4ebb0e8bc0 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -49,7 +49,7 @@ sent_pdus_destination_dist_count = Counter(
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total",
- "" "Total number of PDUs queued for sending across all destinations",
+ "Total number of PDUs queued for sending across all destinations",
)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 67b3e1ab6e..5fed626d5b 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -84,7 +84,7 @@ class TransactionManager(object):
txn_id = str(self._next_txn_id)
logger.debug(
- "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+ "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
@@ -103,7 +103,7 @@ class TransactionManager(object):
self._next_txn_id += 1
logger.info(
- "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+ "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 09baa9c57d..fefc789c85 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateServlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context.
@@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet):
return await self.handler.on_context_state_request(
origin,
context,
- parse_string_from_args(query, "event_id", None, required=True),
+ parse_string_from_args(query, "event_id", None, required=False),
)
@@ -1360,7 +1360,7 @@ class RoomComplexityServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
- FederationStateServlet,
+ FederationStateV1Servlet,
FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 69051101a6..a07d2f1a17 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -119,7 +119,7 @@ class DirectoryHandler(BaseHandler):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
- "This application service has not reserved" " this kind of alias.",
+ "This application service has not reserved this kind of alias.",
errcode=Codes.EXCLUSIVE,
)
else:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f09a0b73c8..28c12753c1 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -30,6 +30,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
UserID,
get_domain_from_id,
@@ -53,6 +54,12 @@ class E2eKeysHandler(object):
self._edu_updater = SigningKeyEduUpdater(hs, self)
+ self._is_master = hs.config.worker_app is None
+ if not self._is_master:
+ self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
+ hs
+ )
+
federation_registry = hs.get_federation_registry()
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
@@ -191,9 +198,15 @@ class E2eKeysHandler(object):
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- user_devices = yield self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
+ if self._is_master:
+ user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = yield self._user_device_resync_client(
+ user_id=user_id
+ )
+
user_devices = user_devices["devices"]
for device in user_devices:
results[user_id] = {device["device_id"]: device["keys"]}
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 0cea445f0d..f1b4424a02 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object):
rooms
session_id(string): session ID to delete keys for, for None to delete keys
for all sessions
+ Raises:
+ NotFoundError: if the backup version does not exist
Returns:
- A deferred of the deletion transaction
+ A dict containing the count and etag for the backup version
"""
# lock for consistency with uploading
with (yield self._upload_linearizer.queue(user_id)):
+ # make sure the backup version exists
+ try:
+ version_info = yield self.store.get_e2e_room_keys_version_info(
+ user_id, version
+ )
+ except StoreError as e:
+ if e.code == 404:
+ raise NotFoundError("Unknown backup version")
+ else:
+ raise
+
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+ version_etag = version_info["etag"] + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
+ )
+
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
+
@trace
@defer.inlineCallbacks
def upload_room_keys(self, user_id, version, room_keys):
@@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object):
}
}
+ Returns:
+ A dict containing the count and etag for the backup version
+
Raises:
NotFoundError: if there are no versions defined
RoomKeysVersionError: if the uploaded version is not the current version
@@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object):
else:
raise
- # go through the room_keys.
- # XXX: this should/could be done concurrently, given we're in a lock.
+ # Fetch any existing room keys for the sessions that have been
+ # submitted. Then compare them with the submitted keys. If the
+ # key is new, insert it; if the key should be updated, then update
+ # it; otherwise, drop it.
+ existing_keys = yield self.store.get_e2e_room_keys_multi(
+ user_id, version, room_keys["rooms"]
+ )
+ to_insert = [] # batch the inserts together
+ changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]):
- for session_id, session in iteritems(room["sessions"]):
- yield self._upload_room_key(
- user_id, version, room_id, session_id, session
+ for session_id, room_key in iteritems(room["sessions"]):
+ log_kv(
+ {
+ "message": "Trying to upload room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "user_id": user_id,
+ }
)
-
- @defer.inlineCallbacks
- def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
- """Upload a given room_key for a given room and session into a given
- version of the backup. Merges the key with any which might already exist.
-
- Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
- """
- log_kv(
- {
- "message": "Trying to upload room key",
- "room_id": room_id,
- "session_id": session_id,
- "user_id": user_id,
- }
- )
- # get the room_key for this particular row
- current_room_key = None
- try:
- current_room_key = yield self.store.get_e2e_room_key(
- user_id, version, room_id, session_id
- )
- except StoreError as e:
- if e.code == 404:
- log_kv(
- {
- "message": "Room key not found.",
- "room_id": room_id,
- "user_id": user_id,
- }
+ current_room_key = existing_keys.get(room_id, {}).get(session_id)
+ if current_room_key:
+ if self._should_replace_room_key(current_room_key, room_key):
+ log_kv({"message": "Replacing room key."})
+ # updates are done one at a time in the DB, so send
+ # updates right away rather than batching them up,
+ # like we do with the inserts
+ yield self.store.update_e2e_room_key(
+ user_id, version, room_id, session_id, room_key
+ )
+ changed = True
+ else:
+ log_kv({"message": "Not replacing room_key."})
+ else:
+ log_kv(
+ {
+ "message": "Room key not found.",
+ "room_id": room_id,
+ "user_id": user_id,
+ }
+ )
+ log_kv({"message": "Replacing room key."})
+ to_insert.append((room_id, session_id, room_key))
+ changed = True
+
+ if len(to_insert):
+ yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+
+ version_etag = version_info["etag"]
+ if changed:
+ version_etag = version_etag + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
)
- else:
- raise
- if self._should_replace_room_key(current_room_key, room_key):
- log_kv({"message": "Replacing room key."})
- yield self.store.set_e2e_room_key(
- user_id, version, room_id, session_id, room_key
- )
- else:
- log_kv({"message": "Not replacing room_key."})
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
@staticmethod
def _should_replace_room_key(current_room_key, room_key):
@@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object):
raise NotFoundError("Unknown backup version")
else:
raise
+
+ res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
return res
@trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0e904f2da0..a5ae7b77d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2040,8 +2040,10 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Map from (event_type, state_key) to event
- What we expect the event's auth_events to be, based on the event's
- position in the dag. I think? maybe??
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
Also NB that this function adds entries to it.
Returns:
@@ -2091,30 +2093,35 @@ class FederationHandler(BaseHandler):
origin (str):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
+
auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Also NB that this function adds entries to it.
Returns:
defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
- if event.is_state():
- event_key = (event.type, event.state_key)
- else:
- event_key = None
-
- # if the event's auth_events refers to events which are not in our
- # calculated auth_events, we need to fetch those events from somewhere.
- #
- # we start by fetching them from the store, and then try calling /event_auth/.
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
- logger.debug("Got events %s from store", have_events)
+ logger.debug("Found events %s in the store", have_events)
missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@@ -2169,15 +2176,17 @@ class FederationHandler(BaseHandler):
event.auth_event_ids()
)
except Exception:
- # FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
logger.info("Skipping auth_event fetch for outlier")
return context
- # FIXME: Assumes we have and stored all the state for all the
- # prev_events
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
@@ -2191,27 +2200,22 @@ class FederationHandler(BaseHandler):
different_auth,
)
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
+
room_version = yield self.store.get_room_version(event.room_id)
+ different_event_ids = [
+ d for d in different_auth if d in have_events and not have_events[d]
+ ]
- different_events = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.store.get_event, d, allow_none=True, allow_rejected=False
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
+ if different_event_ids:
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = yield self.store.get_events_as_list(different_event_ids)
- if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
- remote_view.update(
- {(d.type, d.state_key): d for d in different_events if d}
- )
+ remote_view.update({(d.type, d.state_key): d for d in different_events})
new_state = yield self.state_handler.resolve_events(
room_version,
@@ -2231,13 +2235,13 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state)
context = yield self._update_context_for_auth_events(
- event, context, auth_events, event_key
+ event, context, auth_events
)
return context
@defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events, event_key):
+ def _update_context_for_auth_events(self, event, context, auth_events):
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
@@ -2246,18 +2250,21 @@ class FederationHandler(BaseHandler):
context (synapse.events.snapshot.EventContext): initial event context
- auth_events (dict[(str, str)->str]): Events to update in the event
+ auth_events (dict[(str, str)->EventBase]): Events to update in the event
context.
- event_key ((str, str)): (type, state_key) for the current event.
- this will not be included in the current_state in the context.
-
Returns:
Deferred[EventContext]: new event context
"""
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key = (event.type, event.state_key)
+ else:
+ event_key = None
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
+
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)
@@ -2459,7 +2466,7 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
@@ -2574,7 +2581,7 @@ class FederationHandler(BaseHandler):
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
return (event, context)
@defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d682dc2b7a..155ed6e06a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -417,7 +417,7 @@ class EventCreationHandler(object):
403, "You must be in the room to create an alias for it"
)
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
return (event, context)
@@ -634,7 +634,7 @@ class EventCreationHandler(object):
if requester:
context.app_service = requester.app_service
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 260a4351ca..8514ddc600 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,12 +15,15 @@
# limitations under the License.
import logging
+from six import iteritems
+
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -80,6 +83,109 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
+ self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+
+ if hs.config.retention_enabled:
+ # Run the purge jobs described in the configuration file.
+ for job in hs.config.retention_purge_jobs:
+ self.clock.looping_call(
+ run_as_background_process,
+ job["interval"],
+ "purge_history_for_rooms_in_range",
+ self.purge_history_for_rooms_in_range,
+ job["shortest_max_lifetime"],
+ job["longest_max_lifetime"],
+ )
+
+ @defer.inlineCallbacks
+ def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ """Purge outdated events from rooms within the given retention range.
+
+ If a default retention policy is defined in the server's configuration and its
+ 'max_lifetime' is within this range, also targets rooms which don't have a
+ retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, it means that the range has no
+ lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, it means that the range has no
+ upper limit.
+ """
+ # We want the storage layer to to include rooms with no retention policy in its
+ # return value only if a default retention policy is defined in the server's
+ # configuration and that policy's 'max_lifetime' is either lower (or equal) than
+ # max_ms or higher than min_ms (or both).
+ if self._retention_default_max_lifetime is not None:
+ include_null = True
+
+ if min_ms is not None and min_ms >= self._retention_default_max_lifetime:
+ # The default max_lifetime is lower than (or equal to) min_ms.
+ include_null = False
+
+ if max_ms is not None and max_ms < self._retention_default_max_lifetime:
+ # The default max_lifetime is higher than max_ms.
+ include_null = False
+ else:
+ include_null = False
+
+ rooms = yield self.store.get_rooms_for_retention_period_in_range(
+ min_ms, max_ms, include_null
+ )
+
+ for room_id, retention_policy in iteritems(rooms):
+ if room_id in self._purges_in_progress_by_room:
+ logger.warning(
+ "[purge] not purging room %s as there's an ongoing purge running"
+ " for this room",
+ room_id,
+ )
+ continue
+
+ max_lifetime = retention_policy["max_lifetime"]
+
+ if max_lifetime is None:
+ # If max_lifetime is None, it means that include_null equals True,
+ # therefore we can safely assume that there is a default policy defined
+ # in the server's configuration.
+ max_lifetime = self._retention_default_max_lifetime
+
+ # Figure out what token we should start purging at.
+ ts = self.clock.time_msec() - max_lifetime
+
+ stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
+
+ r = yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ if not r:
+ logger.warning(
+ "[purge] purging events not possible: No event found "
+ "(ts %i => stream_ordering %i)",
+ ts,
+ stream_ordering,
+ )
+ continue
+
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+
+ purge_id = random_string(16)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+
+ logger.info(
+ "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id)
+ )
+
+ # We want to purge everything, including local events, and to run the purge in
+ # the background so that it's not blocking any other operation apart from
+ # other purges in the same room.
+ run_as_background_process(
+ "_purge_history", self._purge_history, purge_id, room_id, token, True,
+ )
+
def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index e9a5e46ced..13fcb408a6 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -96,7 +96,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
- "Boolean query parameter %r must be one of" " ['true', 'false']"
+ "Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 334ddaf39a..ffa7b20ca8 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -261,6 +261,18 @@ def parse_drain_configs(
)
+class StoppableLogPublisher(LogPublisher):
+ """
+ A log publisher that can tell its observers to shut down any external
+ communications.
+ """
+
+ def stop(self):
+ for obs in self._observers:
+ if hasattr(obs, "stop"):
+ obs.stop()
+
+
def setup_structured_logging(
hs,
config,
@@ -336,7 +348,7 @@ def setup_structured_logging(
# We should never get here, but, just in case, throw an error.
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
- publisher = LogPublisher(*observers)
+ publisher = StoppableLogPublisher(*observers)
log_filter = LogLevelFilterPredicate()
for namespace, namespace_config in log_config.get(
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 76ce7d8808..05fc64f409 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -17,25 +17,29 @@
Log formatters that output terse JSON.
"""
+import json
import sys
+import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing import IO
+from typing import IO, Optional
import attr
-from simplejson import dumps
from zope.interface import implementer
from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
+from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, ILogObserver, Logger
-from twisted.python.failure import Failure
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
@@ -141,12 +145,50 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata)
- return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+ return _encoder.encode(flattened) + "\n"
return FileLogObserver(outFile, formatEvent)
@attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ """
+
+ transport = attr.ib(type=ITransport)
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = None
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ event = self._buffer.popleft()
+ self.transport.write(_encoder.encode(event).encode("utf8"))
+ self.transport.write(b"\n")
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
@implementer(ILogObserver)
class TerseJSONToTCPLogObserver(object):
"""
@@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object):
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _writer = attr.ib(default=None)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None:
@@ -187,38 +230,43 @@ class TerseJSONToTCPLogObserver(object):
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
+ self._connect()
- def _write_loop(self) -> None:
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
"""
- Implement the write loop.
+ Triggers an attempt to connect then write to the remote if not already writing.
"""
- if self._writer:
+ if self._connection_waiter:
return
- self._writer = self._service.whenConnected()
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
- @self._writer.addBoth
+ @self._connection_waiter.addCallback
def writer(r):
- if isinstance(r, Failure):
- r.printTraceback(file=sys.__stderr__)
- self._writer = None
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
return
- try:
- for event in self._buffer:
- r.transport.write(
- dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
- "utf8"
- )
- )
- r.transport.write(b"\n")
- self._buffer.clear()
- except Exception as e:
- sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
-
- self._writer = False
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
@@ -277,4 +325,4 @@ class TerseJSONToTCPLogObserver(object):
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
- self._write_loop()
+ self._connect()
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e994037be6..d0879b0490 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -246,7 +246,7 @@ class HttpPusher(object):
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warning(
- "Giving up on a notification to user %s, " "pushkey %s",
+ "Giving up on a notification to user %s, pushkey %s",
self.user_id,
self.pushkey,
)
@@ -299,8 +299,7 @@ class HttpPusher(object):
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warning(
- ("Ignoring rejected pushkey %s because we" " didn't send it"),
- pk,
+ ("Ignoring rejected pushkey %s because we didn't send it"), pk,
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 1d15a06a58..b13b646bfd 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = (
- "You have a message on %(app)s from %(person)s " "in the %(room)s room..."
+ "You have a message on %(app)s from %(person)s in the %(room)s room..."
)
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
@@ -55,7 +55,7 @@ MESSAGES_FROM_PERSON_AND_OTHERS = (
"You have messages on %(app)s from %(person)s and others..."
)
INVITE_FROM_PERSON_TO_ROOM = (
- "%(person)s has invited you to join the " "%(room)s room on %(app)s..."
+ "%(person)s has invited you to join the %(room)s room on %(app)s..."
)
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 81b85352b1..28dbc6fcba 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,14 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import federation, login, membership, register, send_event
+from synapse.replication.http import (
+ devices,
+ federation,
+ login,
+ membership,
+ register,
+ send_event,
+)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
+ devices.register_servlets(hs, self)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
new file mode 100644
index 0000000000..e32aac0a25
--- /dev/null
+++ b/synapse/replication/http/devices.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for a user by contacting their
+ server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/user_device_resync/:user_id
+
+ {}
+
+ Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, e.g.:
+
+ {
+ "user_id": "@alice:example.org",
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ }
+ """
+
+ NAME = "user_device_resync"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+
+ self.device_list_updater = hs.get_device_handler().device_list_updater
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ def _serialize_payload(user_id):
+ return {}
+
+ async def _handle_request(self, request, user_id):
+ user_devices = await self.device_list_updater.user_device_resync(user_id)
+
+ return 200, user_devices
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9e45429d49..8512923eae 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
- "AccountDataStream",
- ("user_id", "room_id", "data_type", "data"), # str # str # str # dict
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
@@ -421,8 +420,8 @@ class AccountDataStream(Stream):
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content)
- for stream_id, user_id, account_data_type, content in global_results
+ (stream_id, user_id, None, account_data_type)
+ for stream_id, user_id, account_data_type in global_results
)
return results
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index d596786430..d83ac8e3c5 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet):
if room_id:
body = {"rooms": {room_id: body}}
- yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
- return 200, {}
+ ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
+ return 200, ret
@defer.inlineCallbacks
def on_GET(self, request, room_id, session_id):
@@ -239,10 +239,10 @@ class RoomKeysServlet(RestServlet):
user_id = requester.user.to_string()
version = parse_string(request, "version")
- yield self.e2e_room_keys_handler.delete_room_keys(
+ ret = yield self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- return 200, {}
+ return 200, ret
class RoomKeysNewVersionServlet(RestServlet):
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 87343d9db9..fb0d02aa83 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -122,7 +122,7 @@ class PreviewUrlResource(DirectServeResource):
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
- "Matching attrib '%s' with value '%s' against" " pattern '%s'",
+ "Matching attrib '%s' with value '%s' against pattern '%s'",
attrib,
value,
pattern,
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 415e9c17d8..5736c56032 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -54,7 +54,7 @@ class ConsentServerNotices(object):
)
if "body" not in self._server_notice_content:
raise ConfigError(
- "user_consent server_notice_consent must contain a 'body' " "key."
+ "user_consent server_notice_consent must contain a 'body' key."
)
self._consent_uri_builder = ConsentURIBuilder(hs.config)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab596fa68d..459901ac60 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -409,16 +409,15 @@ class SQLBaseStore(object):
i = 0
N = 5
while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.database_engine,
+ after_callbacks,
+ exception_callbacks,
+ )
try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn,
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- r = func(txn, *args, **kwargs)
+ r = func(cursor, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
@@ -456,6 +455,40 @@ class SQLBaseStore(object):
)
continue
raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
@@ -851,7 +884,7 @@ class SQLBaseStore(object):
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 6afbfc0d74..22093484ed 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -184,14 +184,14 @@ class AccountDataWorkerStore(SQLBaseStore):
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, type string, and content string.
+ room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
def get_updated_account_data_txn(txn):
sql = (
- "SELECT stream_id, user_id, account_data_type, content"
+ "SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
@@ -199,7 +199,7 @@ class AccountDataWorkerStore(SQLBaseStore):
global_results = txn.fetchall()
sql = (
- "SELECT stream_id, user_id, room_id, account_data_type, content"
+ "SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 96cd0fb77a..a23744f11c 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+ sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 1cbbae5b63..113224fd7c 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
- def get_e2e_room_key(self, user_id, version, room_id, session_id):
- """Get the encrypted E2E room key for a given session from a given
- backup version of room_keys. We only store the 'best' room key for a given
- session at a given time, as determined by the handler.
-
- Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): the ID of the room whose keys we're querying.
- This is a bit redundant as it's implied by the session_id, but
- we include for consistency with the rest of the API.
- session_id(str): the session whose room_key we're querying.
-
- Returns:
- A deferred dict giving the session_data and message metadata for
- this room key.
- """
-
- row = yield self._simple_select_one(
- table="e2e_room_keys",
- keyvalues={
- "user_id": user_id,
- "version": version,
- "room_id": room_id,
- "session_id": session_id,
- },
- retcols=(
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
- ),
- desc="get_e2e_room_key",
- )
-
- row["session_data"] = json.loads(row["session_data"])
-
- return row
-
- @defer.inlineCallbacks
- def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
- """Replaces or inserts the encrypted E2E room key for a given session in
- a given backup
+ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ """Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id(str): the user whose backup we're setting
@@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self._simple_upsert(
+ yield self._simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"room_id": room_id,
"session_id": session_id,
},
- values={
+ updatevalues={
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
"session_data": json.dumps(room_key["session_data"]),
},
- lock=False,
+ desc="update_e2e_room_key",
)
- log_kv(
- {
- "message": "Set room key",
- "room_id": room_id,
- "session_id": session_id,
- "room_key": room_key,
- }
+
+ @defer.inlineCallbacks
+ def add_e2e_room_keys(self, user_id, version, room_keys):
+ """Bulk add room keys to a given backup.
+
+ Args:
+ user_id (str): the user whose backup we're adding to
+ version (str): the version ID of the backup for the set of keys we're adding to
+ room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+ (roomID, sessionID, keyData)
+ """
+
+ values = []
+ for (room_id, session_id, room_key) in room_keys:
+ values.append(
+ {
+ "user_id": user_id,
+ "version": version,
+ "room_id": room_id,
+ "session_id": session_id,
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
+ }
+ )
+ log_kv(
+ {
+ "message": "Set room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "room_key": room_key,
+ }
+ )
+
+ yield self._simple_insert_many(
+ table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
@@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room, or a given session.
Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup for the set of keys we're querying
+ room_id (str): Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id (str): Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
@@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
+ def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ """Get multiple room keys at a time. The difference between this function and
+ get_e2e_room_keys is that this function can be used to retrieve
+ multiple specific keys at a time, whereas get_e2e_room_keys is used for
+ getting all the keys in a backup version, all the keys for a room, or a
+ specific key.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ room_keys (dict[str, dict[str, iterable[str]]]): a map from
+ room ID -> {"session": [session ids]} indicating the session IDs
+ that we want to query
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ """
+
+ return self.runInteraction(
+ "get_e2e_room_keys_multi",
+ self._get_e2e_room_keys_multi_txn,
+ user_id,
+ version,
+ room_keys,
+ )
+
+ @staticmethod
+ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ if not room_keys:
+ return {}
+
+ where_clauses = []
+ params = [user_id, version]
+ for room_id, room in room_keys.items():
+ sessions = list(room["sessions"])
+ if not sessions:
+ continue
+ params.append(room_id)
+ params.extend(sessions)
+ where_clauses.append(
+ "(room_id = ? AND session_id IN (%s))"
+ % (",".join(["?" for _ in sessions]),)
+ )
+
+ # check if we're actually querying something
+ if not where_clauses:
+ return {}
+
+ sql = """
+ SELECT room_id, session_id, first_message_index, forwarded_count,
+ is_verified, session_data
+ FROM e2e_room_keys
+ WHERE user_id = ? AND version = ? AND (%s)
+ """ % (
+ " OR ".join(where_clauses)
+ )
+
+ txn.execute(sql, params)
+
+ ret = {}
+
+ for row in txn:
+ room_id = row[0]
+ session_id = row[1]
+ ret.setdefault(room_id, {})
+ ret[room_id][session_id] = {
+ "first_message_index": row[2],
+ "forwarded_count": row[3],
+ "is_verified": row[4],
+ "session_data": json.loads(row[5]),
+ }
+
+ return ret
+
+ def count_e2e_room_keys(self, user_id, version):
+ """Get the number of keys in a backup version.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ """
+
+ return self._simple_select_one_onecol(
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": version},
+ retcol="COUNT(*)",
+ desc="count_e2e_room_keys",
+ )
+
@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
+ etag(int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn):
@@ -236,10 +316,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
- retcols=("version", "algorithm", "auth_data"),
+ retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
+ if result["etag"] is None:
+ result["etag"] = 0
return result
return self.runInteraction(
@@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- def update_e2e_room_keys_version(self, user_id, version, info):
+ def update_e2e_room_keys_version(
+ self, user_id, version, info=None, version_etag=None
+ ):
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
- info(dict): the new backup version info to store
+ info (dict): the new backup version info to store. If None, then
+ the backup version info is not updated
+ version_etag (Optional[int]): etag of the keys in the backup. If
+ None, then the etag is not updated
"""
+ updatevalues = {}
- return self._simple_update(
- table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
- updatevalues={"auth_data": json.dumps(info["auth_data"])},
- desc="update_e2e_room_keys_version",
- )
+ if info is not None and "auth_data" in info:
+ updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ if version_etag is not None:
+ updatevalues["etag"] = version_etag
+
+ if updatevalues:
+ return self._simple_update(
+ table="e2e_room_keys_versions",
+ keyvalues={"user_id": user_id, "version": version},
+ updatevalues=updatevalues,
+ desc="update_e2e_room_keys_version",
+ )
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 073412a78d..d8ad59ad93 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -138,9 +138,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result.setdefault(user_id, {})[device_id] = None
# get signatures on the device
- signature_sql = (
- "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s"
- ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
+ signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
txn.execute(signature_sql, signature_query_params)
rows = self.cursor_to_dict(txn)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 878f7568a6..2737a1d3ae 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -713,9 +713,7 @@ class EventsStore(
metadata_json = encode_json(event.internal_metadata.get_dict())
- sql = (
- "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
- )
+ sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
# Add an entry to the ex_outlier_stream table to replicate the
@@ -732,7 +730,7 @@ class EventsStore(
},
)
- sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+ sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
txn.execute(sql, (False, event.event_id))
# Update the event_backward_extremities table now that this
@@ -929,6 +927,9 @@ class EventsStore(
elif event.type == EventTypes.Redaction:
# Insert into the redactions table.
self._store_redaction(txn, event)
+ elif event.type == EventTypes.Retention:
+ # Update the room_retention table.
+ self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
@@ -1479,7 +1480,7 @@ class EventsStore(
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index a2a2a67927..f05ace299a 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -55,7 +55,7 @@ class FilteringStore(SQLBaseStore):
if filter_id_response is not None:
return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+ sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 84b5f3ad5e..0f2887bdce 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -337,7 +337,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
if len(media_ids) == 0:
return
- sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
@@ -365,11 +365,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
- sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0c24430f28..8b17334ff4 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -280,7 +280,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return (r[0:5] + (json.loads(r[5]),) for r in txn)
+ return list(r[0:5] + (json.loads(r[5]),) for r in txn)
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 89147ad511..98cf6427c3 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,7 +19,6 @@ import logging
import re
from six import iterkeys
-from six.moves import range
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -377,9 +376,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def f(txn):
- sql = (
- "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
- )
+ sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
return dict(txn)
@@ -484,12 +481,8 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
Gets the localpart of the next generated user ID.
- Generated user IDs are integers, and we aim for them to be as small as
- we can. Unfortunately, it's possible some of them are already taken by
- existing users, and there may be gaps in the already taken range. This
- function returns the start of the first allocatable gap. This is to
- avoid the case of ID 1000 being pre-allocated and starting at 1001 while
- 0-999 are available.
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
@@ -499,15 +492,14 @@ class RegistrationWorkerStore(SQLBaseStore):
regex = re.compile(r"^@(\d+):")
- found = set()
+ max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
- found.add(int(match.group(1)))
- for i in range(len(found) + 1):
- if i not in found:
- return i
+ max_found = max(int(match.group(1)), max_found)
+
+ return max_found + 1
return (
(
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 67bb1b6f60..b7f9024811 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -19,10 +19,13 @@ import logging
import re
from typing import Optional, Tuple
+from six import integer_types
+
from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.search import SearchStore
@@ -300,8 +303,141 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.cursor_to_dict(txn)
+
+ ret = yield self.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
+
class RoomStore(RoomWorkerStore, SearchStore):
+ def __init__(self, db_conn, hs):
+ super(RoomStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
+ self.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention,
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn,
+ )
+
+ if end:
+ yield self._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@@ -502,6 +638,35 @@ class RoomStore(RoomWorkerStore, SearchStore):
txn, event, "content.body", event.content["body"]
)
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
@@ -683,3 +848,89 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+ room_id TEXT,
+ event_id TEXT,
+ min_lifetime BIGINT,
+ max_lifetime BIGINT,
+
+ PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('insert_room_retention', '{}');
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 8780fdd989..9ae4a913a1 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -616,7 +616,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
- "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+ "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
)
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 10d1887f75..aa24339717 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -83,9 +83,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(txn, tag_ids):
- sql = (
- "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
- )
+ sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 2e7753820e..731e1c9d9c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -447,7 +447,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name),
)
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 02994ab2a5..cd56cd91ed 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -88,9 +88,12 @@ class PaginationConfig(object):
raise SynapseError(400, "Invalid request.")
def __repr__(self):
- return (
- "PaginationConfig(from_tok=%r, to_tok=%r," " direction=%r, limit=%r)"
- ) % (self.from_token, self.to_token, self.direction, self.limit)
+ return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
+ self.from_token,
+ self.to_token,
+ self.direction,
+ self.limit,
+ )
def get_source_config(self, source_name):
keyname = "%s_key" % source_name
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 8c843febd8..4d4141dacc 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -86,6 +86,14 @@ def filter_events_for_client(
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
+ room_ids = set(e.room_id for e in events)
+ retention_policies = {}
+
+ for room_id in room_ids:
+ retention_policies[room_id] = yield storage.main.get_retention_policy_for_room(
+ room_id
+ )
+
def allowed(event):
"""
Args:
@@ -103,6 +111,18 @@ def filter_events_for_client(
if not event.is_state() and event.sender in ignore_list:
return None
+ # Don't try to apply the room's retention policy if the event is a state event, as
+ # MSC1763 states that retention is only considered for non-state events.
+ if not event.is_state():
+ retention_policy = retention_policies[event.room_id]
+ max_lifetime = retention_policy.get("max_lifetime")
+
+ if max_lifetime is not None:
+ oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
+
+ if event.origin_server_ts < oldest_allowed_ts:
+ return None
+
if event.event_id in always_include_ids:
return event
|