diff --git a/synapse/__init__.py b/synapse/__init__.py
index f99de2f3f3..fc2a6e4ee6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.6.1"
+__version__ = "1.7.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 448e45e00f..f24920a7d6 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
+from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
@@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
SlavedDeviceStore,
SlavedPresenceStore,
):
- def __init__(self, db_conn, hs):
- super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
# We pull out the current federation stream position now so that we
# always have a known value for the federation position in memory so
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 9f81a857ab..032010600a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -68,9 +68,9 @@ from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
-from synapse.storage import DataStore, are_all_users_on_domain
+from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
@@ -294,22 +294,6 @@ class SynapseHomeServer(HomeServer):
else:
logger.warning("Unrecognized listener type: %s", listener["type"])
- def run_startup_checks(self, db_conn, database_engine):
- all_users_native = are_all_users_on_domain(
- db_conn.cursor(), database_engine, self.hostname
- )
- if not all_users_native:
- quit_with_error(
- "Found users in database not native to %s!\n"
- "You cannot changed a synapse server_name after it's been configured"
- % (self.hostname,)
- )
-
- try:
- database_engine.check_database(db_conn.cursor())
- except IncorrectDatabaseSetup as e:
- quit_with_error(str(e))
-
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
@@ -357,16 +341,12 @@ def setup(config_options):
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
- logger.info("Preparing database: %s...", config.database_config["name"])
+ logger.info("Setting up server")
try:
- with hs.get_db_conn(run_new_connection=False) as db_conn:
- prepare_database(db_conn, database_engine, config=config)
- database_engine.on_new_connection(db_conn)
-
- hs.run_startup_checks(db_conn, database_engine)
-
- db_conn.commit()
+ hs.setup()
+ except IncorrectDatabaseSetup as e:
+ quit_with_error(str(e))
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"
@@ -375,9 +355,6 @@ def setup(config_options):
)
sys.exit(1)
- logger.info("Database prepared in %s.", config.database_config["name"])
-
- hs.setup()
hs.setup_master()
@defer.inlineCallbacks
@@ -542,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version
#
- stats["database_engine"] = hs.database_engine.module.__name__
- stats["database_server_version"] = hs.database_engine.server_version
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 01a5ffc363..dd52a9fc2d 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -33,6 +33,7 @@ from synapse.replication.slave.storage.account_data import SlavedAccountDataStor
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage import DataStore
@@ -45,7 +46,11 @@ logger = logging.getLogger("synapse.app.pusher")
class PusherSlaveStore(
- SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore
+ SlavedEventStore,
+ SlavedPusherStore,
+ SlavedReceiptsStore,
+ SlavedAccountDataStore,
+ RoomStore,
):
update_pusher_last_stream_ordering_and_success = __func__(
DataStore.update_pusher_last_stream_ordering_and_success
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index b6d4481725..c01fb34a9b 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
@@ -60,8 +61,8 @@ class UserDirectorySlaveStore(
UserDirectoryStore,
BaseSlavedStore,
):
- def __init__(self, db_conn, hs):
- super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
+import logging
from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
- map_username_to_mxid_localpart,
- mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+ "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True
- self.saml2_mxid_source_attribute = saml2_config.get(
- "mxid_source_attribute", "uid"
- )
-
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
- saml2_config_dict = self._default_saml_config_dict()
+ # user_mapping_provider may be None if the key is present but has no value
+ ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+ # Use the default user mapping provider if not set
+ ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+ # Ensure a config is present
+ ump_dict["config"] = ump_dict.get("config") or {}
+
+ if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+ # Load deprecated options for use by the default module
+ old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+ if old_mxid_source_attribute:
+ logger.warning(
+ "The config option saml2_config.mxid_source_attribute is deprecated. "
+ "Please use saml2_config.user_mapping_provider.config"
+ ".mxid_source_attribute instead."
+ )
+ ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+ old_mxid_mapping = saml2_config.get("mxid_mapping")
+ if old_mxid_mapping:
+ logger.warning(
+ "The config option saml2_config.mxid_mapping is deprecated. Please "
+ "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+ )
+ ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+ # Retrieve an instance of the module's class
+ # Pass the config dictionary to the module for processing
+ (
+ self.saml2_user_mapping_provider_class,
+ self.saml2_user_mapping_provider_config,
+ ) = load_module(ump_dict)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ # Note parse_config() is already checked during the call to load_module
+ required_methods = [
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.saml2_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by saml2_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # Get the desired saml auth response attributes from the module
+ saml2_config_dict = self._default_saml_config_dict(
+ *self.saml2_user_mapping_provider_class.get_saml_attributes(
+ self.saml2_user_mapping_provider_config
+ )
+ )
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
@@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
- mapping = saml2_config.get("mxid_mapping", "hexencode")
- try:
- self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
- except KeyError:
- raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
- def _default_saml_config_dict(self):
+ def _default_saml_config_dict(
+ self, required_attributes: set, optional_attributes: set
+ ):
+ """Generate a configuration dictionary with required and optional attributes that
+ will be needed to process new user registration
+
+ Args:
+ required_attributes: SAML auth response attributes that are
+ necessary to function
+ optional_attributes: SAML auth response attributes that can be used to add
+ additional information to Synapse user accounts, but are not required
+
+ Returns:
+ dict: A SAML configuration dictionary
+ """
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
- required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
- optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
#
#config_path: "%(config_dir_path)s/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
-
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
+
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
@@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
-
-
-DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
- username = username.lower()
- username = DOT_REPLACE_PATTERN.sub(".", username)
-
- # regular mxids aren't allowed to start with an underscore either
- username = re.sub("^_", "", username)
- return username
-
-
-MXID_MAPPER_MAP = {
- "hexencode": map_username_to_mxid_localpart,
- "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns:
if the auth checks pass.
"""
+ assert isinstance(auth_events, dict)
+
if do_size_check:
_check_size_limits(event)
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
- if auth_events is None:
- # Oh, we don't know what the state of the room was, so we
- # are trusting that this is allowed (at least for now)
- logger.warning("Trusting event: %s", event.event_id)
- return
-
if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
import itertools
import logging
-from six.moves import range
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
return signed_pdu
@defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+ """Calls the /state_ids endpoint to fetch the state at a particular point
+ in the room, and the auth events for the given event
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids)
"""
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
- fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
- destination, room_id, set(state_event_ids + auth_event_ids)
- )
-
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
- failed_to_fetch,
- )
-
- event_map = {ev.event_id: ev for ev in fetched_events}
+ if not isinstance(state_event_ids, list) or not isinstance(
+ auth_event_ids, list
+ ):
+ raise Exception("invalid response from /state_ids")
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- return pdus, auth_chain
-
- @defer.inlineCallbacks
- def get_events_from_store_or_dest(self, destination, room_id, event_ids):
- """Fetch events from a remote destination, checking if we already have them.
-
- Args:
- destination (str)
- room_id (str)
- event_ids (list)
-
- Returns:
- Deferred: A deferred resolving to a 2-tuple where the first is a list of
- events and the second is a list of event ids that we failed to fetch.
- """
- seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = list(seen_events.values())
-
- failed_to_fetch = set()
-
- missing_events = set(event_ids)
- for k in seen_events:
- missing_events.discard(k)
-
- if not missing_events:
- return signed_events, failed_to_fetch
-
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- event_ids,
- )
-
- room_version = yield self.store.get_room_version(room_id)
-
- batch_size = 20
- missing_events = list(missing_events)
- for i in range(0, len(missing_events), batch_size):
- batch = set(missing_events[i : i + batch_size])
-
- deferreds = [
- run_in_background(
- self.get_pdu,
- destinations=[destination],
- event_id=e_id,
- room_version=room_version,
- )
- for e_id in batch
- ]
-
- res = yield make_deferred_yieldable(
- defer.DeferredList(deferreds, consumeErrors=True)
- )
- for success, result in res:
- if success and result:
- signed_events.append(result)
- batch.discard(result.event_id)
-
- # We removed all events we successfully fetched from `batch`
- failed_to_fetch.update(batch)
-
- return signed_events, failed_to_fetch
+ return state_event_ids, auth_event_ids
@defer.inlineCallbacks
@log_function
@@ -609,13 +526,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ content = yield self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
@@ -683,6 +594,44 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
+ def _do_send_join(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_join_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
+
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_join_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
+
+ @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
room_version = yield self.store.get_room_version(room_id)
@@ -791,18 +740,50 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_leave(
+ content = yield self._do_send_leave(destination, pdu)
+
+ logger.debug("Got content: %s", content)
+ return None
+
+ return self._try_destination_list("send_leave", destinations, send_request)
+
+ @defer.inlineCallbacks
+ def _do_send_leave(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- logger.debug("Got content: %s", content)
- return None
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
- return self._try_destination_list("send_leave", destinations, send_request)
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_leave_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
def get_public_rooms(
self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- return (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ }
async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
- return 200, {}
+ return {}
async def on_event_auth(self, origin, room_id, event_id):
with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_join(self, destination, room_id, event_id, content):
+ def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_leave(self, destination, room_id, event_id, content):
+ def send_join_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def send_leave_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ # we want to do our best to send this through. The problem is
+ # that if it fails, we won't retry it later, so if the remote
+ # server was just having a momentary blip, the room will be out of
+ # sync.
+ ignore_backoff=True,
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_V2_PREFIX
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, content
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
return await self.handler.on_event_auth(origin, context, event_id)
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PREFIX = FEDERATION_V2_PREFIX
+
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
- FederationSendJoinServlet,
- FederationSendLeaveServlet,
+ FederationV1SendJoinServlet,
+ FederationV2SendJoinServlet,
+ FederationV1SendLeaveServlet,
+ FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..20ec1ca01b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
- ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
+ async def get_pagination_rows(self, user, config, key):
return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
- "send_renewals", self.send_renewal_emails
+ "send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
+ await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
- yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
+ await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 28c12753c1..57a10daefd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,7 +264,6 @@ class E2eKeysHandler(object):
return ret
- @defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -284,35 +283,14 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
- for user_id in query:
- # XXX: consider changing the store functions to allow querying
- # multiple users simultaneously.
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "master", from_user_id
- )
- if key:
- master_keys[user_id] = key
-
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "self_signing", from_user_id
- )
- if key:
- self_signing_keys[user_id] = key
-
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- if from_user_id == user_id:
- key = yield self.store.get_e2e_cross_signing_key(
- user_id, "user_signing", from_user_id
- )
- if key:
- user_signing_keys[user_id] = key
-
- return {
- "master_keys": master_keys,
- "self_signing_keys": self_signing_keys,
- "user_signing_keys": user_signing_keys,
- }
+ # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
+ return defer.succeed(
+ {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
+ )
@trace
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..62985bab9f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
-
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ ) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
+ state_map = await resolve_events_with_store(
room_version,
state_maps,
event_map,
@@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
+ await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -583,8 +567,116 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = await self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
+
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
+
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
+
+ return remote_state, auth_chain
+
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
+ """Fetch events from a remote destination, checking if we already have them.
+
+ Args:
+ destination
+ room_id
+ event_ids
+
+ Returns:
+ map from event_id to event
+ """
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if not missing_events:
+ return fetched_events
+
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ event_ids,
+ )
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # XXX 20 requests at once? really?
+ for batch in batch_iter(missing_events, 20):
+ deferreds = [
+ run_in_background(
+ self.federation_client.get_pdu,
+ destinations=[destination],
+ event_id=e_id,
+ room_version=room_version,
+ )
+ for e_id in batch
+ ]
+
+ res = await make_deferred_yieldable(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
+ for success, result in res:
+ if success and result:
+ fetched_events[result.event_id] = result
+
+ return fetched_events
+
+ async def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
@@ -599,7 +691,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = yield self.store.have_seen_events(event_ids)
+ seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -626,18 +718,18 @@ class FederationHandler(BaseHandler):
event_id,
[e.event.event_id for e in event_infos],
)
- yield self._handle_new_events(origin, event_infos)
+ await self._handle_new_events(origin, event_infos)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -650,11 +742,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -662,11 +754,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -683,9 +774,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -700,7 +791,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -748,7 +839,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
@@ -762,7 +853,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch,
)
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -789,7 +880,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_seen_events(
+ seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -851,7 +942,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -867,16 +958,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -908,15 +998,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -946,7 +1038,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -985,12 +1077,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1021,7 +1112,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1035,7 +1126,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1045,7 +1136,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1061,7 +1152,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1210,7 +1301,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1238,8 +1329,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1255,7 +1345,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1453,7 +1543,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2814,7 +2904,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
- @defer.inlineCallbacks
- def get_messages(
+ async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_pagination()
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
+ with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
- ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+ ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.state_store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = yield self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import inspect
import logging
import threading
import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
if not isinstance(deferred, defer.Deferred):
return deferred
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 6ece1d6745..b91a528245 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ import six
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
@@ -35,8 +36,8 @@ def __func__(inp):
class BaseSlavedStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(BaseSlavedStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bc2f6a12ae..ebe94909cb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id"
)
- super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+ super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index b4f58cea19..fbf996e33a 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
@@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedClientIpStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 9fb6c5c6ff..0c237c6e0f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,13 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_max_stream_id", "stream_id"
)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index de50748c30..dc625e0d7a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedDeviceStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d0a0eaf75b..29f35b9915 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.storage.data_stores.main.stream import StreamWorkerStore
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -59,13 +60,13 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
- super(SlavedEventStore, self).__init__(db_conn, hs)
+ super(SlavedEventStore, self).__init__(database, db_conn, hs)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 5c84ebd125..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,13 +14,14 @@
# limitations under the License.
from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedFilteringStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 28a46edd28..69a4ae42f9 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.storage import DataStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
@@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedGroupServerStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 747ced0c84..f552e7c972 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,6 +15,7 @@
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
@@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedPresenceStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 3655f05e54..eebd5a1fb6 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,17 +15,18 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.database import Database
from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
- super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+ super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
def get_push_rules_stream_token(self):
return (
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index b4331d0799..f22c2d44a3 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,14 +15,15 @@
# limitations under the License.
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedPusherStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 43d823c601..d40dc6e1f5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
- super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+ super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index d9ad386b28..3a20f45316 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,14 +14,15 @@
# limitations under the License.
from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(RoomStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1eac8a44c5..e7fe50ed72 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -15,6 +15,7 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
@@ -103,11 +104,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
content = parse_json_object_from_request(request)
try:
- new_name = content["avatar_url"]
- except Exception:
- return 400, "Unable to parse name"
-
- await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+ new_avatar_url = content["avatar_url"]
+ except KeyError:
+ raise SynapseError(
+ 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM
+ )
+
+ await self.profile_handler.set_avatar_url(
+ user, requester, new_avatar_url, is_admin
+ )
return 200, {}
diff --git a/synapse/server.py b/synapse/server.py
index be9af7f986..2db3dab221 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -238,8 +238,7 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
with self.get_db_conn() as conn:
- datastore = self.DATASTORE_CLASS(conn, self)
- self.datastores = DataStores(datastore, conn, self)
+ self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
conn.commit()
self.start_time = int(self.get_clock().time())
logger.info("Finished setting up.")
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..3e6d62eef1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -645,7 +646,7 @@ class StateResolutionStore(object):
return self.store.get_events(
event_ids,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=allow_rejected,
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8fb18203dc..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,15 +49,3 @@ class Storage(object):
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
- sql = database_engine.convert_param_style(
- "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
- )
- pat = "%:" + domain
- txn.execute(sql, (pat,))
- num_not_matching = txn.fetchall()[0][0]
- if num_not_matching == 0:
- return True
- return False
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7e27d4e97..b7637b5dc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,11 +37,11 @@ class SQLBaseStore(object):
per data store (and not one per physical database).
"""
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
- self.db = Database(hs) # In future this will be passed in
+ self.db = database
self.rand = random.SystemRandom()
def _invalidate_state_caches(self, room_id, members_changed):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index a9a13a2658..4f97fd5ab6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -379,7 +379,7 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql)
c.execute(sql)
- if isinstance(self.db.database_engine, engines.PostgresEngine):
+ if isinstance(self.db.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cb184a98cc..cafedd5c0d 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.database import Database
+from synapse.storage.prepare_database import prepare_database
+
class DataStores(object):
"""The various data stores.
@@ -20,7 +23,14 @@ class DataStores(object):
These are low level interfaces to physical databases.
"""
- def __init__(self, main_store, db_conn, hs):
- # Note we pass in the main store here as workers use a different main
+ def __init__(self, main_store_class, db_conn, hs):
+ # Note we pass in the main store class here as workers use a different main
# store.
- self.main = main_store
+ database = Database(hs)
+
+ # Check that db is correctly configured.
+ database.engine.check_database(db_conn.cursor())
+
+ prepare_database(db_conn, database.engine, config=hs.config)
+
+ self.main = main_store_class(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 6adb8adb04..c577c0df5f 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
import time
from synapse.api.constants import PresenceState
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
ChainedIdGenerator,
@@ -111,10 +112,20 @@ class DataStore(
RelationsStore,
CacheInvalidationStore,
):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
+ self.database_engine = database.engine
+
+ all_users_native = are_all_users_on_domain(
+ db_conn.cursor(), database.engine, hs.hostname
+ )
+ if not all_users_native:
+ raise Exception(
+ "Found users in database not native to %s!\n"
+ "You cannot changed a synapse server_name after it's been configured"
+ % (hs.hostname,)
+ )
self._stream_id_gen = StreamIdGenerator(
db_conn,
@@ -169,7 +180,7 @@ class DataStore(
else:
self._cache_id_gen = None
- super(DataStore, self).__init__(db_conn, hs)
+ super(DataStore, self).__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
@@ -554,3 +565,15 @@ class DataStore(
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+ sql = database_engine.convert_param_style(
+ "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+ )
+ pat = "%:" + domain
+ txn.execute(sql, (pat,))
+ num_not_matching = txn.fetchall()[0][0]
+ if num_not_matching == 0:
+ return True
+ return False
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 1a3dd7be8d..46b494b334 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+ super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
- super(AccountDataStore, self).__init__(db_conn, hs)
+ super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 6b2e12719c..b2f39649fd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+ super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 7b470a58f1..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
@@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update(
"user_ips_device_index",
@@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpStore(ClientIpBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
- super(ClientIpStore, self).__init__(db_conn, hs)
+ super(ClientIpStore, self).__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
@@ -450,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self.db.simple_upsert_txn(
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={
+ updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
- lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 3c9f09301a..85cfa16850 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, db_conn, hs):
- super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update(
"device_inbox_stream_index",
@@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, db_conn, hs):
- super(DeviceInboxStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 91ddaf137e..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -31,6 +31,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import (
@@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update(
"device_lists_stream_idx",
@@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 31d2e8eb28..1f517e8fad 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, db_conn, hs):
- super(EventFederationStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventFederationStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index eec054cd48..9988a6d3fc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, db_conn, hs):
- super(EventPushActionsStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d644c82784..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.data_stores.main.event_federation import EventFederationStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -95,8 +96,8 @@ def _retry_on_integrity_error(func):
class EventsStore(
StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
):
- def __init__(self, db_conn, hs):
- super(EventsStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsStore, self).__init__(database, db_conn, hs)
# Collect metrics on the number of forward extremities that exist.
# Counter of number of extremities to count
@@ -1038,20 +1039,25 @@ class EventsStore(
},
)
- @defer.inlineCallbacks
- def _censor_redactions(self):
+ async def _censor_redactions(self):
"""Censors all redactions older than the configured period that haven't
been censored yet.
By censor we mean update the event_json table with the redacted event.
-
- Returns:
- Deferred
"""
if self.hs.config.redaction_retention_period is None:
return
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
# We fetch all redactions that:
@@ -1073,15 +1079,15 @@ class EventsStore(
LIMIT ?
"""
- rows = yield self.db.execute(
+ rows = await self.db.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
)
updates = []
for redaction_id, event_id in rows:
- redaction_event = yield self.get_event(redaction_id, allow_none=True)
- original_event = yield self.get_event(
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
event_id, allow_rejected=True, allow_none=True
)
@@ -1114,7 +1120,7 @@ class EventsStore(
updatevalues={"have_censored": True},
)
- yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index cb1fc30c31..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
- def __init__(self, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
@@ -89,6 +90,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"event_store_labels", self._event_store_labels
)
+ self.db.updates.register_background_index_update(
+ "redactions_have_censored_ts_idx",
+ index_name="redactions_have_censored_ts",
+ table="redactions",
+ columns=["received_ts"],
+ where_clause="NOT have_censored",
+ )
+
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e041fc5eac..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
+from typing import List, Optional
from canonicaljson import json
+from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -33,6 +35,7 @@ from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
from synapse.util.caches.descriptors import Cache
@@ -54,9 +57,19 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
class EventsWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(EventsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsWorkerStore, self).__init__(database, db_conn, hs)
self._get_event_cache = Cache(
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
@@ -124,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(
self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
+ event_id: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_id: The event_id of the event to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
+ allow_rejected: If True return rejected events.
+ allow_none: If True, return None if no event found, if
False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
+ check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@@ -153,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -172,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -202,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True, return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -318,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
events.append(event)
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 03c9c6f8ae..80ca36dedf 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+ database, db_conn, hs
+ )
self.db.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 34bf3a1880..27158534cb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,6 +17,7 @@ import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore):
- def __init__(self, dbconn, hs):
- super(MonthlyActiveUsersStore, self).__init__(None, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
# Do not add more reserved users than the total allowable number
self.db.new_transaction(
- dbconn,
+ db_conn,
"initialise_mau_threepids",
[],
[],
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index de682cc63a..5ba13aa973 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -72,8 +73,8 @@ class PushRulesWorkerStore(
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index ac2d45bd5c..96e54d145e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,6 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(db_conn, hs)
+ super(ReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 1ef143c6d8..5e8ecac0ea 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -36,8 +37,8 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
@@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, db_conn, hs):
- super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
@@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(RegistrationBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(RegistrationStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index da42dae243..aa476d0fbf 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -45,6 +46,11 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
def get_room(self, room_id):
"""Retrieve a room.
@@ -361,8 +367,8 @@ class RoomWorkerStore(SQLBaseStore):
class RoomBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.config = hs.config
@@ -440,8 +446,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, db_conn, hs):
- super(RoomStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomStore, self).__init__(database, db_conn, hs)
self.config = hs.config
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 929f6b0d39..92e3b9c512 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -32,6 +32,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
GetRoomsForUserWithStreamOrdering,
@@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
# background update still running?
@@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberStore, self).__init__(database, db_conn, hs)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
*/
ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
*/
ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
INSERT INTO background_updates (update_name, progress_json) VALUES
('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ffa1817e64..dfb46ee0f8 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
@@ -42,8 +44,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, db_conn, hs):
- super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
@@ -342,8 +344,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(SearchStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchStore, self).__init__(database, db_conn, hs)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
@@ -452,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 7d5a9f8128..dcc6b43cdf 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -28,6 +28,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.util.caches import get_cache_factor_for, intern_string
@@ -213,8 +214,8 @@ class StateGroupWorkerStore(
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- def __init__(self, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -277,7 +278,7 @@ class StateGroupWorkerStore(
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
+ """Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
@@ -290,14 +291,22 @@ class StateGroupWorkerStore(
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
+ None if a predecessor key is not found, or is not a dictionary.
+
Raises:
- NotFoundError if the room is unknown
+ NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- # Return predecessor if present
- return create_event.content.get("predecessor", None)
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, dict):
+ return None
+
+ return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -317,7 +326,7 @@ class StateGroupWorkerStore(
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
+ raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
@@ -1029,8 +1038,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
- def __init__(self, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
@@ -1245,8 +1254,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateStore, self).__init__(database, db_conn, hs)
def _store_event_state_mappings_txn(
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 40579bf965..7bc186e9a1 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import cached
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
- def __init__(self, db_conn, hs):
- super(StatsStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StatsStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 2ff8c57109..140da8dad6 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -47,6 +47,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(StreamWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(database, db_conn, hs)
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.db.get_cache_dict(
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index c0d155a43c..5b07c2fbc0 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
@@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, db_conn, hs):
- super(TransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TransactionStore, self).__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 62ffb34b29..90c180ec6d 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.data_stores.main.state import StateFilter
from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, db_conn, hs):
- super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, db_conn, hs):
- super(UserDirectoryStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryStore, self).__init__(database, db_conn, hs)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6843b7e7f8..ec19ae1d9d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -234,7 +234,7 @@ class Database(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self.database_engine = hs.database_engine
+ self.engine = hs.database_engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -242,10 +242,10 @@ class Database(object):
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
- if isinstance(self.database_engine, Sqlite3Engine):
+ if isinstance(self.engine, Sqlite3Engine):
self._unsafe_to_upsert_tables.add("user_directory_search")
- if self.database_engine.can_native_upsert:
+ if self.engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update.
self._clock.call_later(
@@ -331,7 +331,7 @@ class Database(object):
cursor = LoggingTransaction(
conn.cursor(),
name,
- self.database_engine,
+ self.engine,
after_callbacks,
exception_callbacks,
)
@@ -339,7 +339,7 @@ class Database(object):
r = func(cursor, *args, **kwargs)
conn.commit()
return r
- except self.database_engine.module.OperationalError as e:
+ except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warning(
@@ -353,20 +353,20 @@ class Database(object):
i += 1
try:
conn.rollback()
- except self.database_engine.module.Error as e1:
+ except self.engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
continue
raise
- except self.database_engine.module.DatabaseError as e:
- if self.database_engine.is_deadlock(e):
+ except self.engine.module.DatabaseError as e:
+ if self.engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
- except self.database_engine.module.Error as e1:
+ except self.engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
name,
@@ -494,7 +494,7 @@ class Database(object):
sql_scheduling_timer.observe(sched_duration_sec)
context.add_database_scheduled(sched_duration_sec)
- if self.database_engine.is_connection_closed(conn):
+ if self.engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
@@ -561,7 +561,7 @@ class Database(object):
"""
try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
- except self.database_engine.module.IntegrityError:
+ except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
if not or_ignore:
@@ -660,7 +660,7 @@ class Database(object):
lock=lock,
)
return result
- except self.database_engine.module.IntegrityError as e:
+ except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
# don't retry forever, because things other than races
@@ -692,10 +692,7 @@ class Database(object):
upserts return True if a new entry was created, False if an existing
one was updated.
"""
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
@@ -726,7 +723,7 @@ class Database(object):
"""
# We need to lock the table :(, unless we're *really* careful
if lock:
- self.database_engine.lock_table(txn, table)
+ self.engine.lock_table(txn, table)
def _getwhere(key):
# If the value we're passing in is None (aka NULL), we need to use
@@ -828,10 +825,7 @@ class Database(object):
Returns:
None
"""
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values
)
@@ -1301,7 +1295,7 @@ class Database(object):
"limit": limit,
}
- sql = self.database_engine.convert_param_style(sql)
+ sql = self.engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
|