diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index 7950d19db3..158ab79154 100644
--- a/.buildkite/worker-blacklist
+++ b/.buildkite/worker-blacklist
@@ -34,33 +34,8 @@ Device list doesn't change if remote server is down
Remote servers cannot set power levels in rooms without existing powerlevels
Remote servers should reject attempts by non-creators to set the power levels
-# new failures as of https://github.com/matrix-org/sytest/pull/753
-GET /rooms/:room_id/messages returns a message
-GET /rooms/:room_id/messages lazy loads members correctly
-Read receipts are sent as events
-Only original members of the room can see messages from erased users
-Device deletion propagates over federation
-If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes
-Changing user-signing key notifies local users
-Newly updated tags appear in an incremental v2 /sync
+# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update
-Local device key changes get to remote servers with correct prev_id
-AS-ghosted users can use rooms via AS
-Ghost user must register before joining room
-Test that a message is pushed
-Invites are pushed
-Rooms with aliases are correctly named in pushed
-Rooms with names are correctly named in pushed
-Rooms with canonical alias are correctly named in pushed
-Rooms with many users are correctly pushed
-Don't get pushed for rooms you've muted
-Rejected events are not pushed
-Test that rejected pushers are removed.
-Events come down the correct room
-
-# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603
-Presence changes to UNAVAILABLE are reported to remote room members
-If remote user leaves room, changes device and rejoins we see update in sync
-uploading self-signing key notifies over federation
-Inbound federation can receive redacted events
-Outbound federation can request missing events
+
+# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
+Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
diff --git a/changelog.d/6349.feature b/changelog.d/6349.feature
new file mode 100644
index 0000000000..56c4fbf78e
--- /dev/null
+++ b/changelog.d/6349.feature
@@ -0,0 +1 @@
+Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)).
diff --git a/changelog.d/6377.bugfix b/changelog.d/6377.bugfix
new file mode 100644
index 0000000000..ccda96962f
--- /dev/null
+++ b/changelog.d/6377.bugfix
@@ -0,0 +1 @@
+Prevent redacted events from being returned during message search.
\ No newline at end of file
diff --git a/changelog.d/6385.bugfix b/changelog.d/6385.bugfix
new file mode 100644
index 0000000000..7a2bc02170
--- /dev/null
+++ b/changelog.d/6385.bugfix
@@ -0,0 +1 @@
+Prevent error on trying to search a upgraded room when the server is not in the predecessor room.
\ No newline at end of file
diff --git a/changelog.d/6394.feature b/changelog.d/6394.feature
new file mode 100644
index 0000000000..1a0e8845ad
--- /dev/null
+++ b/changelog.d/6394.feature
@@ -0,0 +1 @@
+Add a develop script to generate full SQL schemas.
\ No newline at end of file
diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature
new file mode 100644
index 0000000000..ebea4a208d
--- /dev/null
+++ b/changelog.d/6411.feature
@@ -0,0 +1 @@
+Allow custom SAML username mapping functinality through an external provider plugin.
\ No newline at end of file
diff --git a/changelog.d/6486.bugfix b/changelog.d/6486.bugfix
new file mode 100644
index 0000000000..b98c5a9ae5
--- /dev/null
+++ b/changelog.d/6486.bugfix
@@ -0,0 +1 @@
+Improve performance of looking up cross-signing keys.
diff --git a/changelog.d/6496.misc b/changelog.d/6496.misc
new file mode 100644
index 0000000000..19c6e926b8
--- /dev/null
+++ b/changelog.d/6496.misc
@@ -0,0 +1 @@
+Port synapse.handlers.initial_sync to async/await.
diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc
new file mode 100644
index 0000000000..255f45a9c3
--- /dev/null
+++ b/changelog.d/6501.misc
@@ -0,0 +1 @@
+Refactor get_events_from_store_or_dest to return a dict.
diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal
new file mode 100644
index 0000000000..0b72261d58
--- /dev/null
+++ b/changelog.d/6502.removal
@@ -0,0 +1 @@
+Remove redundant code from event authorisation implementation.
diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc
new file mode 100644
index 0000000000..e4e9a5a3d4
--- /dev/null
+++ b/changelog.d/6503.misc
@@ -0,0 +1 @@
+Move get_state methods into FederationHandler.
diff --git a/changelog.d/6504.misc b/changelog.d/6504.misc
new file mode 100644
index 0000000000..7c873459af
--- /dev/null
+++ b/changelog.d/6504.misc
@@ -0,0 +1 @@
+Port handlers.account_data and handlers.account_validity to async/await.
diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc
new file mode 100644
index 0000000000..99d7a70bcf
--- /dev/null
+++ b/changelog.d/6506.misc
@@ -0,0 +1 @@
+Remove `SnapshotCache` in favour of `ResponseCache`.
diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc
new file mode 100644
index 0000000000..214f06539b
--- /dev/null
+++ b/changelog.d/6510.misc
@@ -0,0 +1 @@
+Change phone home stats to not assume there is a single database and report information about the database used by the main data store.
diff --git a/changelog.d/6511.misc b/changelog.d/6511.misc
new file mode 100644
index 0000000000..19ce435e68
--- /dev/null
+++ b/changelog.d/6511.misc
@@ -0,0 +1 @@
+Move database config from apps into HomeServer object.
diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc
new file mode 100644
index 0000000000..37a8099eec
--- /dev/null
+++ b/changelog.d/6512.misc
@@ -0,0 +1 @@
+Silence mypy errors for files outside those specified.
diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix
new file mode 100644
index 0000000000..6dc1985c24
--- /dev/null
+++ b/changelog.d/6514.bugfix
@@ -0,0 +1 @@
+Fix race which occasionally caused deleted devices to reappear.
diff --git a/changelog.d/6515.misc b/changelog.d/6515.misc
new file mode 100644
index 0000000000..a9c303ed1c
--- /dev/null
+++ b/changelog.d/6515.misc
@@ -0,0 +1 @@
+Clean up some logging when handling incoming events over federation.
diff --git a/changelog.d/6517.misc b/changelog.d/6517.misc
new file mode 100644
index 0000000000..c6ffed9952
--- /dev/null
+++ b/changelog.d/6517.misc
@@ -0,0 +1 @@
+Port some of FederationHandler to async/await.
\ No newline at end of file
diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc
new file mode 100644
index 0000000000..d9a44389b9
--- /dev/null
+++ b/changelog.d/6521.misc
@@ -0,0 +1 @@
+Refactor some code in the event authentication path for clarity.
diff --git a/changelog.d/6522.bugfix b/changelog.d/6522.bugfix
new file mode 100644
index 0000000000..ccda96962f
--- /dev/null
+++ b/changelog.d/6522.bugfix
@@ -0,0 +1 @@
+Prevent redacted events from being returned during message search.
\ No newline at end of file
diff --git a/changelog.d/6524.misc b/changelog.d/6524.misc
new file mode 100644
index 0000000000..f885597426
--- /dev/null
+++ b/changelog.d/6524.misc
@@ -0,0 +1,2 @@
+Improve sanity-checking when receiving events over federation.
+
diff --git a/changelog.d/6534.misc b/changelog.d/6534.misc
new file mode 100644
index 0000000000..7df6bb442a
--- /dev/null
+++ b/changelog.d/6534.misc
@@ -0,0 +1 @@
+Test more folders against mypy.
diff --git a/changelog.d/6538.misc b/changelog.d/6538.misc
new file mode 100644
index 0000000000..cb4fd56948
--- /dev/null
+++ b/changelog.d/6538.misc
@@ -0,0 +1 @@
+Adjust the sytest blacklist for worker mode.
diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md
new file mode 100644
index 0000000000..92f2380488
--- /dev/null
+++ b/docs/saml_mapping_providers.md
@@ -0,0 +1,77 @@
+# SAML Mapping Providers
+
+A SAML mapping provider is a Python class (loaded via a Python module) that
+works out how to map attributes of a SAML response object to Matrix-specific
+user attributes. Details such as user ID localpart, displayname, and even avatar
+URLs are all things that can be mapped from talking to a SSO service.
+
+As an example, a SSO service may return the email address
+"john.smith@example.com" for a user, whereas Synapse will need to figure out how
+to turn that into a displayname when creating a Matrix user for this individual.
+It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
+variations. As each Synapse configuration may want something different, this is
+where SAML mapping providers come into play.
+
+## Enabling Providers
+
+External mapping providers are provided to Synapse in the form of an external
+Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
+then tell Synapse where to look for the handler class by editing the
+`saml2_config.user_mapping_provider.module` config option.
+
+`saml2_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+## Building a Custom Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+ - Arguments:
+ - `parsed_config` - A configuration object that is the return value of the
+ `parse_config` method. You should set any configuration options needed by
+ the module here.
+* `saml_response_to_user_attributes(self, saml_response, failures)`
+ - Arguments:
+ - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+ information from.
+ - `failures` - An `int` that represents the amount of times the returned
+ mxid localpart mapping has failed. This should be used
+ to create a deduplicated mxid localpart which should be
+ returned instead. For example, if this method returns
+ `john.doe` as the value of `mxid_localpart` in the returned
+ dict, and that is already taken on the homeserver, this
+ method will be called again with the same parameters but
+ with failures=1. The method should then return a different
+ `mxid_localpart` value, such as `john.doe1`.
+ - This method must return a dictionary, which will then be used by Synapse
+ to build a new user. The following keys are allowed:
+ * `mxid_localpart` - Required. The mxid localpart of the new user.
+ * `displayname` - The displayname of the new user. If not provided, will default to
+ the value of `mxid_localpart`.
+* `parse_config(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A `dict` representing the parsed content of the
+ `saml2_config.user_mapping_provider.config` homeserver config option.
+ Runs on homeserver startup. Providers should extract any option values
+ they need here.
+ - Whatever is returned will be passed back to the user mapping provider module's
+ `__init__` method during construction.
+* `get_saml_attributes(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A object resulting from a call to `parse_config`.
+ - Returns a tuple of two sets. The first set equates to the saml auth
+ response attributes that are required for the module to function, whereas
+ the second set consists of those attributes which can be used if available,
+ but are not necessary.
+
+## Synapse's Default Provider
+
+Synapse has a built-in SAML mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 10664ae8f7..4d44e631d1 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1250,33 +1250,58 @@ saml2_config:
#
#config_path: "CONFDIR/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
diff --git a/mypy.ini b/mypy.ini
index 1d77c0ecc8..a66434b76b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,7 +1,7 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
-follow_imports = normal
+follow_imports = silent
check_untyped_defs = True
show_error_codes = True
show_traceback = True
diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
new file mode 100755
index 0000000000..60e8970a35
--- /dev/null
+++ b/scripts-dev/make_full_schema.sh
@@ -0,0 +1,184 @@
+#!/bin/bash
+#
+# This script generates SQL files for creating a brand new Synapse DB with the latest
+# schema, on both SQLite3 and Postgres.
+#
+# It does so by having Synapse generate an up-to-date SQLite DB, then running
+# synapse_port_db to convert it to Postgres. It then dumps the contents of both.
+
+POSTGRES_HOST="localhost"
+POSTGRES_DB_NAME="synapse_full_schema.$$"
+
+SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite"
+POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres"
+
+REQUIRED_DEPS=("matrix-synapse" "psycopg2")
+
+usage() {
+ echo
+ echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]"
+ echo
+ echo "-p <postgres_username>"
+ echo " Username to connect to local postgres instance. The password will be requested"
+ echo " during script execution."
+ echo "-c"
+ echo " CI mode. Enables coverage tracking and prints every command that the script runs."
+ echo "-o <path>"
+ echo " Directory to output full schema files to."
+ echo "-h"
+ echo " Display this help text."
+}
+
+while getopts "p:co:h" opt; do
+ case $opt in
+ p)
+ POSTGRES_USERNAME=$OPTARG
+ ;;
+ c)
+ # Print all commands that are being executed
+ set -x
+
+ # Modify required dependencies for coverage
+ REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess")
+
+ COVERAGE=1
+ ;;
+ o)
+ command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1)
+ OUTPUT_DIR="$(realpath "$OPTARG")"
+ ;;
+ h)
+ usage
+ exit
+ ;;
+ \?)
+ echo "ERROR: Invalid option: -$OPTARG" >&2
+ usage
+ exit
+ ;;
+ esac
+done
+
+# Check that required dependencies are installed
+unsatisfied_requirements=()
+for dep in "${REQUIRED_DEPS[@]}"; do
+ pip show "$dep" --quiet || unsatisfied_requirements+=("$dep")
+done
+if [ ${#unsatisfied_requirements} -ne 0 ]; then
+ echo "Please install the following python packages: ${unsatisfied_requirements[*]}"
+ exit 1
+fi
+
+if [ -z "$POSTGRES_USERNAME" ]; then
+ echo "No postgres username supplied"
+ usage
+ exit 1
+fi
+
+if [ -z "$OUTPUT_DIR" ]; then
+ echo "No output directory supplied"
+ usage
+ exit 1
+fi
+
+# Create the output directory if it doesn't exist
+mkdir -p "$OUTPUT_DIR"
+
+read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD
+echo ""
+
+# Exit immediately if a command fails
+set -e
+
+# cd to root of the synapse directory
+cd "$(dirname "$0")/.."
+
+# Create temporary SQLite and Postgres homeserver db configs and key file
+TMPDIR=$(mktemp -d)
+KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path
+SQLITE_CONFIG=$TMPDIR/sqlite.conf
+SQLITE_DB=$TMPDIR/homeserver.db
+POSTGRES_CONFIG=$TMPDIR/postgres.conf
+
+# Ensure these files are delete on script exit
+trap 'rm -rf $TMPDIR' EXIT
+
+cat > "$SQLITE_CONFIG" <<EOF
+server_name: "test"
+
+signing_key_path: "$KEY_FILE"
+macaroon_secret_key: "abcde"
+
+report_stats: false
+
+database:
+ name: "sqlite3"
+ args:
+ database: "$SQLITE_DB"
+
+# Suppress the key server warning.
+trusted_key_servers: []
+EOF
+
+cat > "$POSTGRES_CONFIG" <<EOF
+server_name: "test"
+
+signing_key_path: "$KEY_FILE"
+macaroon_secret_key: "abcde"
+
+report_stats: false
+
+database:
+ name: "psycopg2"
+ args:
+ user: "$POSTGRES_USERNAME"
+ host: "$POSTGRES_HOST"
+ password: "$POSTGRES_PASSWORD"
+ database: "$POSTGRES_DB_NAME"
+
+# Suppress the key server warning.
+trusted_key_servers: []
+EOF
+
+# Generate the server's signing key.
+echo "Generating SQLite3 db schema..."
+python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
+
+# Make sure the SQLite3 database is using the latest schema and has no pending background update.
+echo "Running db background jobs..."
+scripts-dev/update_database --database-config "$SQLITE_CONFIG"
+
+# Create the PostgreSQL database.
+echo "Creating postgres database..."
+createdb $POSTGRES_DB_NAME
+
+echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
+if [ -z "$COVERAGE" ]; then
+ # No coverage needed
+ scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
+else
+ # Coverage desired
+ coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
+fi
+
+# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
+# This needs to be done after synapse_port_db is run
+echo "Dropping unwanted db tables..."
+SQL="
+DROP TABLE schema_version;
+DROP TABLE applied_schema_deltas;
+DROP TABLE applied_module_schemas;
+"
+sqlite3 "$SQLITE_DB" <<< "$SQL"
+psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
+
+echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..."
+sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE"
+
+echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..."
+pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE"
+
+echo "Cleaning up temporary Postgres database..."
+dropdb $POSTGRES_DB_NAME
+
+echo "Done! Files dumped to: $OUTPUT_DIR"
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 1776d202c5..23017c21f8 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database")
@@ -35,21 +34,11 @@ logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
- def __init__(self, config, database_engine, db_conn, **kwargs):
+ def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
- config.server_name,
- reactor=reactor,
- config=config,
- database_engine=database_engine,
- **kwargs
+ config.server_name, reactor=reactor, config=config, **kwargs
)
- self.database_engine = database_engine
- self.db_conn = db_conn
-
- def get_db_conn(self):
- return self.db_conn
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -85,24 +74,14 @@ if __name__ == "__main__":
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
- # Create the database engine and a connection to it.
- database_engine = create_engine(config.database_config)
- db_conn = database_engine.module.connect(
- **{
- k: v
- for k, v in config.database_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
+ # Instantiate and initialise the homeserver object.
+ hs = MockHomeserver(config)
+ db_conn = hs.get_db_conn()
# Update the database to the latest schema.
- prepare_database(db_conn, database_engine, config=config)
+ prepare_database(db_conn, hs.database_engine, config=config)
db_conn.commit()
- # Instantiate and initialise the homeserver object.
- hs = MockHomeserver(
- config, database_engine, db_conn, db_config=config.database_config,
- )
# setup instantiates the store within the homeserver object.
hs.setup()
store = hs.get_datastore()
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 04751a6a5e..51a909419f 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.logcontext import LoggingContext
from synapse.util.versionstring import get_version_string
@@ -229,14 +228,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = AdminCmdServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 02b900f382..e82e0f11e3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -143,8 +142,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
@@ -159,10 +156,8 @@ def start(config_options):
ps = AppserviceServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dadb487d5f..3edfe19567 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -181,14 +180,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = ClientReaderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index d110599a35..d0ddbe38fc 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import (
)
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -180,14 +179,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = EventCreatorServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 418c086254..311523e0ed 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -162,14 +161,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = FederationReaderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index f24920a7d6..83c436229c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
@@ -174,8 +173,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
@@ -190,10 +187,8 @@ def start(config_options):
ss = FederationSenderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index e647459d0e..30e435eead 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -234,14 +233,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = FrontendProxyServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df65d0a989..b8661457e2 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
@@ -328,15 +328,10 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
- config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
-
hs = SynapseHomeServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
@@ -519,8 +514,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version
#
- stats["database_engine"] = hs.database_engine.module.__name__
- stats["database_server_version"] = hs.database_engine.server_version
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 2c6dd3ef02..4c80f257e2 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -157,14 +156,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = MediaRepositoryServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index dd52a9fc2d..09e639040a 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -203,14 +202,10 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
- database_engine = create_engine(config.database_config)
-
ps = PusherServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 288ee64b42..dd2132e608 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v2_alpha import sync
from synapse.server import HomeServer
from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
@@ -437,14 +436,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = SynchrotronServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index c01fb34a9b..1257098f92 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
@@ -200,8 +199,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
@@ -216,10 +213,8 @@ def start(config_options):
ss = UserDirectoryServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
+import logging
from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
- map_username_to_mxid_localpart,
- mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+ "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True
- self.saml2_mxid_source_attribute = saml2_config.get(
- "mxid_source_attribute", "uid"
- )
-
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
- saml2_config_dict = self._default_saml_config_dict()
+ # user_mapping_provider may be None if the key is present but has no value
+ ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+ # Use the default user mapping provider if not set
+ ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+ # Ensure a config is present
+ ump_dict["config"] = ump_dict.get("config") or {}
+
+ if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+ # Load deprecated options for use by the default module
+ old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+ if old_mxid_source_attribute:
+ logger.warning(
+ "The config option saml2_config.mxid_source_attribute is deprecated. "
+ "Please use saml2_config.user_mapping_provider.config"
+ ".mxid_source_attribute instead."
+ )
+ ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+ old_mxid_mapping = saml2_config.get("mxid_mapping")
+ if old_mxid_mapping:
+ logger.warning(
+ "The config option saml2_config.mxid_mapping is deprecated. Please "
+ "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+ )
+ ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+ # Retrieve an instance of the module's class
+ # Pass the config dictionary to the module for processing
+ (
+ self.saml2_user_mapping_provider_class,
+ self.saml2_user_mapping_provider_config,
+ ) = load_module(ump_dict)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ # Note parse_config() is already checked during the call to load_module
+ required_methods = [
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.saml2_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by saml2_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # Get the desired saml auth response attributes from the module
+ saml2_config_dict = self._default_saml_config_dict(
+ *self.saml2_user_mapping_provider_class.get_saml_attributes(
+ self.saml2_user_mapping_provider_config
+ )
+ )
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
@@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
- mapping = saml2_config.get("mxid_mapping", "hexencode")
- try:
- self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
- except KeyError:
- raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
- def _default_saml_config_dict(self):
+ def _default_saml_config_dict(
+ self, required_attributes: set, optional_attributes: set
+ ):
+ """Generate a configuration dictionary with required and optional attributes that
+ will be needed to process new user registration
+
+ Args:
+ required_attributes: SAML auth response attributes that are
+ necessary to function
+ optional_attributes: SAML auth response attributes that can be used to add
+ additional information to Synapse user accounts, but are not required
+
+ Returns:
+ dict: A SAML configuration dictionary
+ """
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
- required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
- optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
#
#config_path: "%(config_dir_path)s/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
-
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
+
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
@@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
-
-
-DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
- username = username.lower()
- username = DOT_REPLACE_PATTERN.sub(".", username)
-
- # regular mxids aren't allowed to start with an underscore either
- username = re.sub("^_", "", username)
- return username
-
-
-MXID_MAPPER_MAP = {
- "hexencode": map_username_to_mxid_localpart,
- "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..c940b84470 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns:
if the auth checks pass.
"""
+ assert isinstance(auth_events, dict)
+
if do_size_check:
_check_size_limits(event)
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
- if auth_events is None:
- # Oh, we don't know what the state of the room was, so we
- # are trusting that this is allowed (at least for now)
- logger.warning("Trusting event: %s", event.event_id)
- return
-
if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
import itertools
import logging
-from six.moves import range
-
from prometheus_client import Counter
from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
return signed_pdu
@defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+ """Calls the /state_ids endpoint to fetch the state at a particular point
+ in the room, and the auth events for the given event
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids)
"""
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
- fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
- destination, room_id, set(state_event_ids + auth_event_ids)
- )
-
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
- failed_to_fetch,
- )
-
- event_map = {ev.event_id: ev for ev in fetched_events}
+ if not isinstance(state_event_ids, list) or not isinstance(
+ auth_event_ids, list
+ ):
+ raise Exception("invalid response from /state_ids")
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- return pdus, auth_chain
-
- @defer.inlineCallbacks
- def get_events_from_store_or_dest(self, destination, room_id, event_ids):
- """Fetch events from a remote destination, checking if we already have them.
-
- Args:
- destination (str)
- room_id (str)
- event_ids (list)
-
- Returns:
- Deferred: A deferred resolving to a 2-tuple where the first is a list of
- events and the second is a list of event ids that we failed to fetch.
- """
- seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = list(seen_events.values())
-
- failed_to_fetch = set()
-
- missing_events = set(event_ids)
- for k in seen_events:
- missing_events.discard(k)
-
- if not missing_events:
- return signed_events, failed_to_fetch
-
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- event_ids,
- )
-
- room_version = yield self.store.get_room_version(room_id)
-
- batch_size = 20
- missing_events = list(missing_events)
- for i in range(0, len(missing_events), batch_size):
- batch = set(missing_events[i : i + batch_size])
-
- deferreds = [
- run_in_background(
- self.get_pdu,
- destinations=[destination],
- event_id=e_id,
- room_version=room_version,
- )
- for e_id in batch
- ]
-
- res = yield make_deferred_yieldable(
- defer.DeferredList(deferreds, consumeErrors=True)
- )
- for success, result in res:
- if success and result:
- signed_events.append(result)
- batch.discard(result.event_id)
-
- # We removed all events we successfully fetched from `batch`
- failed_to_fetch.update(batch)
-
- return signed_events, failed_to_fetch
+ return state_event_ids, auth_event_ids
@defer.inlineCallbacks
@log_function
@@ -609,13 +526,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ content = yield self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
@@ -683,6 +594,44 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
+ def _do_send_join(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_join_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
+
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_join_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
+
+ @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
room_version = yield self.store.get_room_version(room_id)
@@ -791,18 +740,50 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_leave(
+ content = yield self._do_send_leave(destination, pdu)
+
+ logger.debug("Got content: %s", content)
+ return None
+
+ return self._try_destination_list("send_leave", destinations, send_request)
+
+ @defer.inlineCallbacks
+ def _do_send_leave(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- logger.debug("Got content: %s", content)
- return None
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
- return self._try_destination_list("send_leave", destinations, send_request)
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_leave_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
def get_public_rooms(
self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- return (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ }
async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
- return 200, {}
+ return {}
async def on_event_auth(self, origin, room_id, event_id):
with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_join(self, destination, room_id, event_id, content):
+ def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_leave(self, destination, room_id, event_id, content):
+ def send_join_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def send_leave_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ # we want to do our best to send this through. The problem is
+ # that if it fails, we won't retry it later, so if the remote
+ # server was just having a momentary blip, the room will be out of
+ # sync.
+ ignore_backoff=True,
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_V2_PREFIX
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, content
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
return await self.handler.on_event_auth(origin, context, event_id)
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PREFIX = FEDERATION_V2_PREFIX
+
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
- FederationSendJoinServlet,
- FederationSendLeaveServlet,
+ FederationV1SendJoinServlet,
+ FederationV2SendJoinServlet,
+ FederationV1SendLeaveServlet,
+ FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..20ec1ca01b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
- ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
+ async def get_pagination_rows(self, user, config, key):
return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
- "send_renewals", self.send_renewal_emails
+ "send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
+ await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
- yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
+ await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 57a10daefd..2d889364d4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
return ret
+ @defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
- # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
- return defer.succeed(
- {
- "master_keys": master_keys,
- "self_signing_keys": self_signing_keys,
- "user_signing_keys": user_signing_keys,
- }
- )
+ user_ids = list(query)
+
+ keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ if "master" in user_info:
+ master_keys[user_id] = user_info["master"]
+ if "self_signing" in user_info:
+ self_signing_keys[user_id] = user_info["self_signing"]
+
+ if (
+ from_user_id in keys
+ and keys[from_user_id] is not None
+ and "user_signing" in keys[from_user_id]
+ ):
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+ return {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
@trace
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..2ea69c5468 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self.federation_client.get_state_for_room(
- origin, room_id, p
+ ) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
- )
-
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
+ state_map = await resolve_events_with_store(
room_version,
state_maps,
event_map,
@@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
+ await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -583,8 +567,144 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = await self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
+ )
+
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
+
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
+
+ return remote_state, auth_chain
+
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
+ """Fetch events from a remote destination, checking if we already have them.
+
+ Args:
+ destination
+ room_id
+ event_ids
+
+ If we fail to fetch any of the events, a warning will be logged, and the event
+ will be omitted from the result. Likewise, any events which turn out not to
+ be in the given room.
+
+ Returns:
+ map from event_id to event
+ """
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if missing_events:
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ room_id,
+ )
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # XXX 20 requests at once? really?
+ for batch in batch_iter(missing_events, 20):
+ deferreds = [
+ run_in_background(
+ self.federation_client.get_pdu,
+ destinations=[destination],
+ event_id=e_id,
+ room_version=room_version,
+ )
+ for e_id in batch
+ ]
+
+ res = await make_deferred_yieldable(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
+
+ for success, result in res:
+ if success and result:
+ fetched_events[result.event_id] = result
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = list(
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ )
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+ del fetched_events[bad_event_id]
+
+ return fetched_events
+
+ async def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
@@ -599,7 +719,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = yield self.store.have_seen_events(event_ids)
+ seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -626,18 +746,18 @@ class FederationHandler(BaseHandler):
event_id,
[e.event.event_id for e in event_infos],
)
- yield self._handle_new_events(origin, event_infos)
+ await self._handle_new_events(origin, event_infos)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -650,11 +770,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -662,11 +782,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -683,9 +802,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -700,7 +819,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -723,7 +842,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -748,7 +867,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
@@ -762,7 +881,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch,
)
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -789,7 +908,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_seen_events(
+ seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -851,7 +970,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -867,16 +986,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -908,15 +1026,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -946,7 +1066,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -985,12 +1105,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1021,7 +1140,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1035,7 +1154,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1045,7 +1164,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1061,7 +1180,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1210,7 +1329,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1238,8 +1357,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1255,7 +1373,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1453,7 +1571,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2814,7 +2932,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..44ec3e66ae 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = SnapshotCache()
+ self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
as_client_event,
include_archived,
)
- now_ms = self.clock.time_msec()
- result = self.snapshot_cache.get(now_ms, key)
- if result is not None:
- return result
- return self.snapshot_cache.set(
- now_ms,
+ return self.snapshot_cache.wrap(
key,
- self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ),
+ self._snapshot_all_rooms,
+ user_id,
+ pagin_config,
+ as_client_event,
+ include_archived,
)
- @defer.inlineCallbacks
- def _snapshot_all_rooms(
+ async def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
+ presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = yield receipt_stream.get_pagination_rows(
+ receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
- tags_by_room = yield self.store.get_tags_for_user(user_id)
+ tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ account_data, account_data_by_room = await self.store.get_account_data_for_user(
user_id
)
- public_room_ids = yield self.store.get_public_room_ids()
+ public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
- @defer.inlineCallbacks
- def handle_room(event):
+ async def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["inviter"] = event.sender
- invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = yield self._event_serializer.serialize_event(
+ invite_event = await self.store.get_event(event.event_id)
+ d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
)
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield make_deferred_yieldable(
+ (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages
)
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
)
),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
"end": end_token.to_string(),
}
- d["state"] = yield self._event_serializer.serialize_events(
+ d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
except Exception:
logger.exception("Failed to get snapshot")
- yield concurrently_execute(handle_room, room_list, 10)
+ await concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def room_initial_sync(self, requester, room_id, pagin_config=None):
+ async def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
- blocked = yield self.store.is_room_blocked(room_id)
+ blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
- membership, member_event_id = yield self._check_in_room_or_world_readable(
+ membership, member_event_id = await self._check_in_room_or_world_readable(
room_id, user_id
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
- result = yield self._room_initial_sync_joined(
+ result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
- result = yield self._room_initial_sync_parted(
+ result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+ account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({"type": account_data_type, "content": content})
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def _room_initial_sync_parted(
+ async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
- room_state = yield self.state_store.get_state_for_events([member_event_id])
+ room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(member_event_id)
+ stream_token = await self.store.get_stream_token_for_event(member_event_id)
- messages, token = yield self.store.get_recent_events_for_room(
+ messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
room_state.values(), time_now
)
),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
"receipts": [],
}
- @defer.inlineCallbacks
- def _room_initial_sync_joined(
+ async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
- current_state = yield self.state.get_current_state(room_id=room_id)
+ current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = yield self._event_serializer.serialize_events(
+ state = await self._event_serializer.serialize_events(
current_state.values(), time_now
)
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
- @defer.inlineCallbacks
- def get_presence():
+ async def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
return []
- states = yield presence_handler.get_states(
+ states = await presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
return states
- @defer.inlineCallbacks
- def get_receipts():
- receipts = yield self.store.get_linearized_receipts_for_room(
+ async def get_receipts():
+ receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
return receipts
- presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
).addErrback(unwrapFirstError)
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
+ async def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ member_event = await self.auth.check_user_was_in_room(room_id, user_id)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield self.state_handler.get_current_state(
+ visibility = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
- @defer.inlineCallbacks
- def get_messages(
+ async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_pagination()
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
+ with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
- ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+ ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.state_store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = yield self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import inspect
import logging
import threading
import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
if not isinstance(deferred, defer.Deferred):
return deferred
diff --git a/synapse/server.py b/synapse/server.py
index 2db3dab221..5021068ce0 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -34,6 +34,7 @@ from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
@@ -97,6 +98,7 @@ from synapse.server_notices.worker_server_notices_sender import (
)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStores, Storage
+from synapse.storage.engines import create_engine
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -209,16 +211,18 @@ class HomeServer(object):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
- def __init__(self, hostname, reactor=None, **kwargs):
+ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
+ config: The full config for the homeserver.
"""
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname
+ self.config = config
self._building = {}
self._listening_services = []
self.start_time = None
@@ -229,6 +233,12 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
+ self.database_engine = create_engine(config.database_config)
+ config.database_config.setdefault("args", {})[
+ "cp_openfun"
+ ] = self.database_engine.on_new_connection
+ self.db_config = config.database_config
+
self.datastores = None
# Other kwargs are explicit dependencies
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..3e6d62eef1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -645,7 +646,7 @@ class StateResolutionStore(object):
return self.store.get_events(
event_ids,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=allow_rejected,
)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self.db.simple_upsert_txn(
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={
+ updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
- lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, List
+
from six import iteritems
from canonicaljson import encode_canonical_json, json
+from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being set: either 'master'
+ key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose self-signing key is being requested
- key_type (str): the type of cross-signing key to get
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ @cached(num_args=1)
+ def _get_bare_e2e_cross_signing_keys(self, user_id):
+ """Dummy function. Only used to make a cache for
+ _get_bare_e2e_cross_signing_keys_bulk.
+ """
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_bare_e2e_cross_signing_keys",
+ list_name="user_ids",
+ num_args=1,
+ )
+ def _get_bare_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str]
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+
+ """
+ return self.db.runInteraction(
+ "get_bare_e2e_cross_signing_keys_bulk",
+ self._get_bare_e2e_cross_signing_keys_bulk_txn,
+ user_ids,
+ )
+
+ def _get_bare_e2e_cross_signing_keys_bulk_txn(
+ self, txn: Connection, user_ids: List[str],
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, their user
+ ID will not be in the dict.
+
+ """
+ result = {}
+
+ batch_size = 100
+ chunks = [
+ user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+ FROM e2e_cross_signing_keys k
+ INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+ FROM e2e_cross_signing_keys
+ GROUP BY user_id, keytype) s
+ USING (user_id, stream_id, keytype)
+ WHERE k.user_id IN (%s)
+ """ % (
+ ",".join("?" for u in user_chunk),
+ )
+ query_params = []
+ query_params.extend(user_chunk)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ for row in rows:
+ user_id = row["user_id"]
+ key_type = row["keytype"]
+ key = json.loads(row["keydata"])
+ user_info = result.setdefault(user_id, {})
+ user_info[key_type] = key
+
+ return result
+
+ def _get_e2e_cross_signing_signatures_txn(
+ self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing signatures made by a user on a set of keys.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+ key data. This dict will be modified to add signatures.
+ from_user_id (str): fetch the signatures made by this user
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. The return value will be the same as the keys argument,
+ with the modifications included.
+ """
+
+ # find out what cross-signing keys (a.k.a. devices) we need to get
+ # signatures for. This is a map of (user_id, device_id) to key type
+ # (device_id is the key's public part).
+ devices = {}
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ for key_type, key in user_info.items():
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+ devices[(user_id, device_id)] = key_type
+
+ device_list = list(devices)
+
+ # split into batches
+ batch_size = 100
+ chunks = [
+ device_list[i : i + batch_size]
+ for i in range(0, len(device_list), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT target_user_id, target_device_id, key_id, signature
+ FROM e2e_cross_signing_signatures
+ WHERE user_id = ?
+ AND (%s)
+ """ % (
+ " OR ".join(
+ "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ )
+ )
+ query_params = [from_user_id]
+ for item in devices:
+ # item is a (user_id, device_id) tuple
+ query_params.extend(item)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # and add the signatures to the appropriate keys
+ for row in rows:
+ key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ key_type = devices[(target_user_id, target_device_id)]
+ # We need to copy everything, because the result may have come
+ # from the cache. dict.copy only does a shallow copy, so we
+ # need to recursively copy the dicts that will be modified.
+ user_info = keys[target_user_id] = keys[target_user_id].copy()
+ target_user_key = user_info[key_type] = user_info[key_type].copy()
+ if "signatures" in target_user_key:
+ signatures = target_user_key["signatures"] = target_user_key[
+ "signatures"
+ ].copy()
+ if from_user_id in signatures:
+ user_sigs = signatures[from_user_id] = signatures[from_user_id]
+ user_sigs[key_id] = row["signature"]
+ else:
+ signatures[from_user_id] = {key_id: row["signature"]}
+ else:
+ target_user_key["signatures"] = {
+ from_user_id: {key_id: row["signature"]}
+ }
+
+ return keys
+
+ @defer.inlineCallbacks
+ def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: str = None
+ ) -> defer.Deferred:
+ """Returns the cross-signing keys for a set of users.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing keys will be included in the result
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+ key data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+ """
+
+ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+ if from_user_id:
+ result = yield self.db.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ )
+
+ return result
+
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+ )
+
def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
+from typing import List, Optional
from canonicaljson import json
+from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(
self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
+ event_id: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_id: The event_id of the event to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
+ allow_rejected: If True return rejected events.
+ allow_none: If True, return None if no event found, if
False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
+ check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True, return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
events.append(event)
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..dcc6b43cdf 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -278,7 +278,7 @@ class StateGroupWorkerStore(
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
+ """Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
@@ -291,14 +291,22 @@ class StateGroupWorkerStore(
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
+ None if a predecessor key is not found, or is not a dictionary.
+
Raises:
- NotFoundError if the room is unknown
+ NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- # Return predecessor if present
- return create_event.content.get("predecessor", None)
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, dict):
+ return None
+
+ return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -318,7 +326,7 @@ class StateGroupWorkerStore(
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
+ raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 84f5ae22c3..2e8f6543e5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
- arg_spec = inspect.getargspec(orig)
+ arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
- """Cache for snapshots like the response of /initialSync.
- The response of initialSync only has to be a recent snapshot of the
- server state. It shouldn't matter to clients if it is a few minutes out
- of date.
-
- This caches a deferred response. Until the deferred completes it will be
- returned from the cache. This means that if the client retries the request
- while the response is still being computed, that original response will be
- used rather than trying to compute a new response.
-
- Once the deferred completes it will removed from the cache after 5 minutes.
- We delay removing it from the cache because a client retrying its request
- could race with us finishing computing the response.
-
- Rather than tracking precisely how long something has been in the cache we
- keep two generations of completed responses. Every 5 minutes discard the
- old generation, move the new generation to the old generation, and set the
- new generation to be empty. This means that a result will be in the cache
- somewhere between 5 and 10 minutes.
- """
-
- DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
-
- def __init__(self):
- self.pending_result_cache = {} # Request that haven't finished yet.
- self.prev_result_cache = {} # The older requests that have finished.
- self.next_result_cache = {} # The newer requests that have finished.
- self.time_last_rotated_ms = 0
-
- def rotate(self, time_now_ms):
- # Rotate once if the cache duration has passed since the last rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms += self.DURATION_MS
-
- # Rotate again if the cache duration has passed twice since the last
- # rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms = time_now_ms
-
- def get(self, time_now_ms, key):
- self.rotate(time_now_ms)
- # This cache is intended to deduplicate requests, so we expect it to be
- # missed most of the time. So we just lookup the key in all of the
- # dictionaries rather than trying to short circuit the lookup if the
- # key is found.
- result = self.prev_result_cache.get(key)
- result = self.next_result_cache.get(key, result)
- result = self.pending_result_cache.get(key, result)
- if result is not None:
- return result.observe()
- else:
- return None
-
- def set(self, time_now_ms, key, deferred):
- self.rotate(time_now_ms)
-
- result = ObservableDeferred(deferred)
-
- self.pending_result_cache[key] = result
-
- def shuffle_along(r):
- # When the deferred completes we shuffle it along to the first
- # generation of the result cache. So that it will eventually
- # expire from the rotation of that cache.
- self.next_result_cache[key] = result
- self.pending_result_cache.pop(key, None)
- return r
-
- result.addBoth(shuffle_along)
-
- return result.observe()
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index fdfa2cbbc4..854eb6c024 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- test_replace_master_key.skip = (
- "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
- )
-
@defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
@@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
-
- test_upload_signatures.skip = (
- "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
- )
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index c0d0d2b44e..d0c997e385 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(*args, **kwargs):
+ async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
- return
config["email"] = {
"enable_notifs": True,
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(12345678)
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
- # Insert a user IP
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
-
# Force persisting to disk
self.reactor.advance(200)
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.get_success(
self.store.db.simple_update(
table="devices",
- keyvalues={"user_id": user_id, "device_id": "device_id"},
+ keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
desc="test_devices_last_seen_bg_update",
)
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get nulls when querying
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get the correct result again
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
- # Insert a user IP
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
- "device_id": "device_id",
+ "device_id": device_id,
"last_seen": 0,
}
],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But we should still get the correct values for the device
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ad165d7295..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
from mock import Mock
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
)
# Send the join, it should return None (which is not an error)
- d = self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
)
with LoggingContext(request="lying_event"):
- d = self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
)
# Step the reactor, so the database fetches come back
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..8d97c751ea 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
-from tests.utils import TestHomeServer
-mock_homeserver = TestHomeServer(hostname="my.domain")
-
-class UserIDTestCase(unittest.TestCase):
+class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- user = UserID.from_string("@1234abcd:my.domain")
+ user = UserID.from_string("@1234abcd:test")
self.assertEquals("1234abcd", user.localpart)
- self.assertEquals("my.domain", user.domain)
- self.assertEquals(True, mock_homeserver.is_mine(user))
+ self.assertEquals("test", user.domain)
+ self.assertEquals(True, self.hs.is_mine(user))
def test_pase_empty(self):
with self.assertRaises(SynapseError):
@@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase):
self.assertTrue(userA != userB)
-class RoomAliasTestCase(unittest.TestCase):
+class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- room = RoomAlias.from_string("#channel:my.domain")
+ room = RoomAlias.from_string("#channel:test")
self.assertEquals("channel", room.localpart)
- self.assertEquals("my.domain", room.domain)
- self.assertEquals(True, mock_homeserver.is_mine(room))
+ self.assertEquals("test", room.domain)
+ self.assertEquals(True, self.hs.is_mine(room))
def test_build(self):
room = RoomAlias("channel", "my.domain")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_await(self):
+ # an async function which retuns an incomplete coroutine, but doesn't
+ # follow the synapse rules.
+
+ async def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ await d
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = SnapshotCache()
- self.cache.DURATION_MS = 1
-
- def test_get_set(self):
- # Check that getting a missing key returns None
- self.assertEquals(self.cache.get(0, "key"), None)
-
- # Check that setting a key with a deferred returns
- # a deferred that resolves when the initial deferred does
- d = Deferred()
- set_result = self.cache.set(0, "key", d)
- self.assertIsNotNone(set_result)
- self.assertFalse(set_result.called)
-
- # Check that getting the key before the deferred has resolved
- # returns a deferred that resolves when the initial deferred does.
- get_result_at_10 = self.cache.get(10, "key")
- self.assertIsNotNone(get_result_at_10)
- self.assertFalse(get_result_at_10.called)
-
- # Check that the returned deferreds resolve when the initial deferred
- # does.
- d.callback("v")
- self.assertTrue(set_result.called)
- self.assertTrue(get_result_at_10.called)
-
- # Check that getting the key after the deferred has resolved
- # before the cache expires returns a resolved deferred.
- get_result_at_11 = self.cache.get(11, "key")
- self.assertIsNotNone(get_result_at_11)
- if isinstance(get_result_at_11, Deferred):
- # The cache may return the actual result rather than a deferred
- self.assertTrue(get_result_at_11.called)
-
- # Check that getting the key after the deferred has resolved
- # after the cache expires returns None
- get_result_at_12 = self.cache.get(12, "key")
- self.assertIsNone(get_result_at_12)
diff --git a/tests/utils.py b/tests/utils.py
index c57da59191..585f305b9a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -260,9 +260,7 @@ def setup_test_homeserver(
hs = homeserverToUse(
name,
config=config,
- db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
diff --git a/tox.ini b/tox.ini
index 903a245fb0..2ae82d674f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -177,5 +177,17 @@ env =
MYPYPATH = stubs/
extras = all
commands = mypy \
+ synapse/config/ \
+ synapse/handlers/ui_auth \
synapse/logging/ \
- synapse/config/
+ synapse/module_api \
+ synapse/rest/consent \
+ synapse/rest/media/v0 \
+ synapse/rest/saml2 \
+ synapse/spam_checker_api \
+ synapse/storage/engines \
+ synapse/streams
+
+# To find all folders that pass mypy you run:
+#
+# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
|