diff --git a/CHANGES.rst b/CHANGES.rst
index e1d5e876dc..7ebb42b0fc 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,125 @@
+Changes in synapse v0.17.0 (2016-08-08)
+=======================================
+
+This release contains significant security bug fixes regarding authenticating
+events received over federation. PLEASE UPGRADE.
+
+This release changes the LDAP configuration format in a backwards incompatible
+way, see PR #843 for details.
+
+
+Changes:
+
+* Add federation /version API (PR #990)
+* Make psutil dependency optional (PR #992)
+
+
+Bug fixes:
+
+* Fix URL preview API to exclude HTML comments in description (PR #988)
+* Fix error handling of remote joins (PR #991)
+
+
+Changes in synapse v0.17.0-rc4 (2016-08-05)
+===========================================
+
+Changes:
+
+* Change the way we summarize URLs when previewing (PR #973)
+* Add new ``/state_ids/`` federation API (PR #979)
+* Speed up processing of ``/state/`` response (PR #986)
+
+Bug fixes:
+
+* Fix event persistence when event has already been partially persisted
+ (PR #975, #983, #985)
+* Fix port script to also copy across backfilled events (PR #982)
+
+
+Changes in synapse v0.17.0-rc3 (2016-08-02)
+===========================================
+
+Changes:
+
+* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
+* Add some basic admin API docs (PR #963)
+
+
+Bug fixes:
+
+* Send the correct host header when fetching keys (PR #941)
+* Fix joining a room that has missing auth events (PR #964)
+* Fix various push bugs (PR #966, #970)
+* Fix adding emails on registration (PR #968)
+
+
+Changes in synapse v0.17.0-rc2 (2016-08-02)
+===========================================
+
+(This release did not include the changes advertised and was identical to RC1)
+
+
+Changes in synapse v0.17.0-rc1 (2016-07-28)
+===========================================
+
+This release changes the LDAP configuration format in a backwards incompatible
+way, see PR #843 for details.
+
+
+Features:
+
+* Add purge_media_cache admin API (PR #902)
+* Add deactivate account admin API (PR #903)
+* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
+* Add an admin option to shared secret registration (breaks backwards compat)
+ (PR #909)
+* Add purge local room history API (PR #911, #923, #924)
+* Add requestToken endpoints (PR #915)
+* Add an /account/deactivate endpoint (PR #921)
+* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
+* Add device_id support to /login (PR #929)
+* Add device_id support to /v2/register flow. (PR #937, #942)
+* Add GET /devices endpoint (PR #939, #944)
+* Add GET /device/{deviceId} (PR #943)
+* Add update and delete APIs for devices (PR #949)
+
+
+Changes:
+
+* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
+* Linearize some federation endpoints based on (origin, room_id) (PR #879)
+* Remove the legacy v0 content upload API. (PR #888)
+* Use similar naming we use in email notifs for push (PR #894)
+* Optionally include password hash in createUser endpoint (PR #905 by
+ KentShikama)
+* Use a query that postgresql optimises better for get_events_around (PR #906)
+* Fall back to 'username' if 'user' is not given for appservice registration.
+ (PR #927 by Half-Shot)
+* Add metrics for psutil derived memory usage (PR #936)
+* Record device_id in client_ips (PR #938)
+* Send the correct host header when fetching keys (PR #941)
+* Log the hostname the reCAPTCHA was completed on (PR #946)
+* Make the device id on e2e key upload optional (PR #956)
+* Add r0.2.0 to the "supported versions" list (PR #960)
+* Don't include name of room for invites in push (PR #961)
+
+
+Bug fixes:
+
+* Fix substitution failure in mail template (PR #887)
+* Put most recent 20 messages in email notif (PR #892)
+* Ensure that the guest user is in the database when upgrading accounts
+ (PR #914)
+* Fix various edge cases in auth handling (PR #919)
+* Fix 500 ISE when sending alias event without a state_key (PR #925)
+* Fix bug where we stored rejections in the state_group, persist all
+ rejections (PR #948)
+* Fix lack of check of if the user is banned when handling 3pid invites
+ (PR #952)
+* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
+
+
+
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================
diff --git a/MANIFEST.in b/MANIFEST.in
index dfb7c9d28d..981698143f 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -14,6 +14,7 @@ recursive-include docs *
recursive-include res *
recursive-include scripts *
recursive-include scripts-dev *
+recursive-include synapse *.pyi
recursive-include tests *.py
recursive-include synapse/static *.css
@@ -23,5 +24,7 @@ recursive-include synapse/static *.js
exclude jenkins.sh
exclude jenkins*.sh
+exclude jenkins*
+recursive-exclude jenkins *.sh
prune demo/etc
diff --git a/README.rst b/README.rst
index 8cbd28fb8e..d658670835 100644
--- a/README.rst
+++ b/README.rst
@@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs:
1) Use the machine's own hostname as available on public DNS in the form of
- its A or AAAA records. This is easier to set up initially, perhaps for
+ its A records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 699f04c2c2..9f044719a0 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch.
git pull
# Update the versions of synapse's python dependencies.
- python synapse/python_dependencies.py | xargs -n1 pip install
+ python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
Upgrading to v0.15.0
diff --git a/docs/admin_api/README.rst b/docs/admin_api/README.rst
new file mode 100644
index 0000000000..d4f564cfae
--- /dev/null
+++ b/docs/admin_api/README.rst
@@ -0,0 +1,12 @@
+Admin APIs
+==========
+
+This directory includes documentation for the various synapse specific admin
+APIs available.
+
+Only users that are server admins can use these APIs. A user can be marked as a
+server admin by updating the database directly, e.g.:
+
+``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
+
+Restarting may be required for the changes to register.
diff --git a/docs/admin_api/purge_history_api.rst b/docs/admin_api/purge_history_api.rst
new file mode 100644
index 0000000000..986efe40f9
--- /dev/null
+++ b/docs/admin_api/purge_history_api.rst
@@ -0,0 +1,15 @@
+Purge History API
+=================
+
+The purge history API allows server admins to purge historic events from their
+database, reclaiming disk space.
+
+Depending on the amount of history being purged a call to the API may take
+several minutes or longer. During this period users will not be able to
+paginate further back in the room from the point being purged from.
+
+The API is simply:
+
+``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
+
+including an ``access_token`` of a server admin.
diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst
new file mode 100644
index 0000000000..b26c6a9e7b
--- /dev/null
+++ b/docs/admin_api/purge_remote_media.rst
@@ -0,0 +1,19 @@
+Purge Remote Media API
+======================
+
+The purge remote media API allows server admins to purge old cached remote
+media.
+
+The API is::
+
+ POST /_matrix/client/r0/admin/purge_media_cache
+
+ {
+ "before_ts": <unix_timestamp_in_ms>
+ }
+
+Which will remove all cached media that was last accessed before
+``<unix_timestamp_in_ms>``.
+
+If the user re-requests purged remote media, synapse will re-request the media
+from the originating server.
diff --git a/docs/code_style.rst b/docs/code_style.rst
index dc40a7ab7b..8d73d17beb 100644
--- a/docs/code_style.rst
+++ b/docs/code_style.rst
@@ -43,7 +43,10 @@ Basically, PEP8
together, or want to deliberately extend or preserve vertical/horizontal
space)
-Comments should follow the google code style. This is so that we can generate
-documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
+Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
+This is so that we can generate documentation with
+`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
+`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
+in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings.
diff --git a/docs/turn-howto.rst b/docs/turn-howto.rst
index e2c73458e2..04c0100715 100644
--- a/docs/turn-howto.rst
+++ b/docs/turn-howto.rst
@@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
server through the use of a secret shared between the Home Server and the
TURN server.
-This document described how to install coturn
-(https://code.google.com/p/coturn/) which also supports the TURN REST API,
+This document describes how to install coturn
+(https://github.com/coturn/coturn) which also supports the TURN REST API,
and integrate it with synapse.
coturn Setup
============
+You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
+
1. Check out coturn::
- svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
+
+ git clone https://github.com/coturn/coturn.git coturn
cd coturn
2. Configure it::
+
./configure
- You may need to install libevent2: if so, you should do so
+ You may need to install ``libevent2``: if so, you should do so
in the way recommended by your operating system.
You can ignore warnings about lack of database support: a
database is unnecessary for this purpose.
3. Build and install it::
+
make
make install
- 4. Make a config file in /etc/turnserver.conf. You can customise
- a config file from turnserver.conf.default. The relevant
+ 4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
lines, with example values, are::
lt-cred-mech
@@ -41,7 +45,7 @@ coturn Setup
static-auth-secret=[your secret key here]
realm=turn.myserver.org
- See turnserver.conf.default for explanations of the options.
+ See turnserver.conf for explanations of the options.
One way to generate the static-auth-secret is with pwgen::
pwgen -s 64 1
@@ -54,6 +58,7 @@ coturn Setup
import your private key and certificate.
7. Start the turn server::
+
bin/turnserver -o
diff --git a/jenkins-dendron-postgres.sh b/jenkins-dendron-postgres.sh
index 7e6f24aa7d..68912a8967 100755
--- a/jenkins-dendron-postgres.sh
+++ b/jenkins-dendron-postgres.sh
@@ -4,83 +4,19 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
+export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
-# Output test results as junit xml
-export TRIAL_FLAGS="--reporter=subunit"
-export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
-# Write coverage reports to a separate file for each process
-export COVERAGE_OPTS="-p"
-export DUMP_COVERAGE_COMMAND="coverage help"
-
-# Output flake8 violations to violations.flake8.log
-# Don't exit with non-0 status code on Jenkins,
-# so that the build steps continue and a later step can decided whether to
-# UNSTABLE or FAILURE this build.
-export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
-
-rm .coverage* || echo "No coverage files to remove"
-
-tox --notest -e py27
-
-TOX_BIN=$WORKSPACE/.tox/py27/bin
-python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
-$TOX_BIN/pip install psycopg2
-$TOX_BIN/pip install lxml
-
-: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
-
-if [[ ! -e .dendron-base ]]; then
- git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
-else
- (cd .dendron-base; git fetch -p)
-fi
-
-rm -rf dendron
-git clone .dendron-base dendron --shared
-cd dendron
-
-: ${GOPATH:=${WORKSPACE}/.gopath}
-if [[ "${GOPATH}" != *:* ]]; then
- mkdir -p "${GOPATH}"
- export PATH="${GOPATH}/bin:${PATH}"
-fi
-export GOPATH
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-go get github.com/constabulary/gb/...
-gb generate
-gb build
-
-cd ..
-
-
-if [[ ! -e .sytest-base ]]; then
- git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
-else
- (cd .sytest-base; git fetch -p)
-fi
-
-rm -rf sytest
-git clone .sytest-base sytest --shared
-cd sytest
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-: ${PORT_BASE:=8000}
-
-./jenkins/prep_sytest_for_postgres.sh
-
-mkdir -p var
-
-echo >&2 "Running sytest with PostgreSQL";
-./jenkins/install_and_run.sh --python $TOX_BIN/python \
- --synapse-directory $WORKSPACE \
- --dendron $WORKSPACE/dendron/bin/dendron \
- --pusher \
- --synchrotron \
- --port-base $PORT_BASE
-
-cd ..
+./jenkins/prepare_synapse.sh
+./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
+./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
+./dendron/jenkins/build_dendron.sh
+./sytest/jenkins/prep_sytest_for_postgres.sh
+
+./sytest/jenkins/install_and_run.sh \
+ --synapse-directory $WORKSPACE \
+ --dendron $WORKSPACE/dendron/bin/dendron \
+ --pusher \
+ --synchrotron \
+ --federation-reader \
diff --git a/jenkins-postgres.sh b/jenkins-postgres.sh
index ae6b111591..f2ca8ccdff 100755
--- a/jenkins-postgres.sh
+++ b/jenkins-postgres.sh
@@ -4,60 +4,14 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
+export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
-# Output test results as junit xml
-export TRIAL_FLAGS="--reporter=subunit"
-export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
-# Write coverage reports to a separate file for each process
-export COVERAGE_OPTS="-p"
-export DUMP_COVERAGE_COMMAND="coverage help"
+./jenkins/prepare_synapse.sh
+./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
-# Output flake8 violations to violations.flake8.log
-# Don't exit with non-0 status code on Jenkins,
-# so that the build steps continue and a later step can decided whether to
-# UNSTABLE or FAILURE this build.
-export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
+./sytest/jenkins/prep_sytest_for_postgres.sh
-rm .coverage* || echo "No coverage files to remove"
-
-tox --notest -e py27
-
-TOX_BIN=$WORKSPACE/.tox/py27/bin
-python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
-$TOX_BIN/pip install psycopg2
-$TOX_BIN/pip install lxml
-
-: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
-
-if [[ ! -e .sytest-base ]]; then
- git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
-else
- (cd .sytest-base; git fetch -p)
-fi
-
-rm -rf sytest
-git clone .sytest-base sytest --shared
-cd sytest
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-: ${PORT_BASE:=8000}
-
-./jenkins/prep_sytest_for_postgres.sh
-
-echo >&2 "Running sytest with PostgreSQL";
-./jenkins/install_and_run.sh --coverage \
- --python $TOX_BIN/python \
- --synapse-directory $WORKSPACE \
- --port-base $PORT_BASE
-
-cd ..
-cp sytest/.coverage.* .
-
-# Combine the coverage reports
-echo "Combining:" .coverage.*
-$TOX_BIN/python -m coverage combine
-# Output coverage to coverage.xml
-$TOX_BIN/coverage xml -o coverage.xml
+./sytest/jenkins/install_and_run.sh \
+ --synapse-directory $WORKSPACE \
diff --git a/jenkins-sqlite.sh b/jenkins-sqlite.sh
index 9398d9db15..84613d979c 100755
--- a/jenkins-sqlite.sh
+++ b/jenkins-sqlite.sh
@@ -4,54 +4,12 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
+export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
-# Output test results as junit xml
-export TRIAL_FLAGS="--reporter=subunit"
-export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
-# Write coverage reports to a separate file for each process
-export COVERAGE_OPTS="-p"
-export DUMP_COVERAGE_COMMAND="coverage help"
+./jenkins/prepare_synapse.sh
+./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
-# Output flake8 violations to violations.flake8.log
-# Don't exit with non-0 status code on Jenkins,
-# so that the build steps continue and a later step can decided whether to
-# UNSTABLE or FAILURE this build.
-export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
-
-rm .coverage* || echo "No coverage files to remove"
-
-tox --notest -e py27
-TOX_BIN=$WORKSPACE/.tox/py27/bin
-python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
-$TOX_BIN/pip install lxml
-
-: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
-
-if [[ ! -e .sytest-base ]]; then
- git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
-else
- (cd .sytest-base; git fetch -p)
-fi
-
-rm -rf sytest
-git clone .sytest-base sytest --shared
-cd sytest
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-: ${PORT_BASE:=8500}
-./jenkins/install_and_run.sh --coverage \
- --python $TOX_BIN/python \
- --synapse-directory $WORKSPACE \
- --port-base $PORT_BASE
-
-cd ..
-cp sytest/.coverage.* .
-
-# Combine the coverage reports
-echo "Combining:" .coverage.*
-$TOX_BIN/python -m coverage combine
-# Output coverage to coverage.xml
-$TOX_BIN/coverage xml -o coverage.xml
+./sytest/jenkins/install_and_run.sh \
+ --synapse-directory $WORKSPACE \
diff --git a/jenkins-unittests.sh b/jenkins-unittests.sh
index 104d511994..6b0c296cff 100755
--- a/jenkins-unittests.sh
+++ b/jenkins-unittests.sh
@@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
+tox --notest -e py27
+TOX_BIN=$WORKSPACE/.tox/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
+
tox -e py27
diff --git a/jenkins/clone.sh b/jenkins/clone.sh
new file mode 100755
index 0000000000..ab30ac7782
--- /dev/null
+++ b/jenkins/clone.sh
@@ -0,0 +1,44 @@
+#! /bin/bash
+
+# This clones a project from github into a named subdirectory
+# If the project has a branch with the same name as this branch
+# then it will checkout that branch after cloning.
+# Otherwise it will checkout "origin/develop."
+# The first argument is the name of the directory to checkout
+# the branch into.
+# The second argument is the URL of the remote repository to checkout.
+# Usually something like https://github.com/matrix-org/sytest.git
+
+set -eux
+
+NAME=$1
+PROJECT=$2
+BASE=".$NAME-base"
+
+# Update our mirror.
+if [ ! -d ".$NAME-base" ]; then
+ # Create a local mirror of the source repository.
+ # This saves us from having to download the entire repository
+ # when this script is next run.
+ git clone "$PROJECT" "$BASE" --mirror
+else
+ # Fetch any updates from the source repository.
+ (cd "$BASE"; git fetch -p)
+fi
+
+# Remove the existing repository so that we have a clean copy
+rm -rf "$NAME"
+# Cloning with --shared means that we will share portions of the
+# .git directory with our local mirror.
+git clone "$BASE" "$NAME" --shared
+
+# Jenkins may have supplied us with the name of the branch in the
+# environment. Otherwise we will have to guess based on the current
+# commit.
+: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
+cd "$NAME"
+# check out the relevant branch
+git checkout "${GIT_BRANCH}" || (
+ echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
+ git checkout "origin/develop"
+)
diff --git a/jenkins/prepare_synapse.sh b/jenkins/prepare_synapse.sh
new file mode 100755
index 0000000000..237223c81b
--- /dev/null
+++ b/jenkins/prepare_synapse.sh
@@ -0,0 +1,19 @@
+#! /bin/bash
+
+cd "`dirname $0`/.."
+
+TOX_DIR=$WORKSPACE/.tox
+
+mkdir -p $TOX_DIR
+
+if ! [ $TOX_DIR -ef .tox ]; then
+ ln -s "$TOX_DIR" .tox
+fi
+
+# set up the virtualenv
+tox -e py27 --notest -v
+
+TOX_BIN=$TOX_DIR/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
+$TOX_BIN/pip install lxml
+$TOX_BIN/pip install psycopg2
diff --git a/res/templates/notif_mail.html b/res/templates/notif_mail.html
index 8aee68b591..535bea764d 100644
--- a/res/templates/notif_mail.html
+++ b/res/templates/notif_mail.html
@@ -36,7 +36,7 @@
<div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
- which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
+ which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index ea62dceb36..d1ab42d3af 100644
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items():
- authorization_headers.append(bytes(
- "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
- origin_name, key, sig,
- )
- ))
+ header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
+ origin_name, key, sig,
+ )
+ authorization_headers.append(bytes(header))
+ sys.stderr.write(header)
+ sys.stderr.write("\n")
result = requests.get(
lookup(destination, path),
headers={"Authorization": authorization_headers[0]},
verify=False,
)
+ sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json()
@@ -141,6 +143,7 @@ def main():
)
json.dump(result, sys.stdout)
+ print ""
if __name__ == "__main__":
main()
diff --git a/scripts/hash_password b/scripts/hash_password
index e784600989..215ab25cfe 100755
--- a/scripts/hash_password
+++ b/scripts/hash_password
@@ -1,10 +1,16 @@
#!/usr/bin/env python
import argparse
+
+import sys
+
import bcrypt
import getpass
+import yaml
+
bcrypt_rounds=12
+password_pepper = ""
def prompt_for_pass():
password = getpass.getpass("Password: ")
@@ -28,12 +34,22 @@ if __name__ == "__main__":
default=None,
help="New password for user. Will prompt if omitted.",
)
+ parser.add_argument(
+ "-c", "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
+ )
args = parser.parse_args()
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
+ password_config = config.get("password_config", {})
+ password_pepper = password_config.get("pepper", password_pepper)
password = args.password
if not password:
password = prompt_for_pass()
- print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
+ print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user
index 27a6250b14..12ed20d623 100755
--- a/scripts/register_new_matrix_user
+++ b/scripts/register_new_matrix_user
@@ -25,18 +25,26 @@ import urllib2
import yaml
-def request_registration(user, password, server_location, shared_secret):
+def request_registration(user, password, server_location, shared_secret, admin=False):
mac = hmac.new(
key=shared_secret,
- msg=user,
digestmod=hashlib.sha1,
- ).hexdigest()
+ )
+
+ mac.update(user)
+ mac.update("\x00")
+ mac.update(password)
+ mac.update("\x00")
+ mac.update("admin" if admin else "notadmin")
+
+ mac = mac.hexdigest()
data = {
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
+ "admin": admin,
}
server_location = server_location.rstrip("/")
@@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
sys.exit(1)
-def register_new_user(user, password, server_location, shared_secret):
+def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
@@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match"
sys.exit(1)
- request_registration(user, password, server_location, shared_secret)
+ if not admin:
+ admin = raw_input("Make admin [no]: ")
+ if admin in ("y", "yes", "true"):
+ admin = True
+ else:
+ admin = False
+
+ request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__":
@@ -119,6 +134,11 @@ if __name__ == "__main__":
default=None,
help="New password for user. Will prompt if omitted.",
)
+ parser.add_argument(
+ "-a", "--admin",
+ action="store_true",
+ help="Register new user as an admin. Will prompt if omitted.",
+ )
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
@@ -151,4 +171,4 @@ if __name__ == "__main__":
else:
secret = args.shared_secret
- register_new_user(args.user, args.password, args.server_url, secret)
+ register_new_user(args.user, args.password, args.server_url, secret, args.admin)
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index efd04da2d6..66c61b0198 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
- "events": ["processed", "outlier"],
+ "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
@@ -92,8 +92,12 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
+ _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
+ _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
- _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
+ _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
+ "_simple_select_one_onecol_txn"
+ ]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -158,31 +162,40 @@ class Porter(object):
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
- next_chunk = yield self.postgres_store._simple_select_one_onecol(
+ row = yield self.postgres_store._simple_select_one(
table="port_from_sqlite3",
keyvalues={"table_name": table},
- retcol="rowid",
+ retcols=("forward_rowid", "backward_rowid"),
allow_none=True,
)
total_to_port = None
- if next_chunk is None:
+ if row is None:
if table == "sent_transactions":
- next_chunk, already_ported, total_to_port = (
+ forward_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
+ backward_chunk = 0
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
- values={"table_name": table, "rowid": 1}
+ values={
+ "table_name": table,
+ "forward_rowid": 1,
+ "backward_rowid": 0,
+ }
)
- next_chunk = 1
+ forward_chunk = 1
+ backward_chunk = 0
already_ported = 0
+ else:
+ forward_chunk = row["forward_rowid"]
+ backward_chunk = row["backward_rowid"]
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
- table, next_chunk
+ table, forward_chunk, backward_chunk
)
else:
def delete_all(txn):
@@ -196,46 +209,85 @@ class Porter(object):
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
- values={"table_name": table, "rowid": 0}
+ values={
+ "table_name": table,
+ "forward_rowid": 1,
+ "backward_rowid": 0,
+ }
)
- next_chunk = 1
+ forward_chunk = 1
+ backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port(
- table, next_chunk
+ table, forward_chunk, backward_chunk
)
- defer.returnValue((table, already_ported, total_to_port, next_chunk))
+ defer.returnValue(
+ (table, already_ported, total_to_port, forward_chunk, backward_chunk)
+ )
@defer.inlineCallbacks
- def handle_table(self, table, postgres_size, table_size, next_chunk):
+ def handle_table(self, table, postgres_size, table_size, forward_chunk,
+ backward_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
if table == "event_search":
- yield self.handle_search_table(postgres_size, table_size, next_chunk)
+ yield self.handle_search_table(
+ postgres_size, table_size, forward_chunk, backward_chunk
+ )
return
- select = (
+ forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
+ backward_select = (
+ "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
+ % (table,)
+ )
+
+ do_forward = [True]
+ do_backward = [True]
+
while True:
def r(txn):
- txn.execute(select, (next_chunk, self.batch_size,))
- rows = txn.fetchall()
- headers = [column[0] for column in txn.description]
+ forward_rows = []
+ backward_rows = []
+ if do_forward[0]:
+ txn.execute(forward_select, (forward_chunk, self.batch_size,))
+ forward_rows = txn.fetchall()
+ if not forward_rows:
+ do_forward[0] = False
+
+ if do_backward[0]:
+ txn.execute(backward_select, (backward_chunk, self.batch_size,))
+ backward_rows = txn.fetchall()
+ if not backward_rows:
+ do_backward[0] = False
+
+ if forward_rows or backward_rows:
+ headers = [column[0] for column in txn.description]
+ else:
+ headers = None
- return headers, rows
+ return headers, forward_rows, backward_rows
- headers, rows = yield self.sqlite_store.runInteraction("select", r)
+ headers, frows, brows = yield self.sqlite_store.runInteraction(
+ "select", r
+ )
- if rows:
- next_chunk = rows[-1][0] + 1
+ if frows or brows:
+ if frows:
+ forward_chunk = max(row[0] for row in frows) + 1
+ if brows:
+ backward_chunk = min(row[0] for row in brows) - 1
+ rows = frows + brows
self._convert_rows(table, headers, rows)
def insert(txn):
@@ -247,7 +299,10 @@ class Porter(object):
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
- updatevalues={"rowid": next_chunk},
+ updatevalues={
+ "forward_rowid": forward_chunk,
+ "backward_rowid": backward_chunk,
+ },
)
yield self.postgres_store.execute(insert)
@@ -259,7 +314,8 @@ class Porter(object):
return
@defer.inlineCallbacks
- def handle_search_table(self, postgres_size, table_size, next_chunk):
+ def handle_search_table(self, postgres_size, table_size, forward_chunk,
+ backward_chunk):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@@ -270,7 +326,7 @@ class Porter(object):
while True:
def r(txn):
- txn.execute(select, (next_chunk, self.batch_size,))
+ txn.execute(select, (forward_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@@ -279,7 +335,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
- next_chunk = rows[-1][0] + 1
+ forward_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a
# different structure in the two different databases.
@@ -312,7 +368,10 @@ class Porter(object):
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
- updatevalues={"rowid": next_chunk},
+ updatevalues={
+ "forward_rowid": forward_chunk,
+ "backward_rowid": backward_chunk,
+ },
)
yield self.postgres_store.execute(insert)
@@ -324,7 +383,6 @@ class Porter(object):
else:
return
-
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
@@ -395,10 +453,32 @@ class Porter(object):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
- " rowid bigint NOT NULL"
+ " forward_rowid bigint NOT NULL,"
+ " backward_rowid bigint NOT NULL"
")"
)
+ # The old port script created a table with just a "rowid" column.
+ # We want people to be able to rerun this script from an old port
+ # so that they can pick up any missing events that were not
+ # ported across.
+ def alter_table(txn):
+ txn.execute(
+ "ALTER TABLE IF EXISTS port_from_sqlite3"
+ " RENAME rowid TO forward_rowid"
+ )
+ txn.execute(
+ "ALTER TABLE IF EXISTS port_from_sqlite3"
+ " ADD backward_rowid bigint NOT NULL DEFAULT 0"
+ )
+
+ try:
+ yield self.postgres_store.runInteraction(
+ "alter_table", alter_table
+ )
+ except Exception as e:
+ logger.info("Failed to create port table: %s", e)
+
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
@@ -458,7 +538,7 @@ class Porter(object):
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
- yesterday = int(time.time()*1000) - 86400000
+ yesterday = int(time.time() * 1000) - 86400000
# And save the max transaction id from each destination
select = (
@@ -514,7 +594,11 @@ class Porter(object):
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
- values={"table_name": "sent_transactions", "rowid": next_chunk}
+ values={
+ "table_name": "sent_transactions",
+ "forward_rowid": next_chunk,
+ "backward_rowid": 0,
+ }
)
def get_sent_table_size(txn):
@@ -535,13 +619,18 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
- def _get_remaining_count_to_port(self, table, next_chunk):
- rows = yield self.sqlite_store.execute_sql(
+ def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
+ frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
- next_chunk,
+ forward_chunk,
)
- defer.returnValue(rows[0][0])
+ brows = yield self.sqlite_store.execute_sql(
+ "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
+ backward_chunk,
+ )
+
+ defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
@@ -552,10 +641,10 @@ class Porter(object):
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
- def _get_total_count_to_port(self, table, next_chunk):
+ def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
remaining, done = yield defer.gatherResults(
[
- self._get_remaining_count_to_port(table, next_chunk),
+ self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
@@ -686,7 +775,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
- i+2, left_margin + max_len - len(table),
+ i + 2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
@@ -694,18 +783,18 @@ class CursesProgress(Progress):
size = 20
progress = "[%s%s]" % (
- "#" * int(perc*size/100),
- " " * (size - int(perc*size/100)),
+ "#" * int(perc * size / 100),
+ " " * (size - int(perc * size / 100)),
)
self.stdscr.addstr(
- i+2, left_margin + max_len + middle_space,
+ i + 2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
- rows-1, 0,
+ rows - 1, 0,
"Press any key to exit...",
)
diff --git a/setup.cfg b/setup.cfg
index 5ebce1c56b..da8eafbb39 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -16,7 +16,5 @@ ignore =
[flake8]
max-line-length = 90
-ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
-
-[pep8]
-max-line-length = 90
+# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
+ignore = W503
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2750ad3f7a..a63ee565cf 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.16.1-r1"
+__version__ = "0.17.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index a4d658a9d0..59db76debc 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
-
from twisted.internet import defer
+from unpaddedbase64 import decode_base64
+import synapse.types
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
-from synapse.types import Requester, UserID, get_domain_from_id
-from synapse.util.logutils import log_function
+from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
-from unpaddedbase64 import decode_base64
-
-import logging
-import pymacaroons
logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ class Auth(object):
"user_id = ",
])
- def check(self, event, auth_events):
+ def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
Args:
@@ -79,6 +79,13 @@ class Auth(object):
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
+
+ sender_domain = get_domain_from_id(event.sender)
+
+ # Check the sender's domain has signed the event
+ if do_sig_check and not event.signatures.get(sender_domain):
+ raise AuthError(403, "Event not signed by sending server")
+
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
@@ -86,6 +93,12 @@ class Auth(object):
return True
if event.type == EventTypes.Create:
+ room_id_domain = get_domain_from_id(event.room_id)
+ if room_id_domain != sender_domain:
+ raise AuthError(
+ 403,
+ "Creation event's room_id domain does not match sender's"
+ )
# FIXME
return True
@@ -108,6 +121,22 @@ class Auth(object):
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
+ if not event.is_state():
+ raise AuthError(
+ 403,
+ "Alias event must be a state event",
+ )
+ if not event.state_key:
+ raise AuthError(
+ 403,
+ "Alias event must have non-empty state_key"
+ )
+ sender_domain = get_domain_from_id(event.sender)
+ if event.state_key != sender_domain:
+ raise AuthError(
+ 403,
+ "Alias event's state_key does not match sender's domain"
+ )
return True
logger.debug(
@@ -347,6 +376,10 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
+ if target_banned:
+ raise AuthError(
+ 403, "%s is banned from the room" % (target_user_id,)
+ )
return True
if Membership.JOIN != membership:
@@ -537,9 +570,7 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
- tuple of:
- UserID (str)
- Access token ID (str)
+ defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -548,9 +579,7 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args)
if user_id:
request.authenticated_entity = user_id
- defer.returnValue(
- Requester(UserID.from_string(user_id), "", False)
- )
+ defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights)
@@ -558,6 +587,10 @@ class Auth(object):
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
+ # device_id may not be present if get_user_by_access_token has been
+ # stubbed out.
+ device_id = user_info.get("device_id")
+
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
"User-Agent",
@@ -569,7 +602,8 @@ class Auth(object):
user=user,
access_token=access_token,
ip=ip_addr,
- user_agent=user_agent
+ user_agent=user_agent,
+ device_id=device_id,
)
if is_guest and not allow_guest:
@@ -579,7 +613,8 @@ class Auth(object):
request.authenticated_entity = user.to_string()
- defer.returnValue(Requester(user, token_id, is_guest))
+ defer.returnValue(synapse.types.create_requester(
+ user, token_id, is_guest, device_id))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -629,7 +664,10 @@ class Auth(object):
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
+ if self.hs.config.expire_access_token:
+ raise
ret = yield self._look_up_user_by_access_token(token)
+
defer.returnValue(ret)
@defer.inlineCallbacks
@@ -664,6 +702,7 @@ class Auth(object):
"user": user,
"is_guest": True,
"token_id": None,
+ "device_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
@@ -671,13 +710,20 @@ class Auth(object):
"user": user,
"is_guest": False,
"token_id": None,
+ "device_id": None,
}
else:
- # This codepath exists so that we can actually return a
- # token ID, because we use token IDs in place of device
- # identifiers throughout the codebase.
- # TODO(daniel): Remove this fallback when device IDs are
- # properly implemented.
+ # This codepath exists for several reasons:
+ # * so that we can actually return a token ID, which is used
+ # in some parts of the schema (where we probably ought to
+ # use device IDs instead)
+ # * the only way we currently have to invalidate an
+ # access_token is by removing it from the database, so we
+ # have to check here that it is still in the db
+ # * some attributes (notably device_id) aren't stored in the
+ # macaroon. They probably should be.
+ # TODO: build the dictionary from the macaroon once the
+ # above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
@@ -751,10 +797,14 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
+ # we use ret.get() below because *lots* of unit tests stub out
+ # get_user_by_access_token in a way where it only returns a couple of
+ # the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
+ "device_id": ret.get("device_id"),
}
defer.returnValue(user_info)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b106fbed6d..0041646858 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -42,8 +42,10 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
- THREEPID_IN_USE = "THREEPID_IN_USE"
+ THREEPID_IN_USE = "M_THREEPID_IN_USE"
+ THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME"
+ SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4f5a4281fa..3b3ef70750 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -191,6 +191,17 @@ class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
+ self.types = self.filter_json.get("types", None)
+ self.not_types = self.filter_json.get("not_types", [])
+
+ self.rooms = self.filter_json.get("rooms", None)
+ self.not_rooms = self.filter_json.get("not_rooms", [])
+
+ self.senders = self.filter_json.get("senders", None)
+ self.not_senders = self.filter_json.get("not_senders", [])
+
+ self.contains_url = self.filter_json.get("contains_url", None)
+
def check(self, event):
"""Checks whether the filter matches the given event.
@@ -209,9 +220,10 @@ class Filter(object):
event.get("room_id", None),
sender,
event.get("type", None),
+ "url" in event.get("content", {})
)
- def check_fields(self, room_id, sender, event_type):
+ def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -225,15 +237,20 @@ class Filter(object):
for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,)
- disallowed_values = self.filter_json.get(not_name, [])
+ disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)):
return False
- allowed_values = self.filter_json.get(name, None)
+ allowed_values = getattr(self, name)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
return False
+ contains_url_filter = self.filter_json.get("contains_url")
+ if contains_url_filter is not None:
+ if contains_url_filter != contains_url:
+ return False
+
return True
def filter_rooms(self, room_ids):
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 1bc4279807..9c2b627590 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -16,13 +16,11 @@
import sys
sys.dont_write_bytecode = True
-from synapse.python_dependencies import (
- check_requirements, MissingRequirementError
-) # NOQA
+from synapse import python_dependencies # noqa: E402
try:
- check_requirements()
-except MissingRequirementError as e:
+ python_dependencies.check_requirements()
+except python_dependencies.MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
new file mode 100644
index 0000000000..7355499ae2
--- /dev/null
+++ b/synapse/app/federation_reader.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.async import sleep
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+from synapse.api.urls import FEDERATION_PREFIX
+from synapse.federation.transport.server import TransportLayerServer
+from synapse.crypto import context_factory
+
+
+from twisted.internet import reactor, defer
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+import gc
+
+logger = logging.getLogger("synapse.app.federation_reader")
+
+
+class FederationReaderSlavedStore(
+ SlavedEventStore,
+ SlavedKeyStore,
+ RoomStore,
+ DirectoryStore,
+ TransactionStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class FederationReaderServer(HomeServer):
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+ elif name == "federation":
+ resources.update({
+ FEDERATION_PREFIX: TransportLayerServer(self),
+ })
+
+ root_resource = create_resource_tree(resources, Resource())
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=bind_address
+ )
+ logger.info("Synapse federation reader now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ http_client = self.get_simple_http_client()
+ store = self.get_datastore()
+ replication_url = self.config.worker_replication_url
+
+ while True:
+ try:
+ args = store.stream_positions()
+ args["timeout"] = 30000
+ result = yield http_client.get_json(replication_url, args=args)
+ yield store.process_replication(result)
+ except:
+ logger.exception("Error replicating from %r", replication_url)
+ yield sleep(5)
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse federation reader", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ assert config.worker_app == "synapse.app.federation_reader"
+
+ setup_logging(config.worker_log_config, config.worker_log_file)
+
+ database_engine = create_engine(config.database_config)
+
+ tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+ ss = FederationReaderServer(
+ config.server_name,
+ db_config=config.database_config,
+ tls_server_context_factory=tls_server_context_factory,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ database_engine=database_engine,
+ )
+
+ ss.setup()
+ ss.get_handlers()
+ ss.start_listening(config.worker_listeners)
+
+ def run():
+ with LoggingContext("run"):
+ logger.info("Running")
+ change_resource_limit(config.soft_file_limit)
+ if config.gc_thresholds:
+ gc.set_threshold(*config.gc_thresholds)
+ reactor.run()
+
+ def start():
+ ss.get_datastore().start_profiling()
+ ss.replicate()
+
+ reactor.callWhenRunning(start)
+
+ if config.worker_daemonize:
+ daemon = Daemonize(
+ app="synapse-federation-reader",
+ pid=config.worker_pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 40ffd9bf0d..40e6f65236 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -51,6 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
+from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
@@ -147,7 +148,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path, self.auth, self.content_addr
+ self, self.config.uploads_path
),
})
@@ -284,7 +285,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config
check_requirements(config)
- version_string = get_version_string("Synapse", synapse)
+ version_string = "Synapse/" + get_version_string(synapse)
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
@@ -301,7 +302,6 @@ def setup(config_options):
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
- content_addr=config.content_addr,
version_string=version_string,
database_engine=database_engine,
)
@@ -336,6 +336,8 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
+ register_memory_metrics(hs)
+
reactor.callWhenRunning(start)
return hs
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 4f1d18ab5f..c8dde0fcb8 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -273,7 +273,7 @@ def start(config_options):
config.server_name,
db_config=config.database_config,
config=config,
- version_string=get_version_string("Synapse", synapse),
+ version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 8cf5bbbb6d..215ccfd522 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -424,7 +424,7 @@ def start(config_options):
config.server_name,
db_config=config.database_config,
config=config,
- version_string=get_version_string("Synapse", synapse),
+ version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py
index 9c14593a99..d83c2230be 100644
--- a/synapse/config/ldap.py
+++ b/synapse/config/ldap.py
@@ -13,40 +13,88 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from ._base import Config, ConfigError
+
+
+MISSING_LDAP3 = (
+ "Missing ldap3 library. This is required for LDAP Authentication."
+)
+
+
+class LDAPMode(object):
+ SIMPLE = "simple",
+ SEARCH = "search",
+
+ LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
- ldap_config = config.get("ldap_config", None)
- if ldap_config:
- self.ldap_enabled = ldap_config.get("enabled", False)
- self.ldap_server = ldap_config["server"]
- self.ldap_port = ldap_config["port"]
- self.ldap_tls = ldap_config.get("tls", False)
- self.ldap_search_base = ldap_config["search_base"]
- self.ldap_search_property = ldap_config["search_property"]
- self.ldap_email_property = ldap_config["email_property"]
- self.ldap_full_name_property = ldap_config["full_name_property"]
- else:
- self.ldap_enabled = False
- self.ldap_server = None
- self.ldap_port = None
- self.ldap_tls = False
- self.ldap_search_base = None
- self.ldap_search_property = None
- self.ldap_email_property = None
- self.ldap_full_name_property = None
+ ldap_config = config.get("ldap_config", {})
+
+ self.ldap_enabled = ldap_config.get("enabled", False)
+
+ if self.ldap_enabled:
+ # verify dependencies are available
+ try:
+ import ldap3
+ ldap3 # to stop unused lint
+ except ImportError:
+ raise ConfigError(MISSING_LDAP3)
+
+ self.ldap_mode = LDAPMode.SIMPLE
+
+ # verify config sanity
+ self.require_keys(ldap_config, [
+ "uri",
+ "base",
+ "attributes",
+ ])
+
+ self.ldap_uri = ldap_config["uri"]
+ self.ldap_start_tls = ldap_config.get("start_tls", False)
+ self.ldap_base = ldap_config["base"]
+ self.ldap_attributes = ldap_config["attributes"]
+
+ if "bind_dn" in ldap_config:
+ self.ldap_mode = LDAPMode.SEARCH
+ self.require_keys(ldap_config, [
+ "bind_dn",
+ "bind_password",
+ ])
+
+ self.ldap_bind_dn = ldap_config["bind_dn"]
+ self.ldap_bind_password = ldap_config["bind_password"]
+ self.ldap_filter = ldap_config.get("filter", None)
+
+ # verify attribute lookup
+ self.require_keys(ldap_config['attributes'], [
+ "uid",
+ "name",
+ "mail",
+ ])
+
+ def require_keys(self, config, required):
+ missing = [key for key in required if key not in config]
+ if missing:
+ raise ConfigError(
+ "LDAP enabled but missing required config values: {}".format(
+ ", ".join(missing)
+ )
+ )
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
- # server: "ldap://localhost"
- # port: 389
- # tls: false
- # search_base: "ou=Users,dc=example,dc=com"
- # search_property: "cn"
- # email_property: "email"
- # full_name_property: "givenName"
+ # uri: "ldap://ldap.example.com:389"
+ # start_tls: true
+ # base: "ou=users,dc=example,dc=com"
+ # attributes:
+ # uid: "cn"
+ # mail: "email"
+ # name: "givenName"
+ # #bind_dn:
+ # #bind_password:
+ # #filter: "(objectClass=posixAccount)"
"""
diff --git a/synapse/config/password.py b/synapse/config/password.py
index dec801ef41..a4bd171399 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -23,10 +23,14 @@ class PasswordConfig(Config):
def read_config(self, config):
password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True)
+ self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable password for login.
password_config:
enabled: true
+ # Uncomment and change to a secret random string for extra security.
+ # DO NOT CHANGE THIS AFTER INITIAL SETUP!
+ #pepper: ""
"""
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7840dc3ad6..51eaf423ce 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -107,26 +107,6 @@ class ServerConfig(Config):
]
})
- # Attempt to guess the content_addr for the v0 content repostitory
- content_addr = config.get("content_addr")
- if not content_addr:
- for listener in self.listeners:
- if listener["type"] == "http" and not listener.get("tls", False):
- unsecure_port = listener["port"]
- break
- else:
- raise RuntimeError("Could not determine 'content_addr'")
-
- host = self.server_name
- if ':' not in host:
- host = "%s:%d" % (host, unsecure_port)
- else:
- host = host.split(':')[0]
- host = "%s:%d" % (host, unsecure_port)
- content_addr = "http://%s" % (host,)
-
- self.content_addr = content_addr
-
def default_config(self, server_name, **kwargs):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
@@ -169,7 +149,6 @@ class ServerConfig(Config):
# room directory.
# secondary_directory_servers:
# - matrix.org
- # - vector.im
# List of ports that Synapse should listen on, their purpose and their
# configuration.
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 54b83da9d8..c2bd64d6c2 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
+ self._peer = None
def connectionMade(self):
- self.host = self.transport.getHost()
- logger.debug("Connected to %s", self.host)
+ self._peer = self.transport.getPeer()
+ logger.debug("Connected to %s", self._peer)
+
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
@@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
- logger.debug("Timeout waiting for response from %s", self.host)
+ logger.debug(
+ "Timeout waiting for response from %s: %s",
+ self.host, self._peer,
+ )
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
@@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
+ protocol.host = self.host
return protocol
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d08ee0aa91..5012c10ee8 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__)
-KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+VerifyKeyRequest = namedtuple("VerifyRequest", (
+ "server_name", "key_ids", "json_object", "deferred"
+))
+"""
+A request for a verify key to verify a JSON object.
+
+Attributes:
+ server_name(str): The name of the server to verify against.
+ key_ids(set(str)): The set of key_ids to that could be used to verify the
+ JSON object
+ json_object(dict): The JSON object to verify.
+ deferred(twisted.internet.defer.Deferred):
+ A deferred (server_name, key_id, verify_key) tuple that resolves when
+ a verify key has been fetched
+"""
class Keyring(object):
@@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
- group_id_to_json = {}
- group_id_to_group = {}
- group_ids = []
-
- next_group_id = 0
- deferreds = {}
+ verify_requests = []
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
- group_id = next_group_id
- next_group_id += 1
- group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
- deferreds[group_id] = defer.fail(SynapseError(
+ deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
- deferreds[group_id] = defer.Deferred()
+ deferred = defer.Deferred()
- group = KeyGroup(server_name, group_id, key_ids)
+ verify_request = VerifyKeyRequest(
+ server_name, key_ids, json_object, deferred
+ )
- group_id_to_group[group_id] = group
- group_id_to_json[group_id] = json_object
+ verify_requests.append(verify_request)
@defer.inlineCallbacks
- def handle_key_deferred(group, deferred):
- server_name = group.server_name
+ def handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
try:
- _, _, key_id, verify_key = yield deferred
+ _, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
- json_object = group_id_to_json[group.group_id]
+ json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
@@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
- server_to_gids = {}
+ server_to_request_ids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
+ def remove_deferreds(res, server_name, verify_request):
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+ deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- preserve_context_over_fn(
- handle_key_deferred,
- group_id_to_group[g_id],
- deferreds[g_id],
- )
- for g_id in group_ids
+ preserve_context_over_fn(handle_key_deferred, verify_request)
+ for verify_request in verify_requests
]
@defer.inlineCallbacks
@@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
- def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+ def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -237,62 +242,64 @@ class Keyring(object):
merged_results = {}
missing_keys = {}
- for group in group_id_to_group.values():
- missing_keys.setdefault(group.server_name, set()).update(
- group.key_ids
+ for verify_request in verify_requests:
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
- # We now need to figure out which groups we have keys for
- # and which we don't
- missing_groups = {}
- for group in group_id_to_group.values():
- for key_id in group.key_ids:
- if key_id in merged_results[group.server_name]:
+ # We now need to figure out which verify requests we have keys
+ # for and which we don't
+ missing_keys = {}
+ requests_missing_keys = []
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ result_keys = merged_results[server_name]
+
+ if verify_request.deferred.called:
+ # We've already called this deferred, which probably
+ # means that we've already found a key for it.
+ continue
+
+ for key_id in verify_request.key_ids:
+ if key_id in result_keys:
with PreserveLoggingContext():
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
+ verify_request.deferred.callback((
+ server_name,
key_id,
- merged_results[group.server_name][key_id],
+ result_keys[key_id],
))
break
else:
- missing_groups.setdefault(
- group.server_name, []
- ).append(group)
-
- if not missing_groups:
+ # The else block is only reached if the loop above
+ # doesn't break.
+ missing_keys.setdefault(server_name, set()).update(
+ verify_request.key_ids
+ )
+ requests_missing_keys.append(verify_request)
+
+ if not missing_keys:
break
- missing_keys = {
- server_name: set(
- key_id for group in groups for key_id in group.key_ids
- )
- for server_name, groups in missing_groups.items()
- }
-
- for group in missing_groups.values():
- group_id_to_deferred[group.group_id].errback(SynapseError(
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
- group.server_name, group.key_ids,
+ verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
- for deferred in group_id_to_deferred.values():
- if not deferred.called:
- deferred.errback(err)
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
do_iterations().addErrback(on_err)
- return group_id_to_deferred
-
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
@@ -447,7 +454,7 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
- perspective_name, response
+ perspective_name, response, only_from_server=False
)
for server_name, response_keys in processed_response.items():
@@ -527,7 +534,7 @@ class Keyring(object):
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
- requested_ids=[]):
+ requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -551,6 +558,13 @@ class Keyring(object):
results = {}
server_name = response_json["server_name"]
+ if only_from_server:
+ if server_name != from_server:
+ raise ValueError(
+ "Expected a response for server %r not %r" % (
+ from_server, server_name
+ )
+ )
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b06387051c..da95c2ad6d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -236,9 +236,9 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
- e = self._get_pdu_cache.get(event_id)
- if e:
- defer.returnValue(e)
+ ev = self._get_pdu_cache.get(event_id)
+ if ev:
+ defer.returnValue(ev)
pdu = None
for destination in destinations:
@@ -269,7 +269,7 @@ class FederationClient(FederationBase):
break
- except SynapseError:
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@@ -314,6 +314,42 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
+ try:
+ # First we try and ask for just the IDs, as thats far quicker if
+ # we have most of the state and auth_chain already.
+ # However, this may 404 if the other side has an old synapse.
+ result = yield self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id,
+ )
+
+ state_event_ids = result["pdu_ids"]
+ auth_event_ids = result.get("auth_chain_ids", [])
+
+ fetched_events, failed_to_fetch = yield self.get_events(
+ [destination], room_id, set(state_event_ids + auth_event_ids)
+ )
+
+ if failed_to_fetch:
+ logger.warn("Failed to get %r", failed_to_fetch)
+
+ event_map = {
+ ev.event_id: ev for ev in fetched_events
+ }
+
+ pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+ auth_chain = [
+ event_map[e_id] for e_id in auth_event_ids if e_id in event_map
+ ]
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue((pdus, auth_chain))
+ except HttpResponseException as e:
+ if e.code == 400 or e.code == 404:
+ logger.info("Failed to use get_room_state_ids API, falling back")
+ else:
+ raise e
+
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
@@ -327,12 +363,26 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
+ seen_events = yield self.store.get_events([
+ ev.event_id for ev in itertools.chain(pdus, auth_chain)
+ ])
+
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus, outlier=True
+ destination,
+ [p for p in pdus if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_pdus.extend(
+ seen_events[p.event_id] for p in pdus if p.event_id in seen_events
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
+ destination,
+ [p for p in auth_chain if p.event_id not in seen_events],
+ outlier=True
+ )
+ signed_auth.extend(
+ seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
)
signed_auth.sort(key=lambda e: e.depth)
@@ -340,6 +390,67 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
+ def get_events(self, destinations, room_id, event_ids, return_local=True):
+ """Fetch events from some remote destinations, checking if we already
+ have them.
+
+ Args:
+ destinations (list)
+ room_id (str)
+ event_ids (list)
+ return_local (bool): Whether to include events we already have in
+ the DB in the returned list of events
+
+ Returns:
+ Deferred: A deferred resolving to a 2-tuple where the first is a list of
+ events and the second is a list of event ids that we failed to fetch.
+ """
+ if return_local:
+ seen_events = yield self.store.get_events(event_ids)
+ signed_events = seen_events.values()
+ else:
+ seen_events = yield self.store.have_events(event_ids)
+ signed_events = []
+
+ failed_to_fetch = set()
+
+ missing_events = set(event_ids)
+ for k in seen_events:
+ missing_events.discard(k)
+
+ if not missing_events:
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ def random_server_list():
+ srvs = list(destinations)
+ random.shuffle(srvs)
+ return srvs
+
+ batch_size = 20
+ missing_events = list(missing_events)
+ for i in xrange(0, len(missing_events), batch_size):
+ batch = set(missing_events[i:i + batch_size])
+
+ deferreds = [
+ self.get_pdu(
+ destinations=random_server_list(),
+ event_id=e_id,
+ )
+ for e_id in batch
+ ]
+
+ res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ for success, result in res:
+ if success:
+ signed_events.append(result)
+ batch.discard(result.event_id)
+
+ # We removed all events we successfully fetched from `batch`
+ failed_to_fetch.update(batch)
+
+ defer.returnValue((signed_events, failed_to_fetch))
+
+ @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
@@ -414,14 +525,19 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict))
)
break
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.warn(
+ "Failed to make_%s via %s: %s",
+ membership, destination, e.message
+ )
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
- raise
raise RuntimeError("Failed to send to any server.")
@@ -493,8 +609,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth,
"origin": destination,
})
- except CodeMessageException:
- raise
+ except CodeMessageException as e:
+ if not 500 <= e.code < 600:
+ raise
+ else:
+ logger.exception(
+ "Failed to send_join via %s: %s",
+ destination, e.message
+ )
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2a589524a4..aba19639c7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -21,10 +21,11 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
+from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
import synapse.metrics
-from synapse.api.errors import FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
@@ -48,7 +49,14 @@ class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
+ self.auth = hs.get_auth()
+
self._room_pdu_linearizer = Linearizer()
+ self._server_linearizer = Linearizer()
+
+ # We cache responses to state queries, as they take a while and often
+ # come in waves.
+ self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
@@ -89,11 +97,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+ res = self._transaction_from_pdus(pdus).get_dict()
+
+ defer.returnValue((200, res))
@defer.inlineCallbacks
@log_function
@@ -184,32 +195,71 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ if not event_id:
+ raise NotImplementedError("Specify an event")
- for event in auth_chain:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ result = self._state_resp_cache.get((room_id, event_id))
+ if not result:
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.set(
+ (room_id, event_id),
+ self._on_context_state_request_compute(room_id, event_id)
+ )
else:
+ resp = yield result
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_state_ids_request(self, origin, room_id, event_id):
+ if not event_id:
raise NotImplementedError("Specify an event")
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
+ pdus = yield self.handler.get_state_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+
defer.returnValue((200, {
+ "pdu_ids": [pdu.event_id for pdu in pdus],
+ "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+ }))
+
+ @defer.inlineCallbacks
+ def _on_context_state_request_compute(self, room_id, event_id):
+ pdus = yield self.handler.get_state_for_pdu(
+ room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+
+ for event in auth_chain:
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
+ )
+
+ defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }))
+ })
@defer.inlineCallbacks
@log_function
@@ -283,14 +333,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ res = {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }
+ defer.returnValue((200, res))
@defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, event_id):
+ def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -309,58 +361,41 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- auth_chain = [
- self.event_from_pdu_json(e)
- for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ auth_chain = [
+ self.event_from_pdu_json(e)
+ for e in content["auth_chain"]
+ ]
+
+ signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ origin, auth_chain, outlier=True
+ )
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
+ ret = yield self.handler.on_query_auth(
+ origin,
+ event_id,
+ signed_auth,
+ content.get("rejects", []),
+ content.get("missing", []),
+ )
- time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": ret.get("rejects", []),
+ "missing": ret.get("missing", []),
+ }
defer.returnValue(
(200, send_content)
)
- @defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
- query = []
- for user_id, device_ids in content.get("device_keys", {}).items():
- if not device_ids:
- query.append((user_id, None))
- else:
- for device_id in device_ids:
- query.append((user_id, device_id))
-
- results = yield self.store.get_e2e_device_keys(query)
-
- json_result = {}
- for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- json_result.setdefault(user_id, {})[device_id] = json.loads(
- json_bytes
- )
-
- defer.returnValue({"device_keys": json_result})
+ return self.on_query_request("client_keys", content)
@defer.inlineCallbacks
@log_function
@@ -386,21 +421,24 @@ class FederationServer(FederationBase):
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
- logger.info(
- "on_get_missing_events: earliest_events: %r, latest_events: %r,"
- " limit: %d, min_depth: %d",
- earliest_events, latest_events, limit, min_depth
- )
- missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit, min_depth
- )
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ logger.info(
+ "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+ " limit: %d, min_depth: %d",
+ earliest_events, latest_events, limit, min_depth
+ )
+ missing_events = yield self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit, min_depth
+ )
- if len(missing_events) < 5:
- logger.info("Returning %d events: %r", len(missing_events), missing_events)
- else:
- logger.info("Returning %d events", len(missing_events))
+ if len(missing_events) < 5:
+ logger.info(
+ "Returning %d events: %r", len(missing_events), missing_events
+ )
+ else:
+ logger.info("Returning %d events", len(missing_events))
- time_now = self._clock.time_msec()
+ time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
@@ -567,7 +605,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id,
)
except:
- logger.warn("Failed to get state for event: %s", pdu.event_id)
+ logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index ebb698e278..3d088e43cb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -55,6 +55,28 @@ class TransportLayerClient(object):
)
@log_function
+ def get_room_state_ids(self, destination, room_id, event_id):
+ """ Requests all state for a given room from the given server at the
+ given event. Returns the state's event_id's
+
+ Args:
+ destination (str): The host name of the remote home server we want
+ to get the state from.
+ context (str): The name of the context we want the state of
+ event_id (str): The event we want the context at.
+
+ Returns:
+ Deferred: Results in a dict received from the remote homeserver.
+ """
+ logger.debug("get_room_state_ids dest=%s, room=%s",
+ destination, room_id)
+
+ path = PREFIX + "/state_ids/%s/" % room_id
+ return self.client.get_json(
+ destination, path=path, args={"event_id": event_id},
+ )
+
+ @log_function
def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8a1965f45a..37c0d4fbc4 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,13 +18,14 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import parse_json_object_from_request, parse_string
+from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.versionstring import get_version_string
import functools
import logging
-import simplejson as json
import re
+import synapse
logger = logging.getLogger(__name__)
@@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
)
+class AuthenticationError(SynapseError):
+ """There was a problem authenticating the request"""
+ pass
+
+
+class NoAuthenticationError(AuthenticationError):
+ """The request had no authentication information"""
+ pass
+
+
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
@@ -67,7 +78,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
- def authenticate_request(self, request):
+ def authenticate_request(self, request, content):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -75,17 +86,10 @@ class Authenticator(object):
"signatures": {},
}
- content = None
- origin = None
+ if content is not None:
+ json_request["content"] = content
- if request.method in ["PUT", "POST"]:
- # TODO: Handle other method types? other content types?
- try:
- content_bytes = request.content.read()
- content = json.loads(content_bytes)
- json_request["content"] = content
- except:
- raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
+ origin = None
def parse_auth_header(header_str):
try:
@@ -103,14 +107,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
- raise SynapseError(
+ raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -121,7 +125,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
- raise SynapseError(
+ raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -130,10 +134,12 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
- defer.returnValue((origin, content))
+ defer.returnValue(origin)
class BaseFederationServlet(object):
+ REQUIRE_AUTH = True
+
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler
@@ -141,29 +147,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
- def _wrap(self, code):
+ def _wrap(self, func):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
- @functools.wraps(code)
- def new_code(request, *args, **kwargs):
+ @functools.wraps(func)
+ def new_func(request, *args, **kwargs):
+ content = None
+ if request.method in ["PUT", "POST"]:
+ # TODO: Handle other method types? other content types?
+ content = parse_json_object_from_request(request)
+
try:
- (origin, content) = yield authenticator.authenticate_request(request)
+ origin = yield authenticator.authenticate_request(request, content)
+ except NoAuthenticationError:
+ origin = None
+ if self.REQUIRE_AUTH:
+ logger.exception("authenticate_request failed")
+ raise
+ except:
+ logger.exception("authenticate_request failed")
+ raise
+
+ if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
- response = yield code(
+ response = yield func(
origin, content, request.args, *args, **kwargs
)
- except:
- logger.exception("authenticate_request failed")
- raise
+ else:
+ response = yield func(
+ origin, content, request.args, *args, **kwargs
+ )
+
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
- new_code.__self__ = code.__self__
+ new_func.__self__ = func.__self__
- return new_code
+ return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -271,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
)
+class FederationStateIdsServlet(BaseFederationServlet):
+ PATH = "/state_ids/(?P<room_id>[^/]*)/"
+
+ def on_GET(self, origin, content, query, room_id):
+ return self.handler.on_state_ids_request(
+ origin,
+ room_id,
+ query.get("event_id", [None])[0],
+ )
+
+
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/"
@@ -367,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
- @defer.inlineCallbacks
def on_POST(self, origin, content, query):
- response = yield self.handler.on_query_client_keys(origin, content)
- defer.returnValue((200, response))
+ return self.handler.on_query_client_keys(origin, content)
class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -388,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request(
- origin, content, event_id
+ origin, content, context, event_id
)
defer.returnValue((200, new_content))
@@ -420,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_POST(self, request):
- content = parse_json_object_from_request(request)
+ def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -444,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception
defer.returnValue((200, {}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
class OpenIdUserInfo(BaseFederationServlet):
"""
@@ -469,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo"
+ REQUIRE_AUTH = False
+
@defer.inlineCallbacks
- def on_GET(self, request):
- token = parse_string(request, "access_token")
+ def on_GET(self, origin, content, query):
+ token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -488,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id}))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
class PublicRoomList(BaseFederationServlet):
"""
@@ -533,11 +558,26 @@ class PublicRoomList(BaseFederationServlet):
defer.returnValue((200, data))
+class FederationVersionServlet(BaseFederationServlet):
+ PATH = "/version"
+
+ REQUIRE_AUTH = False
+
+ def on_GET(self, origin, content, query):
+ return defer.succeed((200, {
+ "server": {
+ "name": "Synapse",
+ "version": get_version_string(synapse)
+ },
+ }))
+
+
SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
+ FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
@@ -555,6 +595,7 @@ SERVLET_CLASSES = (
On3pidBindServlet,
OpenIdUserInfo,
PublicRoomList,
+ FederationVersionServlet,
)
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index d28e07f0d9..1a50a2ec98 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -31,10 +31,21 @@ from .search import SearchHandler
class Handlers(object):
- """ A collection of all the event handlers.
+ """ Deprecated. A collection of handlers.
- There's no need to lazily create these; we'll just make them all eagerly
- at construction time.
+ At some point most of the classes whose name ended "Handler" were
+ accessed through this class.
+
+ However this makes it painful to unit test the handlers and to run cut
+ down versions of synapse that only use specific handlers because using a
+ single handler required creating all of the handlers. So some of the
+ handlers have been lifted out of the Handlers object and are now accessed
+ directly through the homeserver object itself.
+
+ Any new handlers should follow the new pattern of being accessed through
+ the homeserver object and should not be added to the Handlers object.
+
+ The remaining handlers should be moved out of the handlers object.
"""
def __init__(self, hs):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c904c6c500..11081a0cd5 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import LimitExceededError
+import synapse.types
from synapse.api.constants import Membership, EventTypes
-from synapse.types import UserID, Requester
-
-
-import logging
+from synapse.api.errors import LimitExceededError
+from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -31,11 +31,15 @@ class BaseHandler(object):
Common base class for the event handlers.
Attributes:
- store (synapse.storage.events.StateStore):
+ store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -120,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = Requester(target_user, "", True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b38f81e999..2e138f328f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor
+from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError
@@ -28,6 +29,12 @@ import bcrypt
import pymacaroons
import simplejson
+try:
+ import ldap3
+except ImportError:
+ ldap3 = None
+ pass
+
import synapse.util.stringutils as stringutils
@@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
@@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
- self.ldap_server = hs.config.ldap_server
- self.ldap_port = hs.config.ldap_port
- self.ldap_tls = hs.config.ldap_tls
- self.ldap_search_base = hs.config.ldap_search_base
- self.ldap_search_property = hs.config.ldap_search_property
- self.ldap_email_property = hs.config.ldap_email_property
- self.ldap_full_name_property = hs.config.ldap_full_name_property
-
- if self.ldap_enabled is True:
- import ldap
- logger.info("Import ldap version: %s", ldap.__version__)
+ if self.ldap_enabled:
+ if not ldap3:
+ raise RuntimeError(
+ 'Missing ldap3 library. This is required for LDAP Authentication.'
+ )
+ self.ldap_mode = hs.config.ldap_mode
+ self.ldap_uri = hs.config.ldap_uri
+ self.ldap_start_tls = hs.config.ldap_start_tls
+ self.ldap_base = hs.config.ldap_base
+ self.ldap_filter = hs.config.ldap_filter
+ self.ldap_attributes = hs.config.ldap_attributes
+ if self.ldap_mode == LDAPMode.SEARCH:
+ self.ldap_bind_dn = hs.config.ldap_bind_dn
+ self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later?
+ self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- @defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- defer.returnValue(user_id)
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -270,8 +280,17 @@ class AuthHandler(BaseHandler):
data = pde.response
resp_body = simplejson.loads(data)
- if 'success' in resp_body and resp_body['success']:
- defer.returnValue(True)
+ if 'success' in resp_body:
+ # Note that we do NOT check the hostname here: we explicitly
+ # intend the CAPTCHA to be presented by whatever client the
+ # user is using, we just care that they have completed a CAPTCHA.
+ logger.info(
+ "%s reCAPTCHA from hostname %s",
+ "Successful" if resp_body['success'] else "Failed",
+ resp_body.get('hostname')
+ )
+ if resp_body['success']:
+ defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
@@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- @defer.inlineCallbacks
- def login_with_password(self, user_id, password):
+ def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
- user_id (str): User ID
+ user_id (str): complete @user:id
password (str): Password
Returns:
- A tuple of:
- The user's ID.
- The access token for the user's session.
- The refresh token for the user's session.
+ defer.Deferred: (str) canonical user id
Raises:
- StoreError if there was a problem storing the token.
+ StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
-
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- logger.info("Logging in user %s", user_id)
- access_token = yield self.issue_access_token(user_id)
- refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
- def get_login_tuple_for_user_id(self, user_id):
+ def get_login_tuple_for_user_id(self, user_id, device_id=None,
+ initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
+
+ Creates a new access/refresh token for the user.
+
The user is assumed to have been authenticated by some other
- machanism (e.g. CAS)
+ machanism (e.g. CAS), and the user_id converted to the canonical case.
+
+ The device will be recorded in the table if it is not there already.
Args:
- user_id (str): User ID
+ user_id (str): canonical User ID
+ device_id (str|None): the device ID to associate with the tokens.
+ None to leave the tokens unassociated with a device (deprecated:
+ we should always have a device ID)
+ initial_display_name (str): display name to associate with the
+ device if it needs re-registering
Returns:
A tuple of:
- The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
- user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
+ logger.info("Logging in user %s on device %s", user_id, device_id)
+ access_token = yield self.issue_access_token(user_id, device_id)
+ refresh_token = yield self.issue_refresh_token(user_id, device_id)
+
+ # the device *should* have been registered before we got here; however,
+ # it's possible we raced against a DELETE operation. The thing we
+ # really don't want is active access_tokens without a record of the
+ # device, so we double-check it here.
+ if device_id is not None:
+ yield self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
- logger.info("Logging in user %s", user_id)
- access_token = yield self.issue_access_token(user_id)
- refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks
- def does_user_exist(self, user_id):
+ def check_user_exists(self, user_id):
+ """
+ Checks to see if a user with the given id exists. Will check case
+ insensitively, but return None if there are multiple inexact matches.
+
+ Args:
+ (str) user_id: complete @user:id
+
+ Returns:
+ defer.Deferred: (str) canonical_user_id, or None if zero or
+ multiple matches
+ """
try:
- yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(True)
+ res = yield self._find_user_id_and_pwd_hash(user_id)
+ defer.returnValue(res[0])
except LoginError:
- defer.returnValue(False)
+ defer.returnValue(None)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_password(self, user_id, password):
- """
+ """Authenticate a user against the LDAP and local databases.
+
+ user_id is checked case insensitively against the local database, but
+ will throw if there are multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
Returns:
- True if the user_id successfully authenticated
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
- defer.returnValue(True)
+ defer.returnValue(user_id)
- valid_local_password = yield self._check_local_password(user_id, password)
- if valid_local_password:
- defer.returnValue(True)
-
- defer.returnValue(False)
+ result = yield self._check_local_password(user_id, password)
+ defer.returnValue(result)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
- try:
- user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(self.validate_hash(password, password_hash))
- except LoginError:
- defer.returnValue(False)
+ """Authenticate a user against the local password database.
+
+ user_id is checked case insensitively, but will throw if there are
+ multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
+ Returns:
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
+ """
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ result = self.validate_hash(password, password_hash)
+ if not result:
+ logger.warn("Failed password login for user %s", user_id)
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
- if not self.ldap_enabled:
- logger.debug("LDAP not configured")
+ """ Attempt to authenticate a user against an LDAP Server
+ and register an account if none exists.
+
+ Returns:
+ True if authentication against LDAP was successful
+ """
+
+ if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
- import ldap
+ if self.ldap_mode not in LDAPMode.LIST:
+ raise RuntimeError(
+ 'Invalid ldap mode specified: {mode}'.format(
+ mode=self.ldap_mode
+ )
+ )
- logger.info("Authenticating %s with LDAP" % user_id)
try:
- ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
- logger.debug("Connecting LDAP server at %s" % ldap_url)
- l = ldap.initialize(ldap_url)
- if self.ldap_tls:
- logger.debug("Initiating TLS")
- self._connection.start_tls_s()
-
- local_name = UserID.from_string(user_id).localpart
-
- dn = "%s=%s, %s" % (
- self.ldap_search_property,
- local_name,
- self.ldap_search_base)
- logger.debug("DN for LDAP authentication: %s" % dn)
-
- l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
-
- if not (yield self.does_user_exist(user_id)):
- handler = self.hs.get_handlers().registration_handler
- user_id, access_token = (
- yield handler.register(localpart=local_name)
+ server = ldap3.Server(self.ldap_uri)
+ logger.debug(
+ "Attempting ldap connection with %s",
+ self.ldap_uri
+ )
+
+ localpart = UserID.from_string(user_id).localpart
+ if self.ldap_mode == LDAPMode.SIMPLE:
+ # bind with the the local users ldap credentials
+ bind_dn = "{prop}={value},{base}".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart,
+ base=self.ldap_base
+ )
+ conn = ldap3.Connection(server, bind_dn, password)
+ logger.debug(
+ "Established ldap connection in simple mode: %s",
+ conn
)
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in simple mode through StartTLS: %s",
+ conn
+ )
+
+ conn.bind()
+
+ elif self.ldap_mode == LDAPMode.SEARCH:
+ # connect with preconfigured credentials and search for local user
+ conn = ldap3.Connection(
+ server,
+ self.ldap_bind_dn,
+ self.ldap_bind_password
+ )
+ logger.debug(
+ "Established ldap connection in search mode: %s",
+ conn
+ )
+
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded ldap connection in search mode through StartTLS: %s",
+ conn
+ )
+
+ conn.bind()
+
+ # find matching dn
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+ if self.ldap_filter:
+ query = "(&{query}{filter})".format(
+ query=query,
+ filter=self.ldap_filter
+ )
+ logger.debug("ldap search filter: %s", query)
+ result = conn.search(self.ldap_base, query)
+
+ if result and len(conn.response) == 1:
+ # found exactly one result
+ user_dn = conn.response[0]['dn']
+ logger.debug('ldap search found dn: %s', user_dn)
+
+ # unbind and reconnect, rebind with found dn
+ conn.unbind()
+ conn = ldap3.Connection(
+ server,
+ user_dn,
+ password,
+ auto_bind=True
+ )
+ else:
+ # found 0 or > 1 results, abort!
+ logger.warn(
+ "ldap search returned unexpected (%d!=1) amount of results",
+ len(conn.response)
+ )
+ defer.returnValue(False)
+
+ logger.info(
+ "User authenticated against ldap server: %s",
+ conn
+ )
+
+ # check for existing account, if none exists, create one
+ if not (yield self.check_user_exists(user_id)):
+ # query user metadata for account creation
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+
+ if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
+ query = "(&{filter}{user_filter})".format(
+ filter=query,
+ user_filter=self.ldap_filter
+ )
+ logger.debug("ldap registration filter: %s", query)
+
+ result = conn.search(
+ search_base=self.ldap_base,
+ search_filter=query,
+ attributes=[
+ self.ldap_attributes['name'],
+ self.ldap_attributes['mail']
+ ]
+ )
+
+ if len(conn.response) == 1:
+ attrs = conn.response[0]['attributes']
+ mail = attrs[self.ldap_attributes['mail']][0]
+ name = attrs[self.ldap_attributes['name']][0]
+
+ # create account
+ registration_handler = self.hs.get_handlers().registration_handler
+ user_id, access_token = (
+ yield registration_handler.register(localpart=localpart)
+ )
+
+ # TODO: bind email, set displayname with data from ldap directory
+
+ logger.info(
+ "ldap registration successful: %d: %s (%s, %)",
+ user_id,
+ localpart,
+ name,
+ mail
+ )
+ else:
+ logger.warn(
+ "ldap registration failed: unexpected (%d!=1) amount of results",
+ len(result)
+ )
+ defer.returnValue(False)
+
defer.returnValue(True)
- except ldap.LDAPError, e:
- logger.warn("LDAP error: %s", e)
+ except ldap3.core.exceptions.LDAPException as e:
+ logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
- def issue_access_token(self, user_id):
+ def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token)
+ yield self.store.add_access_token_to_user(user_id, access_token,
+ device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
- def issue_refresh_token(self, user_id):
+ def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
- yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+ yield self.store.add_refresh_token_to_user(user_id, refresh_token,
+ device_id)
defer.returnValue(refresh_token)
- def generate_access_token(self, user_id, extra_caveats=None):
+ def generate_access_token(self, user_id, extra_caveats=None,
+ duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
- expiry = now + (60 * 60 * 1000)
+ expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
@@ -613,7 +797,8 @@ class AuthHandler(BaseHandler):
Returns:
Hashed password (str).
"""
- return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -626,6 +811,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
- return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash
+ return bcrypt.hashpw(password + self.hs.config.password_pepper,
+ stored_hash.encode('utf-8')) == stored_hash
else:
return False
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
new file mode 100644
index 0000000000..8d630c6b1a
--- /dev/null
+++ b/synapse/handlers/device.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api import errors
+from synapse.util import stringutils
+from twisted.internet import defer
+from ._base import BaseHandler
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceHandler(BaseHandler):
+ def __init__(self, hs):
+ super(DeviceHandler, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def check_device_registered(self, user_id, device_id,
+ initial_device_display_name=None):
+ """
+ If the given device has not been registered, register it with the
+ supplied display name.
+
+ If no device_id is supplied, we make one up.
+
+ Args:
+ user_id (str): @user:id
+ device_id (str | None): device id supplied by client
+ initial_device_display_name (str | None): device display name from
+ client
+ Returns:
+ str: device id (generated if none was supplied)
+ """
+ if device_id is not None:
+ yield self.store.store_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=initial_device_display_name,
+ ignore_if_known=True,
+ )
+ defer.returnValue(device_id)
+
+ # if the device id is not specified, we'll autogen one, but loop a few
+ # times in case of a clash.
+ attempts = 0
+ while attempts < 5:
+ try:
+ device_id = stringutils.random_string_with_symbols(16)
+ yield self.store.store_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=initial_device_display_name,
+ ignore_if_known=False,
+ )
+ defer.returnValue(device_id)
+ except errors.StoreError:
+ attempts += 1
+
+ raise errors.StoreError(500, "Couldn't generate a device ID.")
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """
+ Retrieve the given user's devices
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: list[dict[str, X]]: info on each device
+ """
+
+ device_map = yield self.store.get_devices_by_user(user_id)
+
+ ips = yield self.store.get_last_client_ip_by_device(
+ devices=((user_id, device_id) for device_id in device_map.keys())
+ )
+
+ devices = device_map.values()
+ for device in devices:
+ _update_device_from_client_ips(device, ips)
+
+ defer.returnValue(devices)
+
+ @defer.inlineCallbacks
+ def get_device(self, user_id, device_id):
+ """ Retrieve the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+
+ Returns:
+ defer.Deferred: dict[str, X]: info on the device
+ Raises:
+ errors.NotFoundError: if the device was not found
+ """
+ try:
+ device = yield self.store.get_device(user_id, device_id)
+ except errors.StoreError:
+ raise errors.NotFoundError
+ ips = yield self.store.get_last_client_ip_by_device(
+ devices=((user_id, device_id),)
+ )
+ _update_device_from_client_ips(device, ips)
+ defer.returnValue(device)
+
+ @defer.inlineCallbacks
+ def delete_device(self, user_id, device_id):
+ """ Delete the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.delete_device(user_id, device_id)
+ except errors.StoreError, e:
+ if e.code == 404:
+ # no match
+ pass
+ else:
+ raise
+
+ yield self.store.user_delete_access_tokens(
+ user_id, device_id=device_id,
+ delete_refresh_tokens=True,
+ )
+
+ yield self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=device_id
+ )
+
+ @defer.inlineCallbacks
+ def update_device(self, user_id, device_id, content):
+ """ Update the given device
+
+ Args:
+ user_id (str):
+ device_id (str):
+ content (dict): body of update request
+
+ Returns:
+ defer.Deferred:
+ """
+
+ try:
+ yield self.store.update_device(
+ user_id,
+ device_id,
+ new_display_name=content.get("display_name")
+ )
+ except errors.StoreError, e:
+ if e.code == 404:
+ raise errors.NotFoundError()
+ else:
+ raise
+
+
+def _update_device_from_client_ips(device, client_ips):
+ ip = client_ips.get((device["user_id"], device["device_id"]), {})
+ device.update({
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ })
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
new file mode 100644
index 0000000000..2c7bfd91ed
--- /dev/null
+++ b/synapse/handlers/e2e_keys.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import json
+import logging
+
+from twisted.internet import defer
+
+from synapse.api import errors
+import synapse.types
+
+logger = logging.getLogger(__name__)
+
+
+class E2eKeysHandler(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_replication_layer()
+ self.is_mine_id = hs.is_mine_id
+ self.server_name = hs.hostname
+
+ # doesn't really work as part of the generic query API, because the
+ # query request requires an object POST, but we abuse the
+ # "query handler" interface.
+ self.federation.register_query_handler(
+ "client_keys", self.on_federation_query_client_keys
+ )
+
+ @defer.inlineCallbacks
+ def query_devices(self, query_body):
+ """ Handle a device key query from a client
+
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ }
+ }
+ ->
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ ...
+ }
+ }
+ }
+ }
+ """
+ device_keys_query = query_body.get("device_keys", {})
+
+ # separate users by domain.
+ # make a map from domain to user_id to device_ids
+ queries_by_domain = collections.defaultdict(dict)
+ for user_id, device_ids in device_keys_query.items():
+ user = synapse.types.UserID.from_string(user_id)
+ queries_by_domain[user.domain][user_id] = device_ids
+
+ # do the queries
+ # TODO: do these in parallel
+ results = {}
+ for destination, destination_query in queries_by_domain.items():
+ if destination == self.server_name:
+ res = yield self.query_local_devices(destination_query)
+ else:
+ res = yield self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}
+ )
+ res = res["device_keys"]
+ for user_id, keys in res.items():
+ if user_id in destination_query:
+ results[user_id] = keys
+
+ defer.returnValue((200, {"device_keys": results}))
+
+ @defer.inlineCallbacks
+ def query_local_devices(self, query):
+ """Get E2E device keys for local users
+
+ Args:
+ query (dict[string, list[string]|None): map from user_id to a list
+ of devices to query (None for all devices)
+
+ Returns:
+ defer.Deferred: (resolves to dict[string, dict[string, dict]]):
+ map from user_id -> device_id -> device details
+ """
+ local_query = []
+
+ result_dict = {}
+ for user_id, device_ids in query.items():
+ if not self.is_mine_id(user_id):
+ logger.warning("Request for keys for non-local user %s",
+ user_id)
+ raise errors.SynapseError(400, "Not a user here")
+
+ if not device_ids:
+ local_query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ local_query.append((user_id, device_id))
+
+ # make sure that each queried user appears in the result dict
+ result_dict[user_id] = {}
+
+ results = yield self.store.get_e2e_device_keys(local_query)
+
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
+ for user_id, device_keys in results.items():
+ for device_id, device_info in device_keys.items():
+ r = json.loads(device_info["key_json"])
+ r["unsigned"] = {}
+ display_name = device_info["device_display_name"]
+ if display_name is not None:
+ r["unsigned"]["device_display_name"] = display_name
+ result_dict[user_id][device_id] = r
+
+ defer.returnValue(result_dict)
+
+ @defer.inlineCallbacks
+ def on_federation_query_client_keys(self, query_body):
+ """ Handle a device key query from a federated server
+ """
+ device_keys_query = query_body.get("device_keys", {})
+ res = yield self.query_local_devices(device_keys_query)
+ defer.returnValue({"device_keys": res})
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6c0bc7eafa..618cb53629 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -124,7 +124,7 @@ class FederationHandler(BaseHandler):
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
except AuthError as e:
raise FederationError(
@@ -335,31 +335,58 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- seen_events = yield self.store.have_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- all_events = events + state_events.values() + auth_events.values()
required_auth = set(
- a_id for event in all_events for a_id, _ in event.auth_events
+ a_id
+ for event in events + state_events.values() + auth_events.values()
+ for a_id, _ in event.auth_events
)
-
+ auth_events.update({
+ e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
+ })
missing_auth = required_auth - set(auth_events)
- if missing_auth:
+ failed_to_fetch = set()
+
+ # Try and fetch any missing auth events from both DB and remote servers.
+ # We repeatedly do this until we stop finding new auth events.
+ while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
- results = yield defer.gatherResults(
- [
- self.replication_layer.get_pdu(
- [dest],
- event_id,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results})
+ ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ auth_events.update(ret_events)
+
+ required_auth.update(
+ a_id for event in ret_events.values() for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ if missing_auth - failed_to_fetch:
+ logger.info(
+ "Fetching missing auth for backfill: %r",
+ missing_auth - failed_to_fetch
+ )
+
+ results = yield defer.gatherResults(
+ [
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth - failed_to_fetch
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
+ required_auth.update(
+ a_id for event in results for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ failed_to_fetch = missing_auth - set(auth_events)
+
+ seen_events = yield self.store.have_events(
+ set(auth_events.keys()) | set(state_events.keys())
+ )
ev_infos = []
for a in auth_events.values():
@@ -372,6 +399,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
+ if a_id in auth_events
}
})
@@ -383,6 +411,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
+ if a_id in auth_events
}
})
@@ -637,7 +666,7 @@ class FederationHandler(BaseHandler):
pass
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
with PreserveLoggingContext():
@@ -688,7 +717,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e)
raise e
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_join_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event)
@@ -918,7 +949,9 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_leave_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -987,14 +1020,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
+ def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor()
- if do_auth:
- in_room = yield self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
@@ -1114,11 +1142,12 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
- # this intentionally does not yield: we don't care about the result
- # and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
- )
+ if not backfilled:
+ # this intentionally does not yield: we don't care about the result
+ # and don't need to wait for it.
+ preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ event_stream_id, max_stream_id
+ )
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1150,11 +1179,19 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _persist_auth_tree(self, auth_events, state, event):
+ def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
+ Will attempt to fetch missing auth events.
+
+ Args:
+ origin (str): Where the events came from
+ auth_events (list)
+ state (list)
+ event (Event)
+
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
@@ -1167,7 +1204,7 @@ class FederationHandler(BaseHandler):
event_map = {
e.event_id: e
- for e in auth_events
+ for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1176,10 +1213,29 @@ class FederationHandler(BaseHandler):
create_event = e
break
+ missing_auth_events = set()
+ for e in itertools.chain(auth_events, state, [event]):
+ for e_id, _ in e.auth_events:
+ if e_id not in event_map:
+ missing_auth_events.add(e_id)
+
+ for e_id in missing_auth_events:
+ m_ev = yield self.replication_layer.get_pdu(
+ [origin],
+ e_id,
+ outlier=True,
+ timeout=10000,
+ )
+ if m_ev and m_ev.event_id == e_id:
+ event_map[e_id] = m_ev
+ else:
+ logger.info("Failed to find auth event %r", e_id)
+
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
+ if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1413,7 +1469,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
- (d.type, d.state_key): d for d in different_events
+ (d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 656ce124f9..559e5d5a71 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,7 +21,7 @@ from synapse.api.errors import (
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
import json
import logging
@@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
+ def _should_trust_id_server(self, id_server):
+ if id_server not in self.trusted_id_servers:
+ if self.trust_any_id_server_just_for_testing_do_not_use:
+ logger.warn(
+ "Trusting untrustworthy ID server %r even though it isn't"
+ " in the trusted id list for testing because"
+ " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
+ " is set in the config",
+ id_server,
+ )
+ else:
+ return False
+ return True
+
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
@@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
- if id_server not in self.trusted_id_servers:
- if self.trust_any_id_server_just_for_testing_do_not_use:
- logger.warn(
- "Trusting untrustworthy ID server %r even though it isn't"
- " in the trusted id list for testing because"
- " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
- " is set in the config",
- id_server,
- )
- else:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server)
- defer.returnValue(None)
+ if not self._should_trust_id_server(id_server):
+ logger.warn(
+ '%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server
+ )
+ defer.returnValue(None)
data = {}
try:
@@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server,
+ Codes.SERVER_NOT_TRUSTED
+ )
+
params = {
'email': email,
'client_secret': client_secret,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 15caf1950a..dc76d34a52 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -26,7 +26,7 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute, run_on_reactor
+from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
@@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
+ self.pagination_lock = ReadWriteLock()
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, event_id):
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ depth = event.depth
+
+ with (yield self.pagination_lock.write(room_id)):
+ yield self.store.delete_old_state(room_id, depth)
+
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True):
+ as_client_event=True, event_filter=None):
"""Get messages in a room.
Args:
@@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
+ event_filter (Filter): Filter to apply to results or None
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
- data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
@@ -85,42 +99,48 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room")
- membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
+ with (yield self.pagination_lock.read(room_id)):
+ membership, member_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
- if source_config.direction == 'b':
- # if we're going backwards, we might need to backfill. This
- # requires that we have a topo token.
- if room_token.topological:
- max_topo = room_token.topological
- else:
- max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
- room_id, room_token.stream
- )
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room, to save the effort of loading from the
- # database.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < max_topo:
- source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
+ events, next_key = yield self.store.paginate_room_events(
+ room_id=room_id,
+ from_key=source_config.from_key,
+ to_key=source_config.to_key,
+ direction=source_config.direction,
+ limit=source_config.limit,
+ event_filter=event_filter,
)
- events, next_key = yield data_source.get_pagination_rows(
- requester.user, source_config, room_id
- )
-
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
+ )
if not events:
defer.returnValue({
@@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
+ if event_filter:
+ events = event_filter.filter(events)
+
events = yield filter_events_for_client(
self.store,
user_id,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 711a6a567f..d9ac09078d 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
+import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID, Requester
-
+from synapse.types import UserID
from ._base import BaseHandler
-import logging
-
logger = logging.getLogger(__name__)
@@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
- requester = Requester(user, "", False)
+ # XXX why are we recreating `requester` here for each room?
+ # what was wrong with the `requester` we were passed?
+ requester = synapse.types.create_requester(user)
yield handler.update_membership(
requester,
user,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0b7517221d..dd75c4fecf 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,18 +14,19 @@
# limitations under the License.
"""Contains functions for registering clients."""
+import logging
+import urllib
+
from twisted.internet import defer
-from synapse.types import UserID, Requester
+import synapse.types
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
-from ._base import BaseHandler
-from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
-
-import logging
-import urllib
+from synapse.types import UserID
+from synapse.util.async import run_on_reactor
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -52,6 +53,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME
)
+ if localpart[0] == '_':
+ raise SynapseError(
+ 400,
+ "User ID may not begin with _",
+ Codes.INVALID_USERNAME
+ )
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -90,7 +98,8 @@ class RegistrationHandler(BaseHandler):
password=None,
generate_token=True,
guest_access_token=None,
- make_guest=False
+ make_guest=False,
+ admin=False,
):
"""Registers a new client on the server.
@@ -98,8 +107,13 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be generated.
password (str) : The password to assign to this user so they can
- login again. This can be None which means they cannot login again
- via a password (e.g. the user is an application service user).
+ login again. This can be None which means they cannot login again
+ via a password (e.g. the user is an application service user).
+ generate_token (bool): Whether a new access token should be
+ generated. Having this be True should be considered deprecated,
+ since it offers no means of associating a device_id with the
+ access_token. Instead you should call auth_handler.issue_access_token
+ after registration.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -141,6 +155,7 @@ class RegistrationHandler(BaseHandler):
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
+ admin=admin,
)
else:
# autogen a sequential user ID
@@ -194,15 +209,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
- token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
- token=token,
password_hash="",
appservice_id=service_id,
create_profile_with_localpart=user.localpart,
)
- defer.returnValue((user_id, token))
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -358,7 +371,8 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, localpart, displayname, duration_seconds):
+ def get_or_create_user(self, localpart, displayname, duration_in_ms,
+ password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -387,14 +401,14 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- token = self.auth_handler().generate_short_term_login_token(
- user_id, duration_seconds)
+ token = self.auth_handler().generate_access_token(
+ user_id, None, duration_in_ms)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=None,
+ password_hash=password_hash,
create_profile_with_localpart=user.localpart,
)
else:
@@ -404,8 +418,9 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
+ requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
- user, Requester(user, token, False), displayname
+ user, requester, displayname
)
defer.returnValue((user_id, token))
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ae44c7a556..bf6b1c1535 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache()
- self.remote_list_request_cache = ResponseCache()
+ self.response_cache = ResponseCache(hs)
+ self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7e616f44fd..8cec8fc4ed 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -14,24 +14,22 @@
# limitations under the License.
-from twisted.internet import defer
+import logging
-from ._base import BaseHandler
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from twisted.internet import defer
+from unpaddedbase64 import decode_base64
-from synapse.types import UserID, RoomID, Requester
+import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
+from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
-
-from signedjson.sign import verify_signed_json
-from signedjson.key import decode_verify_key_bytes
-
-from unpaddedbase64 import decode_base64
-
-import logging
+from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
- requester = Requester(target_user, None, False)
+ requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index be26a491ff..0ee4ebe504 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -138,7 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache()
+ self.response_cache = ResponseCache(hs)
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index f705abab94..2b3c05a740 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
+ logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback)
)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index bdd7292a30..76d5998d75 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -27,7 +27,8 @@ import gc
from twisted.internet import reactor
from .metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
+ CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
+ MemoryUsageMetric,
)
@@ -66,6 +67,21 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs)
+def register_memory_metrics(hs):
+ try:
+ import psutil
+ process = psutil.Process()
+ process.memory_info().rss
+ except (ImportError, AttributeError):
+ logger.warn(
+ "psutil is not installed or incorrect version."
+ " Disabling memory metrics."
+ )
+ return
+ metric = MemoryUsageMetric(hs, psutil)
+ all_metrics.append(metric)
+
+
def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 341043952a..e81af29895 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -153,3 +153,43 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]
+
+
+class MemoryUsageMetric(object):
+ """Keeps track of the current memory usage, using psutil.
+
+ The class will keep the current min/max/sum/counts of rss over the last
+ WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
+ """
+
+ UPDATE_HZ = 2 # number of times to get memory per second
+ WINDOW_SIZE_SEC = 30 # the size of the window in seconds
+
+ def __init__(self, hs, psutil):
+ clock = hs.get_clock()
+ self.memory_snapshots = []
+
+ self.process = psutil.Process()
+
+ clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
+
+ def _update_curr_values(self):
+ max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
+ self.memory_snapshots.append(self.process.memory_info().rss)
+ self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
+
+ def render(self):
+ if not self.memory_snapshots:
+ return []
+
+ max_rss = max(self.memory_snapshots)
+ min_rss = min(self.memory_snapshots)
+ sum_rss = sum(self.memory_snapshots)
+ len_rss = len(self.memory_snapshots)
+
+ return [
+ "process_psutil_rss:max %d" % max_rss,
+ "process_psutil_rss:min %d" % min_rss,
+ "process_psutil_rss:total %d" % sum_rss,
+ "process_psutil_rss:count %d" % len_rss,
+ ]
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 12a3ec7fd8..6600c9cd55 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,6 +14,7 @@
# limitations under the License.
from twisted.internet import defer, reactor
+from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
@@ -92,7 +93,11 @@ class EmailPusher(object):
def on_stop(self):
if self.timed_call:
- self.timed_call.cancel()
+ try:
+ self.timed_call.cancel()
+ except (AlreadyCalled, AlreadyCancelled):
+ pass
+ self.timed_call = None
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -140,9 +145,8 @@ class EmailPusher(object):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
- self.user_id, start, self.max_stream_ordering
- )
+ fn = self.store.get_unread_push_actions_for_user_in_range_for_email
+ unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
soonest_due_at = None
@@ -190,7 +194,10 @@ class EmailPusher(object):
soonest_due_at = should_notify_at
if self.timed_call is not None:
- self.timed_call.cancel()
+ try:
+ self.timed_call.cancel()
+ except (AlreadyCalled, AlreadyCancelled):
+ pass
self.timed_call = None
if soonest_due_at is not None:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 3992804845..feedb075e2 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -16,6 +16,7 @@
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor
+from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator
@@ -38,6 +39,7 @@ class HttpPusher(object):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id']
self.app_display_name = pusherdict['app_display_name']
@@ -108,7 +110,11 @@ class HttpPusher(object):
def on_stop(self):
if self.timed_call:
- self.timed_call.cancel()
+ try:
+ self.timed_call.cancel()
+ except (AlreadyCalled, AlreadyCancelled):
+ pass
+ self.timed_call = None
@defer.inlineCallbacks
def _process(self):
@@ -140,7 +146,8 @@ class HttpPusher(object):
run once per pusher.
"""
- unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
+ fn = self.store.get_unread_push_actions_for_user_in_range_for_http
+ unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -237,7 +244,9 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
- ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
+ ctx = yield push_tools.get_context_for_event(
+ self.state_handler, event, self.user_id
+ )
d = {
'notification': {
@@ -269,8 +278,8 @@ class HttpPusher(object):
if 'content' in event:
d['notification']['content'] = event.content
- if len(ctx['aliases']):
- d['notification']['room_alias'] = ctx['aliases'][0]
+ # We no longer send aliases separately, instead, we send the human
+ # readable name of the room, which may be an alias.
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 89a3b5e90a..d555a33e9a 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -14,6 +14,9 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.util.presentable_names import (
+ calculate_room_name, name_from_member_event
+)
@defer.inlineCallbacks
@@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
-def get_context_for_event(store, ev):
- name_aliases = yield store.get_room_name_and_aliases(
- ev.room_id
- )
+def get_context_for_event(state_handler, ev, user_id):
+ ctx = {}
- ctx = {'aliases': name_aliases[1]}
- if name_aliases[0] is not None:
- ctx['name'] = name_aliases[0]
+ room_state = yield state_handler.get_current_state(ev.room_id)
- their_member_events_for_room = yield store.get_current_state(
- room_id=ev.room_id,
- event_type='m.room.member',
- state_key=ev.user_id
+ # we no longer bother setting room_alias, and make room_name the
+ # human-readable name instead, be that m.room.name, an alias or
+ # a list of people in the room
+ name = calculate_room_name(
+ room_state, user_id, fallback_to_single_member=False
)
- for mev in their_member_events_for_room:
- if mev.content['membership'] == 'join' and 'displayname' in mev.content:
- dn = mev.content['displayname']
- if dn is not None:
- ctx['sender_display_name'] = dn
+ if name:
+ ctx['name'] = name
+
+ sender_state_event = room_state[("m.room.member", ev.sender)]
+ ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index e0a7a19777..86e3d89154 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
},
+ "ldap": {
+ "ldap3>=1.0": ["ldap3>=1.0"],
+ },
+ "psutil": {
+ "psutil>=2.0.0": ["psutil>=2.0.0"],
+ },
}
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
new file mode 100644
index 0000000000..5fbe3a303a
--- /dev/null
+++ b/synapse/replication/slave/storage/directory.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage.directory import DirectoryStore
+
+
+class DirectoryStore(BaseSlavedStore):
+ get_aliases_for_room = DirectoryStore.__dict__[
+ "get_aliases_for_room"
+ ].orig
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 877c68508c..f4f31f2d27 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
-from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
@@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
- get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
@@ -95,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
- get_unread_push_actions_for_user_in_range = (
- DataStore.get_unread_push_actions_for_user_in_range.__func__
+ get_unread_push_actions_for_user_in_range_for_http = (
+ DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
+ )
+ get_unread_push_actions_for_user_in_range_for_email = (
+ DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
@@ -144,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
+ get_backfill_events = DataStore.get_backfill_events.__func__
+ _get_backfill_events = DataStore._get_backfill_events.__func__
+ get_missing_events = DataStore.get_missing_events.__func__
+ _get_missing_events = DataStore._get_missing_events.__func__
+
+ get_auth_chain = DataStore.get_auth_chain.__func__
+ get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
+ _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
+
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
@@ -202,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
- self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
@@ -246,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
-
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- self.get_room_name_and_aliases.invalidate(
- (event.room_id,)
- )
- pass
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
new file mode 100644
index 0000000000..dd2ae49e48
--- /dev/null
+++ b/synapse/replication/slave/storage/keys.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage import DataStore
+from synapse.storage.keys import KeyStore
+
+
+class SlavedKeyStore(BaseSlavedStore):
+ _get_server_verify_key = KeyStore.__dict__[
+ "_get_server_verify_key"
+ ]
+
+ get_server_verify_keys = DataStore.get_server_verify_keys.__func__
+ store_server_verify_key = DataStore.store_server_verify_key.__func__
+
+ get_server_certificate = DataStore.get_server_certificate.__func__
+ store_server_certificate = DataStore.store_server_certificate.__func__
+
+ get_server_keys_json = DataStore.get_server_keys_json.__func__
+ store_server_keys_json = DataStore.store_server_keys_json.__func__
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
new file mode 100644
index 0000000000..d5bb0f98ea
--- /dev/null
+++ b/synapse/replication/slave/storage/room.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage import DataStore
+
+
+class RoomStore(BaseSlavedStore):
+ get_public_room_ids = DataStore.get_public_room_ids.__func__
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
new file mode 100644
index 0000000000..6f2ba98af5
--- /dev/null
+++ b/synapse/replication/slave/storage/transactions.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from ._base import BaseSlavedStore
+from synapse.storage import DataStore
+from synapse.storage.transactions import TransactionStore
+
+
+class TransactionStore(BaseSlavedStore):
+ get_destination_retry_timings = TransactionStore.__dict__[
+ "get_destination_retry_timings"
+ ].orig
+ _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
+
+ # For now, don't record the destination rety timings
+ def set_destination_retry_timings(*args, **kwargs):
+ return defer.succeed(None)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 8b223e032b..14227f1cdb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import (
account_data,
report_event,
openid,
+ devices,
)
from synapse.http.server import JsonResource
@@ -90,3 +91,4 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
+ devices.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index aa05b3f023..b0cb31a448 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
+class PurgeMediaCacheRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/admin/purge_media_cache")
+
+ def __init__(self, hs):
+ self.media_repository = hs.get_media_repository()
+ super(PurgeMediaCacheRestServlet, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ before_ts = request.args.get("before_ts", None)
+ if not before_ts:
+ raise SynapseError(400, "Missing 'before_ts' arg")
+
+ logger.info("before_ts: %r", before_ts[0])
+
+ try:
+ before_ts = int(before_ts[0])
+ except Exception:
+ raise SynapseError(400, "Invalid 'before_ts' arg")
+
+ ret = yield self.media_repository.delete_old_remote_media(before_ts)
+
+ defer.returnValue((200, ret))
+
+
+class PurgeHistoryRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ yield self.handlers.message_handler.purge_history(room_id, event_id)
+
+ defer.returnValue((200, {}))
+
+
+class DeactivateAccountRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(DeactivateAccountRestServlet, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, target_user_id):
+ UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ # FIXME: Theoretically there is a race here wherein user resets password
+ # using threepid.
+ yield self.store.user_delete_access_tokens(target_user_id)
+ yield self.store.user_delete_threepids(target_user_id)
+ yield self.store.user_set_password_hash(target_user_id, None)
+
+ defer.returnValue((200, {}))
+
+
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
+ PurgeMediaCacheRestServlet(hs).register(http_server)
+ DeactivateAccountRestServlet(hs).register(http_server)
+ PurgeHistoryRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 1c020b7e2c..96b49b01f2 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8df9d10efa..92fcae674a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
+ self.device_handler = self.hs.get_device_handler()
def on_GET(self, request):
flows = []
@@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string()
auth_handler = self.auth_handler
- user_id, access_token, refresh_token = yield auth_handler.login_with_password(
+ user_id = yield auth_handler.validate_password_login(
user_id=user_id,
- password=login_submission["password"])
-
+ password=login_submission["password"],
+ )
+ device_id = yield self._register_device(user_id, login_submission)
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id,
+ login_submission.get("initial_device_display_name")
+ )
+ )
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
+ "device_id": device_id,
}
defer.returnValue((200, result))
@@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
+ device_id = yield self._register_device(user_id, login_submission)
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id,
+ login_submission.get("initial_device_display_name")
+ )
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
+ "device_id": device_id,
}
defer.returnValue((200, result))
@@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if registered_user_id:
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(
+ registered_user_id
+ )
)
result = {
- "user_id": user_id, # may have changed
+ "user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
@@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if registered_user_id:
+ device_id = yield self._register_device(
+ registered_user_id, login_submission
+ )
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(
+ registered_user_id, device_id,
+ login_submission.get("initial_device_display_name")
+ )
)
result = {
- "user_id": user_id, # may have changed
+ "user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
+ # TODO: we should probably check that the register isn't going
+ # to fonx/change our user_id before registering the device
+ device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
@@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes)
+ def _register_device(self, user_id, login_submission):
+ """Register a device for a user.
+
+ This is called after the user's credentials have been validated, but
+ before the access token has been issued.
+
+ Args:
+ (str) user_id: full canonical @user:id
+ (object) login_submission: dictionary supplied to /login call, from
+ which we pull device_id and initial_device_name
+ Returns:
+ defer.Deferred: (str) device_id
+ """
+ device_id = login_submission.get("device_id")
+ initial_display_name = login_submission.get(
+ "initial_device_display_name")
+ return self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+
class SAML2RestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/saml2", releases=())
@@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if not user_exists:
- user_id, _ = (
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
- login_token = auth_handler.generate_short_term_login_token(user_id)
+ login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index e3f4fbb0bb..2383b9df86 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
@@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
+ self.auth_handler = hs.get_auth_handler()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
- (user_id, token) = yield handler.appservice_register(
+ user_id = yield handler.appservice_register(
user_localpart, as_token
)
+ token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
@@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
+ password = register_json["password"].encode("utf-8")
+ admin = register_json.get("admin", None)
+
+ # Its important to check as we use null bytes as HMAC field separators
+ if "\x00" in user:
+ raise SynapseError(400, "Invalid user")
+ if "\x00" in password:
+ raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
@@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
- msg=user,
digestmod=sha1,
- ).hexdigest()
-
- password = register_json["password"].encode("utf-8")
+ )
+ want_mac.update(user)
+ want_mac.update("\x00")
+ want_mac.update(password)
+ want_mac.update("\x00")
+ want_mac.update("admin" if admin else "notadmin")
+ want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
+ admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
@@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
+ password_hash = user_json["password_hash"].encode("utf-8") \
+ if user_json.get("password_hash") else None
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
- duration_seconds=duration_seconds
+ duration_in_ms=(duration_seconds * 1000),
+ password_hash=password_hash
)
defer.returnValue({
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 86fbe2747d..866a1e9120 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
+from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import logging
import urllib
+import ujson as json
logger = logging.getLogger(__name__)
@@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
+ filter_bytes = request.args.get("filter", None)
+ if filter_bytes:
+ filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ event_filter = Filter(json.loads(filter_json))
+ else:
+ event_filter = None
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
- as_client_event=as_client_event
+ as_client_event=as_client_event,
+ event_filter=event_filter,
)
defer.returnValue((200, msgs))
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index b6faa2b0e6..20e765f48f 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -25,7 +25,9 @@ import logging
logger = logging.getLogger(__name__)
-def client_v2_patterns(path_regex, releases=(0,)):
+def client_v2_patterns(path_regex, releases=(0,),
+ v2_alpha=True,
+ unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)):
Returns:
SRE_Pattern
"""
- patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
- unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
- patterns.append(re.compile("^" + unstable_prefix + path_regex))
+ patterns = []
+ if v2_alpha:
+ patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
+ if unstable:
+ unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
+ patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 9a84873a5f..eb49ad62e9 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -28,8 +28,40 @@ import logging
logger = logging.getLogger(__name__)
+class PasswordRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
+
+ def __init__(self, hs):
+ super(PasswordRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ required = ['id_server', 'client_secret', 'email', 'send_attempt']
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if absent:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'email', body['email']
+ )
+
+ if existingUid is None:
+ raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
+
+ ret = yield self.identity_handler.requestEmailToken(**body)
+ defer.returnValue((200, ret))
+
+
class PasswordRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password")
+ PATTERNS = client_v2_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet):
return 200, {}
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/deactivate$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+ super(DeactivateAccountRestServlet, self).__init__()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ authed, result, params, _ = yield self.auth_handler.check_auth([
+ [LoginType.PASSWORD],
+ ], body, self.hs.get_ip_from_request(request))
+
+ if not authed:
+ defer.returnValue((401, result))
+
+ user_id = None
+ requester = None
+
+ if LoginType.PASSWORD in result:
+ # if using password, they should also be logged in
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+ if user_id != result[LoginType.PASSWORD]:
+ raise LoginError(400, "", Codes.UNKNOWN)
+ else:
+ logger.error("Auth succeeded but no known type!", result.keys())
+ raise SynapseError(500, "", Codes.UNKNOWN)
+
+ # FIXME: Theoretically there is a race here wherein user resets password
+ # using threepid.
+ yield self.store.user_delete_access_tokens(user_id)
+ yield self.store.user_delete_threepids(user_id)
+ yield self.store.user_set_password_hash(user_id, None)
+
+ defer.returnValue((200, {}))
+
+
+class ThreepidRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ super(ThreepidRequestTokenRestServlet, self).__init__()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ required = ['id_server', 'client_secret', 'email', 'send_attempt']
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if absent:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'email', body['email']
+ )
+
+ if existingUid is not None:
+ raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+
+ ret = yield self.identity_handler.requestEmailToken(**body)
+ defer.returnValue((200, ret))
+
+
class ThreepidRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid")
+ PATTERNS = client_v2_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server):
+ PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
+ DeactivateAccountRestServlet(hs).register(http_server)
+ ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
new file mode 100644
index 0000000000..8fbd3d3dfc
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.http import servlet
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class DevicesRestServlet(servlet.RestServlet):
+ PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(DevicesRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ devices = yield self.device_handler.get_devices_by_user(
+ requester.user.to_string()
+ )
+ defer.returnValue((200, {"devices": devices}))
+
+
+class DeviceRestServlet(servlet.RestServlet):
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
+ releases=[], v2_alpha=False)
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(DeviceRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+ device = yield self.device_handler.get_device(
+ requester.user.to_string(),
+ device_id,
+ )
+ defer.returnValue((200, device))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, device_id):
+ # XXX: it's not completely obvious we want to expose this endpoint.
+ # It allows the client to delete access tokens, which feels like a
+ # thing which merits extra auth. But if we want to do the interactive-
+ # auth dance, we should really make it possible to delete more than one
+ # device at a time.
+ requester = yield self.auth.get_user_by_req(request)
+ yield self.device_handler.delete_device(
+ requester.user.to_string(),
+ device_id,
+ )
+ defer.returnValue((200, {}))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+
+ body = servlet.parse_json_object_from_request(request)
+ yield self.device_handler.update_device(
+ requester.user.to_string(),
+ device_id,
+ body
+ )
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ DevicesRestServlet(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 89ab39491c..c5ff16adf3 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -13,24 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+import simplejson as json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
+import synapse.api.errors
+import synapse.server
+import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
-
-from canonicaljson import encode_canonical_json
-
from ._base import client_v2_patterns
-import logging
-import simplejson as json
-
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
- POST /keys/upload/<device_id> HTTP/1.1
+ POST /keys/upload HTTP/1.1
Content-Type: application/json
{
@@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
+ releases=())
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
+
user_id = requester.user.to_string()
- # TODO: Check that the device_id matches that in the authentication
- # or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request)
+ if device_id is not None:
+ # passing the device_id here is deprecated; however, we allow it
+ # for now for compatibility with older clients.
+ if (requester.device_id is not None and
+ device_id != requester.device_id):
+ logger.warning("Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id, device_id)
+ else:
+ device_id = requester.device_id
+
+ if device_id is None:
+ raise synapse.api.errors.SynapseError(
+ 400,
+ "To upload keys, you must pass device_id when authenticating"
+ )
+
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
@@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list
)
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue((200, {"one_time_key_counts": result}))
-
- @defer.inlineCallbacks
- def on_GET(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
)
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(KeyQueryServlet, self).__init__()
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.federation = hs.get_replication_layer()
- self.is_mine = hs.is_mine
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
- result = yield self.handle_request(body)
+ result = yield self.e2e_keys_handler.query_devices(body)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
- result = yield self.handle_request(
+ result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)
- @defer.inlineCallbacks
- def handle_request(self, body):
- local_query = []
- remote_queries = {}
- for user_id, device_ids in body.get("device_keys", {}).items():
- user = UserID.from_string(user_id)
- if self.is_mine(user):
- if not device_ids:
- local_query.append((user_id, None))
- else:
- for device_id in device_ids:
- local_query.append((user_id, device_id))
- else:
- remote_queries.setdefault(user.domain, {})[user_id] = list(
- device_ids
- )
- results = yield self.store.get_e2e_device_keys(local_query)
-
- json_result = {}
- for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- json_result.setdefault(user_id, {})[device_id] = json.loads(
- json_bytes
- )
-
- for destination, device_keys in remote_queries.items():
- remote_result = yield self.federation.query_client_keys(
- destination, {"device_keys": device_keys}
- )
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
- defer.returnValue((200, {"device_keys": json_result}))
-
class OneTimeKeyServlet(RestServlet):
"""
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2088c316d1..943f5676a3 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -41,17 +41,59 @@ else:
logger = logging.getLogger(__name__)
+class RegisterRequestTokenRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/register/email/requestToken$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(RegisterRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ required = ['id_server', 'client_secret', 'email', 'send_attempt']
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if len(absent) > 0:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'email', body['email']
+ )
+
+ if existingUid is not None:
+ raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+
+ ret = yield self.identity_handler.requestEmailToken(**body)
+ defer.returnValue((200, ret))
+
+
class RegisterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register")
+ PATTERNS = client_v2_patterns("/register$")
def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
super(RegisterRestServlet, self).__init__()
+
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
+ self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,)
)
- if '/register/email/requestToken' in request.path:
- ret = yield self.onEmailTokenRequest(request)
- defer.returnValue(ret)
-
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- if isinstance(body.get("user"), basestring):
- desired_username = body["user"]
- result = yield self._do_appservice_registration(
- desired_username, request.args["access_token"][0]
- )
+ desired_username = body.get("user", desired_username)
+
+ if isinstance(desired_username, basestring):
+ result = yield self._do_appservice_registration(
+ desired_username, request.args["access_token"][0], body
+ )
defer.returnValue((200, result)) # we throw for non 200 responses
return
@@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
- desired_username, desired_password, body["mac"]
+ desired_username, desired_password, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
@@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
- authed, result, params, session_id = yield self.auth_handler.check_auth(
+ authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
- defer.returnValue((401, result))
+ defer.returnValue((401, auth_result))
return
if registered_user_id is not None:
@@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session",
registered_user_id
)
- access_token = yield self.auth_handler.issue_access_token(registered_user_id)
- refresh_token = yield self.auth_handler.issue_refresh_token(
- registered_user_id
+ # don't re-register the email address
+ add_email = False
+ else:
+ # NB: This may be from the auth handler and NOT from the POST
+ if 'password' not in params:
+ raise SynapseError(400, "Missing password.",
+ Codes.MISSING_PARAM)
+
+ desired_username = params.get("username", None)
+ new_password = params.get("password", None)
+ guest_access_token = params.get("guest_access_token", None)
+
+ (registered_user_id, _) = yield self.registration_handler.register(
+ localpart=desired_username,
+ password=new_password,
+ guest_access_token=guest_access_token,
+ generate_token=False,
)
- defer.returnValue((200, {
- "user_id": registered_user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "refresh_token": refresh_token,
- }))
-
- # NB: This may be from the auth handler and NOT from the POST
- if 'password' not in params:
- raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
-
- desired_username = params.get("username", None)
- new_password = params.get("password", None)
- guest_access_token = params.get("guest_access_token", None)
-
- (user_id, token) = yield self.registration_handler.register(
- localpart=desired_username,
- password=new_password,
- guest_access_token=guest_access_token,
- )
- # remember that we've now registered that user account, and with what
- # user ID (since the user may not have specified)
- self.auth_handler.set_session_data(
- session_id, "registered_user_id", user_id
+ # remember that we've now registered that user account, and with
+ # what user ID (since the user may not have specified)
+ self.auth_handler.set_session_data(
+ session_id, "registered_user_id", registered_user_id
+ )
+
+ add_email = True
+
+ return_dict = yield self._create_registration_details(
+ registered_user_id, params
)
- if result and LoginType.EMAIL_IDENTITY in result:
- threepid = result[LoginType.EMAIL_IDENTITY]
-
- for reqd in ['medium', 'address', 'validated_at']:
- if reqd not in threepid:
- logger.info("Can't add incomplete 3pid")
- else:
- yield self.auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
- )
-
- # And we add an email pusher for them by default, but only
- # if email notifications are enabled (so people don't start
- # getting mail spam where they weren't before if email
- # notifs are set up on a home server)
- if (
- self.hs.config.email_enable_notifs and
- self.hs.config.email_notif_for_new_users
- ):
- # Pull the ID of the access token back out of the db
- # It would really make more sense for this to be passed
- # up when the access token is saved, but that's quite an
- # invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(
- token
- )
-
- yield self.hs.get_pusherpool().add_pusher(
- user_id=user_id,
- access_token=user_tuple["token_id"],
- kind="email",
- app_id="m.email",
- app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
- lang=None, # We don't know a user's language here
- data={},
- )
-
- if 'bind_email' in params and params['bind_email']:
- logger.info("bind_email specified: binding")
-
- emailThreepid = result[LoginType.EMAIL_IDENTITY]
- threepid_creds = emailThreepid['threepid_creds']
- logger.debug("Binding emails %s to %s" % (
- emailThreepid, user_id
- ))
- yield self.identity_handler.bind_threepid(threepid_creds, user_id)
- else:
- logger.info("bind_email not specified: not binding email")
-
- result = yield self._create_registration_details(user_id, token)
- defer.returnValue((200, result))
+ if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ yield self._register_email_threepid(
+ registered_user_id, threepid, return_dict["access_token"],
+ params.get("bind_email")
+ )
+
+ defer.returnValue((200, return_dict))
def on_OPTIONS(self, _):
return 200, {}
@defer.inlineCallbacks
- def _do_appservice_registration(self, username, as_token):
- (user_id, token) = yield self.registration_handler.appservice_register(
+ def _do_appservice_registration(self, username, as_token, body):
+ user_id = yield self.registration_handler.appservice_register(
username, as_token
)
- defer.returnValue((yield self._create_registration_details(user_id, token)))
+ defer.returnValue((yield self._create_registration_details(user_id, body)))
@defer.inlineCallbacks
- def _do_shared_secret_registration(self, username, password, mac):
+ def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
- got_mac = str(mac)
+ got_mac = str(body["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
@@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect",
)
- (user_id, token) = yield self.registration_handler.register(
- localpart=username, password=password
+ (user_id, _) = yield self.registration_handler.register(
+ localpart=username, password=password, generate_token=False,
)
- defer.returnValue((yield self._create_registration_details(user_id, token)))
- @defer.inlineCallbacks
- def _create_registration_details(self, user_id, token):
- refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
- defer.returnValue({
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- "refresh_token": refresh_token,
- })
+ result = yield self._create_registration_details(user_id, body)
+ defer.returnValue(result)
@defer.inlineCallbacks
- def onEmailTokenRequest(self, request):
- body = parse_json_object_from_request(request)
+ def _register_email_threepid(self, user_id, threepid, token, bind_email):
+ """Add an email address as a 3pid identifier
+
+ Also adds an email pusher for the email address, if configured in the
+ HS config
+
+ Also optionally binds emails to the given user_id on the identity server
+
+ Args:
+ user_id (str): id of user
+ threepid (object): m.login.email.identity auth response
+ token (str): access_token for the user
+ bind_email (bool): true if the client requested the email to be
+ bound at the identity server
+ Returns:
+ defer.Deferred:
+ """
+ reqd = ('medium', 'address', 'validated_at')
+ if any(x not in threepid for x in reqd):
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue()
+
+ yield self.auth_handler.add_threepid(
+ user_id,
+ threepid['medium'],
+ threepid['address'],
+ threepid['validated_at'],
+ )
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
+ # And we add an email pusher for them by default, but only
+ # if email notifications are enabled (so people don't start
+ # getting mail spam where they weren't before if email
+ # notifs are set up on a home server)
+ if (self.hs.config.email_enable_notifs and
+ self.hs.config.email_notif_for_new_users):
+ # Pull the ID of the access token back out of the db
+ # It would really make more sense for this to be passed
+ # up when the access token is saved, but that's quite an
+ # invasive change I'd rather do separately.
+ user_tuple = yield self.store.get_user_by_access_token(
+ token
+ )
+ token_id = user_tuple["token_id"]
+
+ yield self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="email",
+ app_id="m.email",
+ app_display_name="Email Notifications",
+ device_display_name=threepid["address"],
+ pushkey=threepid["address"],
+ lang=None, # We don't know a user's language here
+ data={},
+ )
- if len(absent) > 0:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if bind_email:
+ logger.info("bind_email specified: binding")
+ logger.debug("Binding emails %s to %s" % (
+ threepid, user_id
+ ))
+ yield self.identity_handler.bind_threepid(
+ threepid['threepid_creds'], user_id
+ )
+ else:
+ logger.info("bind_email not specified: not binding email")
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'email', body['email']
+ @defer.inlineCallbacks
+ def _create_registration_details(self, user_id, params):
+ """Complete registration of newly-registered user
+
+ Allocates device_id if one was not given; also creates access_token
+ and refresh_token.
+
+ Args:
+ (str) user_id: full canonical @user:id
+ (object) params: registration parameters, from which we pull
+ device_id and initial_device_name
+ Returns:
+ defer.Deferred: (object) dictionary for response from /register
+ """
+ device_id = yield self._register_device(user_id, params)
+
+ access_token, refresh_token = (
+ yield self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id,
+ initial_display_name=params.get("initial_device_display_name")
+ )
)
- if existingUid is not None:
- raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+ defer.returnValue({
+ "user_id": user_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ "refresh_token": refresh_token,
+ "device_id": device_id,
+ })
- ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ def _register_device(self, user_id, params):
+ """Register a device for a user.
+
+ This is called after the user's credentials have been validated, but
+ before the access token has been issued.
+
+ Args:
+ (str) user_id: full canonical @user:id
+ (object) params: registration parameters, from which we pull
+ device_id and initial_device_name
+ Returns:
+ defer.Deferred: (str) device_id
+ """
+ # register the user's device
+ device_id = params.get("device_id")
+ initial_display_name = params.get("initial_device_display_name")
+ device_id = self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+ return device_id
@defer.inlineCallbacks
def _do_guest_registration(self):
@@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
generate_token=False,
make_guest=True
)
- access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
+ access_token = self.auth_handler.generate_access_token(
+ user_id, ["guest = true"]
+ )
+ # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
+ # so long as we don't return a refresh_token here.
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
@@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
+ RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 8270e8787f..0d312c91d4 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
- (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
- old_refresh_token, auth_handler.generate_refresh_token)
- new_access_token = yield auth_handler.issue_access_token(user_id)
+ refresh_result = yield self.store.exchange_refresh_token(
+ old_refresh_token, auth_handler.generate_refresh_token
+ )
+ (user_id, new_refresh_token, device_id) = refresh_result
+ new_access_token = yield auth_handler.issue_access_token(
+ user_id, device_id
+ )
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index ca5468c402..e984ea47db 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
- "versions": ["r0.0.1"]
+ "versions": [
+ "r0.0.1",
+ "r0.1.0",
+ "r0.2.0",
+ ]
})
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index d9fc045fc6..956bd5da75 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -15,14 +15,12 @@
from synapse.http.server import respond_with_json_bytes, finish_request
-from synapse.util.stringutils import random_string
from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException, Codes, cs_error
+ Codes, cs_error
)
from twisted.protocols.basic import FileSender
from twisted.web import server, resource
-from twisted.internet import defer
import base64
import simplejson as json
@@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
"""
isLeaf = True
- def __init__(self, hs, directory, auth, external_addr):
+ def __init__(self, hs, directory):
resource.Resource.__init__(self)
self.hs = hs
self.directory = directory
- self.auth = auth
- self.external_addr = external_addr.rstrip('/')
- self.max_upload_size = hs.config.max_upload_size
-
- if not os.path.isdir(self.directory):
- os.mkdir(self.directory)
- logger.info("ContentRepoResource : Created %s directory.",
- self.directory)
-
- @defer.inlineCallbacks
- def map_request_to_name(self, request):
- # auth the user
- requester = yield self.auth.get_user_by_req(request)
-
- # namespace all file uploads on the user
- prefix = base64.urlsafe_b64encode(
- requester.user.to_string()
- ).replace('=', '')
-
- # use a random string for the main portion
- main_part = random_string(24)
-
- # suffix with a file extension if we can make one. This is nice to
- # provide a hint to clients on the file information. We will also reuse
- # this info to spit back the content type to the client.
- suffix = ""
- if request.requestHeaders.hasHeader("Content-Type"):
- content_type = request.requestHeaders.getRawHeaders(
- "Content-Type")[0]
- suffix = "." + base64.urlsafe_b64encode(content_type)
- if (content_type.split("/")[0].lower() in
- ["image", "video", "audio"]):
- file_ext = content_type.split("/")[-1]
- # be a little paranoid and only allow a-z
- file_ext = re.sub("[^a-z]", "", file_ext)
- suffix += "." + file_ext
-
- file_name = prefix + main_part + suffix
- file_path = os.path.join(self.directory, file_name)
- logger.info("User %s is uploading a file to path %s",
- request.user.user_id.to_string(),
- file_path)
-
- # keep trying to make a non-clashing file, with a sensible max attempts
- attempts = 0
- while os.path.exists(file_path):
- main_part = random_string(24)
- file_name = prefix + main_part + suffix
- file_path = os.path.join(self.directory, file_name)
- attempts += 1
- if attempts > 25: # really? Really?
- raise SynapseError(500, "Unable to create file.")
-
- defer.returnValue(file_path)
def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home
@@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
return server.NOT_DONE_YET
- def render_POST(self, request):
- self._async_render(request)
- return server.NOT_DONE_YET
-
def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET
-
- @defer.inlineCallbacks
- def _async_render(self, request):
- try:
- # TODO: The checks here are a bit late. The content will have
- # already been uploaded to a tmp file at this point
- content_length = request.getHeader("Content-Length")
- if content_length is None:
- raise SynapseError(
- msg="Request must specify a Content-Length", code=400
- )
- if int(content_length) > self.max_upload_size:
- raise SynapseError(
- msg="Upload request body is too large",
- code=413,
- )
-
- fname = yield self.map_request_to_name(request)
-
- # TODO I have a suspicious feeling this is just going to block
- with open(fname, "wb") as f:
- f.write(request.content.read())
-
- # FIXME (erikj): These should use constants.
- file_name = os.path.basename(fname)
- # FIXME: we can't assume what the repo's public mounted path is
- # ...plus self-signed SSL won't work to remote clients anyway
- # ...and we can't assume that it's SSL anyway, as we might want to
- # serve it via the non-SSL listener...
- url = "%s/_matrix/content/%s" % (
- self.external_addr, file_name
- )
-
- respond_with_json_bytes(request, 200,
- json.dumps({"content_token": url}),
- send_cors=True)
-
- except CodeMessageException as e:
- logger.exception(e)
- respond_with_json_bytes(request, e.code,
- json.dumps(cs_exception(e)))
- except Exception as e:
- logger.error("Failed to store file: %s" % e)
- respond_with_json_bytes(
- request,
- 500,
- json.dumps({"error": "Internal server error"}),
- send_cors=True)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 422ab86fb3..0137458f71 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -65,3 +65,9 @@ class MediaFilePaths(object):
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
+
+ def remote_media_thumbnail_dir(self, server_name, file_id):
+ return os.path.join(
+ self.base_path, "remote_thumbnail", server_name,
+ file_id[0:2], file_id[2:4], file_id[4:],
+ )
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 2468c3ac42..692e078419 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer, threads
-from synapse.util.async import ObservableDeferred
+from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
+import errno
+import shutil
import cgi
import logging
@@ -43,8 +45,11 @@ import urlparse
logger = logging.getLogger(__name__)
+UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+
+
class MediaRepository(object):
- def __init__(self, hs, filepaths):
+ def __init__(self, hs):
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@@ -52,11 +57,28 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.filepaths = filepaths
- self.downloads = {}
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
+ self.remote_media_linearizer = Linearizer()
+
+ self.recently_accessed_remotes = set()
+
+ self.clock.looping_call(
+ self._update_recently_accessed_remotes,
+ UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+ )
+
+ @defer.inlineCallbacks
+ def _update_recently_accessed_remotes(self):
+ media = self.recently_accessed_remotes
+ self.recently_accessed_remotes = set()
+
+ yield self.store.update_cached_last_access_time(
+ media, self.clock.time_msec()
+ )
+
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
@@ -93,22 +115,12 @@ class MediaRepository(object):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ @defer.inlineCallbacks
def get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
- download = self.downloads.get(key)
- if download is None:
- download = self._get_remote_media_impl(server_name, media_id)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.downloads[key] = download
-
- @download.addBoth
- def callback(media_info):
- del self.downloads[key]
- return media_info
- return download.observe()
+ with (yield self.remote_media_linearizer.queue(key)):
+ media_info = yield self._get_remote_media_impl(server_name, media_id)
+ defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -119,6 +131,11 @@ class MediaRepository(object):
media_info = yield self._download_remote_file(
server_name, media_id
)
+ else:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ yield self.store.update_cached_last_access_time(
+ [(server_name, media_id)], self.clock.time_msec()
+ )
defer.returnValue(media_info)
@defer.inlineCallbacks
@@ -416,6 +433,41 @@ class MediaRepository(object):
"height": m_height,
})
+ @defer.inlineCallbacks
+ def delete_old_remote_media(self, before_ts):
+ old_media = yield self.store.get_remote_media_before(before_ts)
+
+ deleted = 0
+
+ for media in old_media:
+ origin = media["media_origin"]
+ media_id = media["media_id"]
+ file_id = media["filesystem_id"]
+ key = (origin, media_id)
+
+ logger.info("Deleting: %r", key)
+
+ with (yield self.remote_media_linearizer.queue(key)):
+ full_path = self.filepaths.remote_media_filepath(origin, file_id)
+ try:
+ os.remove(full_path)
+ except OSError as e:
+ logger.warn("Failed to remove file: %r", full_path)
+ if e.errno == errno.ENOENT:
+ pass
+ else:
+ continue
+
+ thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
+ origin, file_id
+ )
+ shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+ yield self.store.delete_remote_media(origin, media_id)
+ deleted += 1
+
+ defer.returnValue({"deleted": deleted})
+
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
@@ -464,9 +516,8 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
- filepaths = MediaFilePaths(hs.config.media_store_path)
- media_repo = MediaRepository(hs, filepaths)
+ media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 74c64f1371..bdd0e60c5b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -29,6 +29,8 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
+from copy import deepcopy
+
import os
import re
import fnmatch
@@ -329,20 +331,24 @@ class PreviewUrlResource(Resource):
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
- text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
- "ancestor::aside | ancestor::footer | "
- "ancestor::script | ancestor::style)]" +
- "[ancestor::body]")
- text = ''
- for text_node in text_nodes:
- if len(text) < 500:
- text += text_node + ' '
- else:
- break
- text = re.sub(r'[\t ]+', ' ', text)
- text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
- text = text.strip()[:500]
- og['og:description'] = text if text else None
+
+ # We don't just use XPATH here as that is slow on some machines.
+
+ # We clone `tree` as we modify it.
+ cloned_tree = deepcopy(tree.find("body"))
+
+ TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
+ for el in cloned_tree.iter(TAGS_TO_REMOVE):
+ el.getparent().remove(el)
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r'\s+', '\n', el.text).strip()
+ for el in cloned_tree.iter()
+ if el.text and isinstance(el.tag, basestring) # Removes comments
+ )
+ og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
@@ -450,3 +456,56 @@ class PreviewUrlResource(Resource):
content_type.startswith("application/xhtml")
):
return True
+
+
+def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+ # Try to get a summary of between 200 and 500 words, respecting
+ # first paragraph and then word boundaries.
+ # TODO: Respect sentences?
+
+ description = ''
+
+ # Keep adding paragraphs until we get to the MIN_SIZE.
+ for text_node in text_nodes:
+ if len(description) < min_size:
+ text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
+ description += text_node + '\n\n'
+ else:
+ break
+
+ description = description.strip()
+ description = re.sub(r'[\t ]+', ' ', description)
+ description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
+
+ # If the concatenation of paragraphs to get above MIN_SIZE
+ # took us over MAX_SIZE, then we need to truncate mid paragraph
+ if len(description) > max_size:
+ new_desc = ""
+
+ # This splits the paragraph into words, but keeping the
+ # (preceeding) whitespace intact so we can easily concat
+ # words back together.
+ for match in re.finditer("\s*\S+", description):
+ word = match.group()
+
+ # Keep adding words while the total length is less than
+ # MAX_SIZE.
+ if len(word) + len(new_desc) < max_size:
+ new_desc += word
+ else:
+ # At this point the next word *will* take us over
+ # MAX_SIZE, but we also want to ensure that its not
+ # a huge word. If it is add it anyway and we'll
+ # truncate later.
+ if len(new_desc) < min_size:
+ new_desc += word
+ break
+
+ # Double check that we're not over the limit
+ if len(new_desc) > max_size:
+ new_desc = new_desc[:max_size]
+
+ # We always add an ellipsis because at the very least
+ # we chopped mid paragraph.
+ description = new_desc.strip() + "…"
+ return description if description else None
diff --git a/synapse/server.py b/synapse/server.py
index dd4b81c658..6bb4988309 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -19,37 +19,38 @@
# partial one for unit test mocking.
# Imports required for the default HomeServer() implementation
-from twisted.web.client import BrowserLikePolicyForHTTPS
+import logging
+
from twisted.enterprise import adbapi
+from twisted.web.client import BrowserLikePolicyForHTTPS
-from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.api.auth import Auth
+from synapse.api.filtering import Filtering
+from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
+from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.crypto.keyring import Keyring
+from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication
-from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
-from synapse.notifier import Notifier
-from synapse.api.auth import Auth
from synapse.handlers import Handlers
+from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.handlers.auth import AuthHandler
+from synapse.handlers.device import DeviceHandler
+from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
+from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
-from synapse.handlers.room import RoomListHandler
-from synapse.handlers.auth import AuthHandler
-from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.notifier import Notifier
+from synapse.push.pusherpool import PusherPool
+from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler
from synapse.storage import DataStore
+from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
-from synapse.streams.events import EventSources
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.crypto.keyring import Keyring
-from synapse.push.pusherpool import PusherPool
-from synapse.events.builder import EventBuilderFactory
-from synapse.api.filtering import Filtering
-
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
-
-import logging
-
logger = logging.getLogger(__name__)
@@ -91,6 +92,8 @@ class HomeServer(object):
'typing_handler',
'room_list_handler',
'auth_handler',
+ 'device_handler',
+ 'e2e_keys_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@@ -113,6 +116,7 @@ class HomeServer(object):
'filtering',
'http_client_context_factory',
'simple_http_client',
+ 'media_repository',
]
def __init__(self, hostname, **kwargs):
@@ -195,6 +199,12 @@ class HomeServer(object):
def build_auth_handler(self):
return AuthHandler(self)
+ def build_device_handler(self):
+ return DeviceHandler(self)
+
+ def build_e2e_keys_handler(self):
+ return E2eKeysHandler(self)
+
def build_application_service_api(self):
return ApplicationServiceApi(self)
@@ -233,6 +243,9 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def build_media_repository(self):
+ return MediaRepository(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
new file mode 100644
index 0000000000..c0aa868c4f
--- /dev/null
+++ b/synapse/server.pyi
@@ -0,0 +1,25 @@
+import synapse.handlers
+import synapse.handlers.auth
+import synapse.handlers.device
+import synapse.handlers.e2e_keys
+import synapse.storage
+import synapse.state
+
+class HomeServer(object):
+ def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
+ pass
+
+ def get_datastore(self) -> synapse.storage.DataStore:
+ pass
+
+ def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
+ pass
+
+ def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
+ pass
+
+ def get_handlers(self) -> synapse.handlers.Handlers:
+ pass
+
+ def get_state_handler(self) -> synapse.state.StateHandler:
+ pass
diff --git a/synapse/state.py b/synapse/state.py
index d0f76dc4f5..ef1bc470be 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -379,7 +379,8 @@ class StateHandler(object):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
- self.hs.get_auth().check(event, auth_events)
+ # The signatures have already been checked at this point
+ self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
@@ -391,7 +392,8 @@ class StateHandler(object):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
- self.hs.get_auth().check(event, auth_events)
+ # The signatures have already been checked at this point
+ self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e93c3de66c..73fb334dd6 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
+
+from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
+ DeviceStore,
):
def __init__(self, db_conn, hs):
@@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering", step=-1
+ db_conn, "events", "stream_ordering", step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 32c6677d47..0117fdc639 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -597,10 +597,13 @@ class SQLBaseStore(object):
more rows, returning the result as a list of dicts.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with,
- or None to not apply a WHERE clause.
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
@@ -615,9 +618,11 @@ class SQLBaseStore(object):
Args:
txn : Transaction object
- table : string giving the table name
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -807,6 +812,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
+ def _simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_txn, table, keyvalues
+ )
+
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 66a995157d..30d0e4c5dc 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
+from . import engines
from twisted.internet import defer
@@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def start_doing_background_updates(self):
- while True:
- if self._background_update_timer is not None:
- return
+ assert self._background_update_timer is None, \
+ "background updates already running"
+
+ logger.info("Starting background schema updates")
+ while True:
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None
try:
- result = yield self.do_background_update(
+ result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
-
- if result is None:
- logger.info(
- "No more background updates to do."
- " Unscheduling background update task."
- )
- return
+ else:
+ if result is None:
+ logger.info(
+ "No more background updates to do."
+ " Unscheduling background update task."
+ )
+ defer.returnValue(None)
@defer.inlineCallbacks
- def do_background_update(self, desired_duration_ms):
- """Does some amount of work on a background update
+ def do_next_background_update(self, desired_duration_ms):
+ """Does some amount of work on the next queued background update
+
Args:
desired_duration_ms(float): How long we want to spend
updating.
@@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
+ # no work left to do
defer.returnValue(None)
+ # pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
+ res = yield self._do_background_update(update_name, desired_duration_ms)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _do_background_update(self, update_name, desired_duration_ms):
+ logger.info("Starting update batch on background update '%s'",
+ update_name)
+
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
@@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_background_index_update(self, update_name, index_name,
+ table, columns):
+ """Helper for store classes to do a background index addition
+
+ To use:
+
+ 1. use a schema delta file to add a background update. Example:
+ INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('my_new_index', '{}');
+
+ 2. In the Store constructor, call this method
+
+ Args:
+ update_name (str): update_name to register for
+ index_name (str): name of index to add
+ table (str): table to add index to
+ columns (list[str]): columns/expressions to include in index
+ """
+
+ # if this is postgres, we add the indexes concurrently. Otherwise
+ # we fall back to doing it inline
+ if isinstance(self.database_engine, engines.PostgresEngine):
+ conc = True
+ else:
+ conc = False
+
+ sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
+ % {
+ "conc": "CONCURRENTLY" if conc else "",
+ "name": index_name,
+ "table": table,
+ "columns": ", ".join(columns),
+ }
+
+ def create_index_concurrently(conn):
+ conn.rollback()
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ c = conn.cursor()
+ c.execute(sql)
+ conn.set_session(autocommit=False)
+
+ def create_index(conn):
+ c = conn.cursor()
+ c.execute(sql)
+
+ @defer.inlineCallbacks
+ def updater(progress, batch_size):
+ logger.info("Adding index %s to %s", index_name, table)
+ if conc:
+ yield self.runWithConnection(create_index_concurrently)
+ else:
+ yield self.runWithConnection(create_index)
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, updater)
+
def start_background_update(self, update_name, progress):
"""Starts a background update running.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index a90990e006..71e5ea112f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, Cache
+import logging
from twisted.internet import defer
+from ._base import Cache
+from . import background_updates
+
+logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
@@ -24,8 +28,7 @@ from twisted.internet import defer
LAST_SEEN_GRANULARITY = 120 * 1000
-class ClientIpStore(SQLBaseStore):
-
+class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
@@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
super(ClientIpStore, self).__init__(hs)
+ self.register_background_index_update(
+ "user_ips_device_index",
+ index_name="user_ips_device_id",
+ table="user_ips",
+ columns=["user_id", "device_id", "last_seen"],
+ )
+
@defer.inlineCallbacks
- def insert_client_ip(self, user, access_token, ip, user_agent):
+ def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
+ "device_id": device_id,
},
values={
"last_seen": now,
@@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
desc="insert_client_ip",
lock=False,
)
+
+ @defer.inlineCallbacks
+ def get_last_client_ip_by_device(self, devices):
+ """For each device_id listed, give the user_ip it was last seen on
+
+ Args:
+ devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+
+ Returns:
+ defer.Deferred: resolves to a dict, where the keys
+ are (user_id, device_id) tuples. The values are also dicts, with
+ keys giving the column names
+ """
+
+ res = yield self.runInteraction(
+ "get_last_client_ip_by_device",
+ self._get_last_client_ip_by_device_txn,
+ retcols=(
+ "user_id",
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ),
+ devices=devices
+ )
+
+ ret = {(d["user_id"], d["device_id"]): d for d in res}
+ defer.returnValue(ret)
+
+ @classmethod
+ def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ where_clauses = []
+ bindings = []
+ for (user_id, device_id) in devices:
+ if device_id is None:
+ where_clauses.append("(user_id = ? AND device_id IS NULL)")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
+
+ inner_select = (
+ "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
+ "WHERE %(where)s "
+ "GROUP BY user_id, device_id"
+ ) % {
+ "where": " OR ".join(where_clauses),
+ }
+
+ sql = (
+ "SELECT %(retcols)s FROM user_ips "
+ "JOIN (%(inner_select)s) ips ON"
+ " user_ips.last_seen = ips.mls AND"
+ " user_ips.user_id = ips.user_id AND"
+ " (user_ips.device_id = ips.device_id OR"
+ " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
+ " )"
+ ) % {
+ "retcols": ",".join("user_ips." + c for c in retcols),
+ "inner_select": inner_select,
+ }
+
+ txn.execute(sql, bindings)
+ return cls.cursor_to_dict(txn)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
new file mode 100644
index 0000000000..afd6530cab
--- /dev/null
+++ b/synapse/storage/devices.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def store_device(self, user_id, device_id,
+ initial_device_display_name,
+ ignore_if_known=True):
+ """Ensure the given device is known; add it to the store if not
+
+ Args:
+ user_id (str): id of user associated with the device
+ device_id (str): id of device
+ initial_device_display_name (str): initial displayname of the
+ device
+ ignore_if_known (bool): ignore integrity errors which mean the
+ device is already known
+ Returns:
+ defer.Deferred
+ Raises:
+ StoreError: if ignore_if_known is False and the device was already
+ known
+ """
+ try:
+ yield self._simple_insert(
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "display_name": initial_device_display_name
+ },
+ desc="store_device",
+ or_ignore=ignore_if_known,
+ )
+ except Exception as e:
+ logger.error("store_device with device_id=%s failed: %s",
+ device_id, e)
+ raise StoreError(500, "Problem storing device.")
+
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ def delete_device(self, user_id, device_id):
+ """Delete a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to delete
+ Returns:
+ defer.Deferred
+ """
+ return self._simple_delete_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_device",
+ )
+
+ def update_device(self, user_id, device_id, new_display_name=None):
+ """Update a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to update
+ new_display_name (str|None): new displayname for device; None
+ to leave unchanged
+ Raises:
+ StoreError: if the device is not found
+ Returns:
+ defer.Deferred
+ """
+ updates = {}
+ if new_display_name is not None:
+ updates["display_name"] = new_display_name
+ if not updates:
+ return defer.succeed(None)
+ return self._simple_update_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues=updates,
+ desc="update_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2e89066515..385d607056 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+
+import twisted.internet.defer
from ._base import SQLBaseStore
@@ -36,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
- key json byte strings.
+ dict containing "key_json", "device_display_name".
"""
- def _get_e2e_device_keys(txn):
- result = {}
- for user_id, device_id in query_list:
- user_result = result.setdefault(user_id, {})
- keyvalues = {"user_id": user_id}
- if device_id:
- keyvalues["device_id"] = device_id
- rows = self._simple_select_list_txn(
- txn, table="e2e_device_keys_json",
- keyvalues=keyvalues,
- retcols=["device_id", "key_json"]
- )
- for row in rows:
- user_result[row["device_id"]] = row["key_json"]
- return result
- return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+ if not query_list:
+ return {}
+
+ return self.runInteraction(
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+ )
+
+ def _get_e2e_device_keys_txn(self, txn, query_list):
+ query_clauses = []
+ query_params = []
+
+ for (user_id, device_id) in query_list:
+ query_clause = "k.user_id = ?"
+ query_params.append(user_id)
+
+ if device_id:
+ query_clause += " AND k.device_id = ?"
+ query_params.append(device_id)
+
+ query_clauses.append(query_clause)
+
+ sql = (
+ "SELECT k.user_id, k.device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM e2e_device_keys_json k"
+ " LEFT JOIN devices d ON d.user_id = k.user_id"
+ " AND d.device_id = k.device_id"
+ " WHERE %s"
+ ) % (
+ " OR ".join("(" + q + ")" for q in query_clauses)
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.cursor_to_dict(txn)
+
+ result = collections.defaultdict(dict)
+ for row in rows:
+ result[row["user_id"]][row["device_id"]] = row
+
+ return result
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
@@ -123,3 +151,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
+
+ @twisted.internet.defer.inlineCallbacks
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ yield self._simple_delete(
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_device_keys_by_device"
+ )
+ yield self._simple_delete(
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_one_time_keys_by_device"
+ )
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 940e11d7a2..df4000d0da 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -16,6 +16,8 @@
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import RoomStreamToken
+from .stream import lower_bound
import logging
import ujson as json
@@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0]
topological_ordering = results[0][1]
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
+ )
sql = (
"SELECT sum(notif), sum(highlight)"
@@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND ("
- " topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?)"
- ")"
- )
- txn.execute(sql, (
- user_id, room_id,
- topological_ordering, topological_ordering, stream_ordering
- ))
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
row = txn.fetchone()
if row:
return {
@@ -117,24 +117,42 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range(self, user_id,
- min_stream_ordering,
- max_stream_ordering=None,
- limit=20):
+ def get_unread_push_actions_for_user_in_range_for_http(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the httppusher.
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions".
+ The list will be ordered by ascending stream_ordering.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the next
+ # push actions
def get_after_receipt(txn):
+ # find rooms that have a read receipt in them and return the next
+ # push actions
sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
- "e.received_ts "
- "FROM ("
- " SELECT room_id, user_id, "
- " max(topological_ordering) as topological_ordering, "
- " max(stream_ordering) as stream_ordering "
- " FROM events"
- " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
- " GROUP BY room_id, user_id"
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
" AND ("
@@ -144,46 +162,163 @@ class EventPushActionsStore(SQLBaseStore):
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
- " AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
- " AND ep.user_id = rl.user_id"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
)
- args = [min_stream_ordering, user_id]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
- args.append(limit)
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_after_receipt
+ "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.room_id not in ("
- " SELECT room_id FROM events NATURAL JOIN receipts_linearized"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ no_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ )
+
+ notifs = [
+ {
+ "event_id": row[0],
+ "room_id": row[1],
+ "stream_ordering": row[2],
+ "actions": json.loads(row[3]),
+ } for row in after_read_receipt + no_read_receipt
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by stream_ordering, oldest first.
+ notifs.sort(key=lambda r: r['stream_ordering'])
+
+ # Take only up to the limit. We have to stop at the limit because
+ # one of the subqueries may have hit the limit.
+ defer.returnValue(notifs[:limit])
+
+ @defer.inlineCallbacks
+ def get_unread_push_actions_for_user_in_range_for_email(
+ self, user_id, min_stream_ordering, max_stream_ordering, limit=20
+ ):
+ """Get a list of the most recent unread push actions for a given user,
+ within the given stream ordering range. Called by the emailpusher
+
+ Args:
+ user_id (str): The user to fetch push actions for.
+ min_stream_ordering(int): The exclusive lower bound on the
+ stream ordering of event push actions to fetch.
+ max_stream_ordering(int): The inclusive upper bound on the
+ stream ordering of event push actions to fetch.
+ limit (int): The maximum number of rows to return.
+ Returns:
+ A promise which resolves to a list of dicts with the keys "event_id",
+ "room_id", "stream_ordering", "actions", "received_ts".
+ The list will be ordered by descending received_ts.
+ The list will have between 0~limit entries.
+ """
+ # find rooms that have a read receipt in them and return the most recent
+ # push actions
+ def get_after_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM ("
+ " SELECT room_id,"
+ " MAX(topological_ordering) as topological_ordering,"
+ " MAX(stream_ordering) as stream_ordering"
+ " FROM events"
+ " INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
- ") AND ep.user_id = ? AND ep.stream_ordering > ?"
+ ") AS rl,"
+ " event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id = rl.room_id"
+ " AND ("
+ " ep.topological_ordering > rl.topological_ordering"
+ " OR ("
+ " ep.topological_ordering = rl.topological_ordering"
+ " AND ep.stream_ordering > rl.stream_ordering"
+ " )"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
)
- args = [user_id, user_id, min_stream_ordering]
- if max_stream_ordering is not None:
- sql += " AND ep.stream_ordering <= ?"
- args.append(max_stream_ordering)
- sql += " ORDER BY ep.stream_ordering ASC"
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
+ txn.execute(sql, args)
+ return txn.fetchall()
+ after_read_receipt = yield self.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ )
+
+ # There are rooms with push actions in them but you don't have a read receipt in
+ # them e.g. rooms you've been invited to, so get push actions for rooms which do
+ # not have read receipts in them too.
+ def get_no_receipt(txn):
+ sql = (
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " e.received_ts"
+ " FROM event_push_actions AS ep"
+ " INNER JOIN events AS e USING (room_id, event_id)"
+ " WHERE"
+ " ep.room_id NOT IN ("
+ " SELECT room_id FROM receipts_linearized"
+ " WHERE receipt_type = 'm.read' AND user_id = ?"
+ " GROUP BY room_id"
+ " )"
+ " AND ep.user_id = ?"
+ " AND ep.stream_ordering > ?"
+ " AND ep.stream_ordering <= ?"
+ " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ )
+ args = [
+ user_id, user_id,
+ min_stream_ordering, max_stream_ordering, limit,
+ ]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
- "get_unread_push_actions_for_user_in_range", get_no_receipt
+ "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
- defer.returnValue([
+ # Make a list of dicts from the two sets of results.
+ notifs = [
{
"event_id": row[0],
"room_id": row[1],
@@ -191,7 +326,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]),
"received_ts": row[4],
} for row in after_read_receipt + no_read_receipt
- ])
+ ]
+
+ # Now sort it so it's ordered correctly, since currently it will
+ # contain results from the first query, correctly ordered, followed
+ # by results from the second query, but we want them all ordered
+ # by received_ts (most recent first)
+ notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+
+ # Now return the first `limit`
+ defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6d978ffcd5..d2feee8dbb 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,9 +23,11 @@ from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple
+from collections import deque, namedtuple, OrderedDict
+from functools import wraps
import synapse
import synapse.metrics
@@ -149,8 +151,29 @@ class _EventPeristenceQueue(object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+def _retry_on_integrity_error(func):
+ """Wraps a database function so that it gets retried on IntegrityError,
+ with `delete_existing=True` passed in.
+
+ Args:
+ func: function that returns a Deferred and accepts a `delete_existing` arg
+ """
+ @wraps(func)
+ @defer.inlineCallbacks
+ def f(self, *args, **kwargs):
+ try:
+ res = yield func(self, *args, **kwargs)
+ except self.database_engine.module.IntegrityError:
+ logger.exception("IntegrityError, retrying.")
+ res = yield func(self, *args, delete_existing=True, **kwargs)
+ defer.returnValue(res)
+
+ return f
+
+
class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
@@ -158,6 +181,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
+ self.register_background_update_handler(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ self._background_reindex_fields_sender,
+ )
self._event_persist_queue = _EventPeristenceQueue()
@@ -223,8 +250,10 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue.handle_queue(room_id, persisting_queue)
+ @_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(self, events_and_contexts, backfilled=False):
+ def _persist_events(self, events_and_contexts, backfilled=False,
+ delete_existing=False):
if not events_and_contexts:
return
@@ -267,12 +296,15 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
persist_event_counter.inc_by(len(chunk))
+ @_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
- def _persist_event(self, event, context, current_state=None, backfilled=False):
+ def _persist_event(self, event, context, current_state=None, backfilled=False,
+ delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
@@ -285,6 +317,7 @@ class EventsStore(SQLBaseStore):
context=context,
current_state=current_state,
backfilled=backfilled,
+ delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
@@ -317,7 +350,7 @@ class EventsStore(SQLBaseStore):
)
if not events and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@@ -347,7 +380,8 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
- def _persist_event_txn(self, txn, event, context, current_state, backfilled=False):
+ def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
+ delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
@@ -355,7 +389,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
- txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
@@ -388,10 +421,38 @@ class EventsStore(SQLBaseStore):
txn,
[(event, context)],
backfilled=backfilled,
+ delete_existing=delete_existing,
)
@log_function
- def _persist_events_txn(self, txn, events_and_contexts, backfilled):
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ delete_existing=False):
+ """Insert some number of room events into the necessary database tables.
+
+ Rejected events are only inserted into the events table, the events_json table,
+ and the rejections table. Things reading from those table will need to check
+ whether the event was rejected.
+
+ If delete_existing is True then existing events will be purged from the
+ database before insertion. This is useful when retrying due to IntegrityError.
+ """
+ # Ensure that we don't have the same event twice.
+ # Pick the earliest non-outlier if there is one, else the earliest one.
+ new_events_and_contexts = OrderedDict()
+ for event, context in events_and_contexts:
+ prev_event_context = new_events_and_contexts.get(event.event_id)
+ if prev_event_context:
+ if not event.internal_metadata.is_outlier():
+ if prev_event_context[0].internal_metadata.is_outlier():
+ # To ensure correct ordering we pop, as OrderedDict is
+ # ordered by first insertion.
+ new_events_and_contexts.pop(event.event_id, None)
+ new_events_and_contexts[event.event_id] = (event, context)
+ else:
+ new_events_and_contexts[event.event_id] = (event, context)
+
+ events_and_contexts = new_events_and_contexts.values()
+
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -402,21 +463,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering,
)
- if not event.internal_metadata.is_outlier():
+ if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
- if event.type == EventTypes.Redaction and event.redacts is not None:
- self._remove_push_actions_for_event_id_txn(
- txn, event.room_id, event.redacts
- )
-
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@@ -426,30 +477,21 @@ class EventsStore(SQLBaseStore):
),
[event.event_id for event, _ in events_and_contexts]
)
+
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
- event_map = {}
to_remove = set()
for event, context in events_and_contexts:
- # Handle the case of the list including the same event multiple
- # times. The tricky thing here is when they differ by whether
- # they are an outlier.
- if event.event_id in event_map:
- other = event_map[event.event_id]
-
- if not other.internal_metadata.is_outlier():
- to_remove.add(event)
- continue
- elif not event.internal_metadata.is_outlier():
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ if event.event_id in have_persisted:
+ # If we have already seen the event then ignore it.
to_remove.add(event)
- continue
- else:
- to_remove.add(other)
-
- event_map[event.event_id] = event
+ continue
if event.event_id not in have_persisted:
continue
@@ -458,6 +500,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
+ # We received a copy of an event that we had already stored as
+ # an outlier in the database. We now have some state at that
+ # so we need to update the state_groups table with that state.
+
+ # insert into the state_group, state_groups_state and
+ # event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
@@ -473,6 +521,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
+ # Add an entry to the ex_outlier_stream table to replicate the
+ # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
@@ -494,6 +544,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
+ # Update the event_backward_extremities table now that this
+ # event isn't an outlier any more.
self._update_extremeties(txn, [event])
events_and_contexts = [
@@ -501,38 +553,12 @@ class EventsStore(SQLBaseStore):
]
if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
return
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
- self._handle_mult_prev_events(
- txn,
- events=[event for event, _ in events_and_contexts],
- )
-
- for event, _ in events_and_contexts:
- if event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Message:
- self._store_room_message_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
- elif event.type == EventTypes.RoomHistoryVisibility:
- self._store_history_visibility_txn(txn, event)
- elif event.type == EventTypes.GuestAccess:
- self._store_guest_access_txn(txn, event)
-
- self._store_room_members_txn(
- txn,
- [
- event
- for event, _ in events_and_contexts
- if event.type == EventTypes.Member
- ],
- backfilled=backfilled,
- )
+ # From this point onwards the events are only events that we haven't
+ # seen before.
def event_dict(event):
return {
@@ -544,6 +570,43 @@ class EventsStore(SQLBaseStore):
]
}
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+
+ logger.info("Deleting existing")
+
+ for table in (
+ "events",
+ "event_auth",
+ "event_json",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "event_to_state_groups",
+ "guest_access",
+ "history_visibility",
+ "local_invites",
+ "room_names",
+ "state_events",
+ "rejections",
+ "redactions",
+ "room_memberships",
+ "state_events"
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
self._simple_insert_many_txn(
txn,
table="event_json",
@@ -576,15 +639,51 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
+ "sender": event.sender,
+ "contains_url": (
+ "url" in event.content
+ and isinstance(event.content["url"], basestring)
+ ),
}
for event, _ in events_and_contexts
],
)
- if context.rejected:
- self._store_rejections_txn(
- txn, event.event_id, context.rejected
- )
+ # Remove the rejected events from the list now that we've added them
+ # to the events table and the events_json table.
+ to_remove = set()
+ for event, context in events_and_contexts:
+ if context.rejected:
+ # Insert the event_id into the rejections table
+ self._store_rejections_txn(
+ txn, event.event_id, context.rejected
+ )
+ to_remove.add(event)
+
+ events_and_contexts = [
+ ec for ec in events_and_contexts if ec[0] not in to_remove
+ ]
+
+ if not events_and_contexts:
+ # Make sure we don't pass an empty list to functions that expect to
+ # be storing at least one element.
+ return
+
+ # From this point onwards the events are only ones that weren't rejected.
+
+ for event, context in events_and_contexts:
+ # Insert all the push actions into the event_push_actions table.
+ if context.push_actions:
+ self._set_push_actions_for_event_and_users_txn(
+ txn, event, context.push_actions
+ )
+
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ # Remove the entries in the event_push_actions table for the
+ # redacted event.
+ self._remove_push_actions_for_event_id_txn(
+ txn, event.room_id, event.redacts
+ )
self._simple_insert_many_txn(
txn,
@@ -600,6 +699,49 @@ class EventsStore(SQLBaseStore):
],
)
+ # Insert into the state_groups, state_groups_state, and
+ # event_to_state_groups tables.
+ self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+ # Update the event_forward_extremities, event_backward_extremities and
+ # event_edges tables.
+ self._handle_mult_prev_events(
+ txn,
+ events=[event for event, _ in events_and_contexts],
+ )
+
+ for event, _ in events_and_contexts:
+ if event.type == EventTypes.Name:
+ # Insert into the room_names and event_search tables.
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ # Insert into the topics table and event_search table.
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Message:
+ # Insert into the event_search table.
+ self._store_room_message_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ # Insert into the redactions table.
+ self._store_redaction(txn, event)
+ elif event.type == EventTypes.RoomHistoryVisibility:
+ # Insert into the event_search table.
+ self._store_history_visibility_txn(txn, event)
+ elif event.type == EventTypes.GuestAccess:
+ # Insert into the event_search table.
+ self._store_guest_access_txn(txn, event)
+
+ # Insert into the room_memberships table.
+ self._store_room_members_txn(
+ txn,
+ [
+ event
+ for event, _ in events_and_contexts
+ if event.type == EventTypes.Member
+ ],
+ backfilled=backfilled,
+ )
+
+ # Insert event_reference_hashes table.
self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts]
)
@@ -644,6 +786,7 @@ class EventsStore(SQLBaseStore):
],
)
+ # Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
if backfilled:
@@ -656,22 +799,11 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state.
continue
- if context.rejected:
- # If the event failed it's auth checks then it shouldn't
- # clobbler the current state.
- continue
-
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
- if event.type in [EventTypes.Name, EventTypes.Aliases]:
- txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- (event.room_id,)
- )
-
self._simple_upsert_txn(
txn,
"current_state_events",
@@ -1122,6 +1254,78 @@ class EventsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
+ def _background_reindex_fields_sender(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id, json FROM events"
+ " INNER JOIN event_json USING (event_id)"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+
+ update_rows = []
+ for row in rows:
+ try:
+ event_id = row[1]
+ event_json = json.loads(row[2])
+ sender = event_json["sender"]
+ content = event_json["content"]
+
+ contains_url = "url" in content
+ if contains_url:
+ contains_url &= isinstance(content["url"], basestring)
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ update_rows.append((sender, contains_url, event_id))
+
+ sql = (
+ "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+ )
+
+ for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+ clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows)
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -1288,6 +1492,162 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ def delete_old_state(self, room_id, topological_ordering):
+ return self.runInteraction(
+ "delete_old_state",
+ self._delete_old_state_txn, room_id, topological_ordering
+ )
+
+ def _delete_old_state_txn(self, txn, room_id, topological_ordering):
+ """Deletes old room state
+ """
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_content_hashes
+ # event_destinations
+ # event_edge_hashes
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_signatures
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,)
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[0] for row in rows)
+
+ if max_depth <= topological_ordering:
+ # We need to ensure we don't delete all the events from the datanase
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ txn.execute(
+ "SELECT event_id, state_key FROM events"
+ " LEFT JOIN state_events USING (room_id, event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?",
+ (room_id, topological_ordering,)
+ )
+ event_rows = txn.fetchall()
+
+ # We calculate the new entries for the backward extremeties by finding
+ # all events that point to events that are to be purged
+ txn.execute(
+ "SELECT DISTINCT e.event_id FROM events as e"
+ " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
+ " WHERE e.room_id = ? AND e.topological_ordering < ?"
+ " AND e2.topological_ordering >= ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ txn.execute(
+ "DELETE FROM event_backward_extremities WHERE room_id = ?",
+ (room_id,)
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [
+ (room_id, event_id) for event_id, in new_backwards_extrems
+ ]
+ )
+
+ # Get all state groups that are only referenced by events that are
+ # to be deleted.
+ txn.execute(
+ "SELECT state_group FROM event_to_state_groups"
+ " INNER JOIN events USING (event_id)"
+ " WHERE state_group IN ("
+ " SELECT DISTINCT state_group FROM events"
+ " INNER JOIN event_to_state_groups USING (event_id)"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " )"
+ " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+ (room_id, topological_ordering, topological_ordering)
+ )
+ state_rows = txn.fetchall()
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ state_rows
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ state_rows
+ )
+ # Delete all non-state
+ txn.executemany(
+ "DELETE FROM event_to_state_groups WHERE event_id = ?",
+ [(event_id,) for event_id, _ in event_rows]
+ )
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # Delete all remote non-state events
+ to_delete = [
+ (event_id,) for event_id, state_key in event_rows
+ if state_key is None and not self.hs.is_mine_id(event_id)
+ ]
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_content_hashes",
+ "event_destinations",
+ "event_edge_hashes",
+ "event_edges",
+ "event_forward_extremities",
+ "event_push_actions",
+ "event_reference_hashes",
+ "event_search",
+ "event_signatures",
+ "rejections",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ to_delete
+ )
+
+ txn.executemany(
+ "DELETE FROM events WHERE event_id = ?",
+ to_delete
+ )
+ # Mark all state and own events as outliers
+ txn.executemany(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id = ?",
+ [
+ (True, event_id,) for event_id, state_key in event_rows
+ if state_key is not None or self.hs.is_mine_id(event_id)
+ ]
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index a495a8a7d9..86b37b9ddd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,10 @@ import OpenSSL
from signedjson.key import decode_verify_key_bytes
import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@@ -74,22 +78,22 @@ class KeyStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_all_server_verify_keys(self, server_name):
- rows = yield self._simple_select_list(
+ def _get_server_verify_key(self, server_name, key_id):
+ verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
+ "key_id": key_id,
},
- retcols=["key_id", "verify_key"],
- desc="get_all_server_verify_keys",
+ retcol="verify_key",
+ desc="_get_server_verify_key",
+ allow_none=True,
)
- defer.returnValue({
- row["key_id"]: decode_verify_key_bytes(
- row["key_id"], str(row["verify_key"])
- )
- for row in rows
- })
+ if verify_key_bytes:
+ defer.returnValue(decode_verify_key_bytes(
+ key_id, str(verify_key_bytes)
+ ))
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@@ -101,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- keys = yield self.get_all_server_verify_keys(server_name)
- defer.returnValue({
- k: keys[k]
- for k in key_ids
- if k in keys and keys[k]
- })
+ keys = {}
+ for key_id in key_ids:
+ key = yield self._get_server_verify_key(server_name, key_id)
+ if key:
+ keys[key_id] = key
+ defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -133,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
- self.get_all_server_verify_keys.invalidate((server_name,))
-
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a820fcf07f..4c0f82353d 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
+ "last_access_ts": time_now_ms,
},
desc="store_cached_remote_media",
)
+ def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cache_txn(txn):
+ sql = (
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ " WHERE media_origin = ? AND media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ts, media_origin, media_id)
+ for media_origin, media_id in origin_id_tuples
+ ))
+
+ return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
"remote_media_cache_thumbnails",
@@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
},
desc="store_remote_media_thumbnail",
)
+
+ def get_remote_media_before(self, before_ts):
+ sql = (
+ "SELECT media_origin, media_id, filesystem_id"
+ " FROM remote_media_cache"
+ " WHERE last_access_ts < ?"
+ )
+
+ return self._execute(
+ "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ )
+
+ def delete_remote_media(self, media_origin, media_id):
+ def delete_remote_media_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ "remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": media_origin, "media_id": media_id
+ },
+ )
+ return self.runInteraction("delete_remote_media", delete_remote_media_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c8487c8838..8801669a6b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 32
+SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3de9e0f709..7e7d32eb66 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,25 +18,40 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
-
-from ._base import SQLBaseStore
+from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
+ self.register_background_index_update(
+ "access_tokens_device_index",
+ index_name="access_tokens_device_id",
+ table="access_tokens",
+ columns=["user_id", "device_id"],
+ )
+
+ self.register_background_index_update(
+ "refresh_tokens_device_index",
+ index_name="refresh_tokens_device_id",
+ table="refresh_tokens",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token):
+ def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_access_token_to_user",
)
@defer.inlineCallbacks
- def add_refresh_token_to_user(self, user_id, token):
+ def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -69,20 +87,23 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash,
+ def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_localpart=None):
+ create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user.
+ token (str): The desired access token to use for this user. If this
+ is not None, the given access token is associated with the user
+ id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -104,6 +125,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
@@ -118,6 +140,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
+ admin,
):
now = int(self.clock.time())
@@ -125,29 +148,48 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
- txn.execute("UPDATE users SET"
- " password_hash = ?,"
- " upgrade_ts = ?,"
- " is_guest = ?"
- " WHERE name = ?",
- [password_hash, now, 1 if make_guest else 0, user_id])
+ # Ensure that the guest user actually exists
+ # ``allow_none=False`` makes this raise an exception
+ # if the row isn't in the database.
+ self._simple_select_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ retcols=("name",),
+ allow_none=False,
+ )
+
+ self._simple_update_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ updatevalues={
+ "password_hash": password_hash,
+ "upgrade_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
else:
- txn.execute("INSERT INTO users "
- "("
- " name,"
- " password_hash,"
- " creation_ts,"
- " is_guest,"
- " appservice_id"
- ") "
- "VALUES (?,?,?,?,?)",
- [
- user_id,
- password_hash,
- now,
- 1 if make_guest else 0,
- appservice_id,
- ])
+ self._simple_insert_txn(
+ txn,
+ "users",
+ values={
+ "name": user_id,
+ "password_hash": password_hash,
+ "creation_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -209,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[]):
- def f(txn):
- sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+ def user_delete_access_tokens(self, user_id, except_token_ids=[],
+ device_id=None,
+ delete_refresh_tokens=False):
+ """
+ Invalidate access/refresh tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_ids (list[str]): list of access_tokens which should
+ *not* be deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ delete_refresh_tokens (bool): True to delete refresh tokens as
+ well as access tokens.
+ Returns:
+ defer.Deferred:
+ """
+ def f(txn, table, except_tokens, call_after_delete):
+ sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
- if except_token_ids:
+ if device_id is not None:
+ sql += " AND device_id = ?"
+ clauses.append(device_id)
+
+ if except_tokens:
sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_token_ids]),
+ ",".join(["?" for _ in except_tokens]),
)
- clauses += except_token_ids
+ clauses += except_tokens
txn.execute(sql, clauses)
@@ -227,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
- for row in chunk:
- txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+ if call_after_delete:
+ for row in chunk:
+ txn.call_after(call_after_delete, (row[0],))
txn.execute(
- "DELETE FROM access_tokens WHERE token in (%s)" % (
+ "DELETE FROM %s WHERE token in (%s)" % (
+ table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
- yield self.runInteraction("user_delete_access_tokens", f)
+ # delete refresh tokens first, to stop new access tokens being
+ # allocated while our backs are turned
+ if delete_refresh_tokens:
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="refresh_tokens",
+ except_tokens=[],
+ call_after_delete=None,
+ )
+
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="access_tokens",
+ except_tokens=except_token_ids,
+ call_after_delete=self.get_user_by_access_token.invalidate,
+ )
def delete_access_token(self, access_token):
def f(txn):
@@ -259,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
- dict: Including the name (user_id) and the ID of their access token.
- Raises:
- StoreError if no user was found.
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
@@ -270,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
)
def exchange_refresh_token(self, refresh_token, token_generator):
- """Exchange a refresh token for a new access token and refresh token.
+ """Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
- token (str): The refresh token of a user.
+ refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
- tuple of (user_id, refresh_token)
+ tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
@@ -293,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
- sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+ sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
+ device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
@@ -307,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
- return user_id, new_token
+ return user_id, new_token, device_id
@defer.inlineCallbacks
def is_server_admin(self, user):
@@ -335,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -384,6 +465,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
+ def user_delete_threepids(self, user_id):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="user_delete_threepids",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 97f9f1929c..8251f58670 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections
@@ -192,49 +191,6 @@ class RoomStore(SQLBaseStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- @cachedInlineCallbacks()
- def get_room_name_and_aliases(self, room_id):
- def get_room_name(txn):
- sql = (
- "SELECT name FROM room_names"
- " INNER JOIN current_state_events USING (room_id, event_id)"
- " WHERE room_id = ?"
- " LIMIT 1"
- )
-
- txn.execute(sql, (room_id,))
- rows = txn.fetchall()
- if rows:
- return rows[0][0]
- else:
- return None
-
- return [row[0] for row in txn.fetchall()]
-
- def get_room_aliases(txn):
- sql = (
- "SELECT content FROM current_state_events"
- " INNER JOIN events USING (room_id, event_id)"
- " WHERE room_id = ?"
- )
- txn.execute(sql, (room_id,))
- return [row[0] for row in txn.fetchall()]
-
- name = yield self.runInteraction("get_room_name", get_room_name)
- alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
-
- aliases = []
-
- for c in alias_contents:
- try:
- content = json.loads(c)
- except:
- continue
-
- aliases.extend(content.get('aliases', []))
-
- defer.returnValue((name, aliases))
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
new file mode 100644
index 0000000000..61ad3fe3e8
--- /dev/null
+++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('access_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql
new file mode 100644
index 0000000000..eca7268d82
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE devices (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
new file mode 100644
index 0000000000..aa4a3b9f2f
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- make sure that we have a device record for each set of E2E keys, so that the
+-- user can delete them if they like.
+INSERT INTO devices
+ SELECT user_id, device_id, NULL FROM e2e_device_keys_json;
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
new file mode 100644
index 0000000000..6671573398
--- /dev/null
+++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
@@ -0,0 +1,20 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a previous version of the "devices_for_e2e_keys" delta set all the device
+-- names to "unknown device". This wasn't terribly helpful
+UPDATE devices
+ SET display_name = NULL
+ WHERE display_name = 'unknown device';
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
new file mode 100644
index 0000000000..83066cccc9
--- /dev/null
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE events ADD COLUMN sender TEXT;
+ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(ALTER_TABLE.splitlines()):
+ cur.execute(statement)
+
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ min_stream_id = rows[0][0]
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ max_stream_id = rows[0][0]
+
+ if min_stream_id is not None and max_stream_id is not None:
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ "rows_inserted": 0,
+ }
+ progress_json = ujson.dumps(progress)
+
+ sql = (
+ "INSERT into background_updates (update_name, progress_json)"
+ " VALUES (?, ?)"
+ )
+
+ sql = database_engine.convert_param_style(sql)
+
+ cur.execute(sql, ("event_fields_sender_url", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql
new file mode 100644
index 0000000000..290bd6da86
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
new file mode 100644
index 0000000000..bb225dafbf
--- /dev/null
+++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('refresh_tokens_device_index', '{}');
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
new file mode 100644
index 0000000000..55ae43f395
--- /dev/null
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -0,0 +1,31 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+
+ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(ALTER_TABLE)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ database_engine.convert_param_style(
+ "UPDATE remote_media_cache SET last_access_ts = ?"
+ ),
+ (int(time.time() * 1000),)
+ )
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql
new file mode 100644
index 0000000000..473f75a78e
--- /dev/null
+++ b/synapse/storage/schema/delta/33/user_ips_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('user_ips_device_index', '{}');
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b9ad965fd6..862c5c3ea1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@@ -54,26 +55,92 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-def lower_bound(token):
+def lower_bound(token, engine, inclusive=False):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d < %s)" % (token.stream, "stream_ordering")
+ return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d < %s OR (%d = %s AND %d < %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) <%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
)
-def upper_bound(token):
+def upper_bound(token, engine, inclusive=True):
+ inclusive = "=" if inclusive else ""
if token.topological is None:
- return "(%d >= %s)" % (token.stream, "stream_ordering")
+ return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
- return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
+ # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) >%s (%s,%s))" % (
+ token.topological, token.stream, inclusive,
+ "topological_ordering", "stream_ordering",
+ )
+ return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
- token.stream, "stream_ordering",
+ token.stream, inclusive, "stream_ordering",
+ )
+
+
+def filter_to_clause(event_filter):
+ # NB: This may create SQL clauses that don't optimise well (and we don't
+ # have indices on all possible clauses). E.g. it may create
+ # "room_id == X AND room_id != X", which postgres doesn't optimise.
+
+ if not event_filter:
+ return "", []
+
+ clauses = []
+ args = []
+
+ if event_filter.types:
+ clauses.append(
+ "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
+ )
+ args.extend(event_filter.types)
+
+ for typ in event_filter.not_types:
+ clauses.append("type != ?")
+ args.append(typ)
+
+ if event_filter.senders:
+ clauses.append(
+ "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
+ args.extend(event_filter.senders)
+
+ for sender in event_filter.not_senders:
+ clauses.append("sender != ?")
+ args.append(sender)
+
+ if event_filter.rooms:
+ clauses.append(
+ "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
+ )
+ args.extend(event_filter.rooms)
+
+ for room_id in event_filter.not_rooms:
+ clauses.append("room_id != ?")
+ args.append(room_id)
+
+ if event_filter.contains_url:
+ clauses.append("contains_url = ?")
+ args.append(event_filter.contains_url)
+
+ return " AND ".join(clauses), args
class StreamStore(SQLBaseStore):
@@ -301,25 +368,35 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1):
+ direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(RoomStreamToken.parse(from_key))
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, lower_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
else:
order = "ASC"
- bounds = lower_bound(RoomStreamToken.parse(from_key))
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
if to_key:
- bounds = "%s AND %s" % (
- bounds, upper_bound(RoomStreamToken.parse(to_key))
- )
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
if int(limit) > 0:
args.append(int(limit))
@@ -487,13 +564,13 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],)
)
- def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ def get_max_topological_token(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
- "get_max_topological_token_for_stream_and_room", None,
+ "get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
@@ -586,32 +663,60 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- stream_ordering = results["stream_ordering"]
- topological_ordering = results["topological_ordering"]
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering < ?"
- " OR (topological_ordering = ? AND stream_ordering < ?))"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
+ token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
)
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND (topological_ordering > ?"
- " OR (topological_ordering = ? AND stream_ordering > ?))"
- " ORDER BY topological_ordering ASC, stream_ordering ASC"
- " LIMIT ?"
- )
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
+ # So we give pass it to SQLite3 as the UNION ALL of the two queries.
+
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering < ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ )
+ before_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ before_limit,
+ )
- txn.execute(
- query_before,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, before_limit,
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering > ?"
+ " UNION ALL"
+ " SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
)
- )
+ after_args = (
+ room_id, token.topological,
+ room_id, token.topological, token.stream,
+ after_limit,
+ )
+ else:
+ query_before = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ ) % (upper_bound(token, self.database_engine, inclusive=False),)
+
+ before_args = (room_id, before_limit)
+
+ query_after = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " WHERE room_id = ? AND %s"
+ " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ after_args = (room_id, after_limit)
+
+ txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
@@ -623,17 +728,11 @@ class StreamStore(SQLBaseStore):
))
else:
start_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering - 1,
+ token.topological,
+ token.stream - 1,
))
- txn.execute(
- query_after,
- (
- room_id, topological_ordering, topological_ordering,
- stream_ordering, after_limit,
- )
- )
+ txn.execute(query_after, after_args)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
@@ -644,10 +743,7 @@ class StreamStore(SQLBaseStore):
rows[-1]["stream_ordering"],
))
else:
- end_token = str(RoomStreamToken(
- topological_ordering,
- stream_ordering,
- ))
+ end_token = str(token)
return {
"before": {
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 6c7481a728..6258ff1725 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -24,6 +24,7 @@ from collections import namedtuple
import itertools
import logging
+import ujson as json
logger = logging.getLogger(__name__)
@@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], result["response_json"]
+ return result["response_code"], json.loads(str(result["response_json"]))
else:
return None
diff --git a/synapse/types.py b/synapse/types.py
index f639651a73..5349b0c450 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
-Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
+Requester = namedtuple("Requester",
+ ["user", "access_token_id", "is_guest", "device_id"])
+"""
+Represents the user making a request
+
+Attributes:
+ user (UserID): id of the user making the request
+ access_token_id (int|None): *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest (bool): True if the user making this request is a guest user
+ device_id (str|None): device_id which was set at authentication time
+"""
+
+
+def create_requester(user_id, access_token_id=None, is_guest=False,
+ device_id=None):
+ """
+ Create a new ``Requester`` object
+
+ Args:
+ user_id (str|UserID): id of the user making the request
+ access_token_id (int|None): *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest (bool): True if the user making this request is a guest user
+ device_id (str|None): device_id which was set at authentication time
+
+ Returns:
+ Requester
+ """
+ if not isinstance(user_id, UserID):
+ user_id = UserID.from_string(user_id)
+ return Requester(user_id, access_token_id, is_guest, device_id)
def get_domain_from_id(string):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 40be7fe7e3..c84b23ff46 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -194,3 +194,85 @@ class Linearizer(object):
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
+
+
+class ReadWriteLock(object):
+ """A deferred style read write lock.
+
+ Example:
+
+ with (yield read_write_lock.read("test_key")):
+ # do some work
+ """
+
+ # IMPLEMENTATION NOTES
+ #
+ # We track the most recent queued reader and writer deferreds (which get
+ # resolved when they release the lock).
+ #
+ # Read: We know its safe to acquire a read lock when the latest writer has
+ # been resolved. The new reader is appeneded to the list of latest readers.
+ #
+ # Write: We know its safe to acquire the write lock when both the latest
+ # writers and readers have been resolved. The new writer replaces the latest
+ # writer.
+
+ def __init__(self):
+ # Latest readers queued
+ self.key_to_current_readers = {}
+
+ # Latest writer queued
+ self.key_to_current_writer = {}
+
+ @defer.inlineCallbacks
+ def read(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.setdefault(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ curr_readers.add(new_defer)
+
+ # We wait for the latest writer to finish writing. We can safely ignore
+ # any existing readers... as they're readers.
+ yield curr_writer
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ self.key_to_current_readers.get(key, set()).discard(new_defer)
+
+ defer.returnValue(_ctx_manager())
+
+ @defer.inlineCallbacks
+ def write(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.get(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ # We wait on all latest readers and writer.
+ to_wait_on = list(curr_readers)
+ if curr_writer:
+ to_wait_on.append(curr_writer)
+
+ # We can clear the list of current readers since the new writer waits
+ # for them to finish.
+ curr_readers.clear()
+ self.key_to_current_writer[key] = new_defer
+
+ yield defer.gatherResults(to_wait_on)
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ if self.key_to_current_writer[key] == new_defer:
+ self.key_to_current_writer.pop(key)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 36686b479e..00af539880 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -24,9 +24,12 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self):
+ def __init__(self, hs, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
+ self.clock = hs.get_clock()
+ self.timeout_sec = timeout_ms / 1000.
+
def get(self, key):
result = self.pending_result_cache.get(key)
if result is not None:
@@ -39,7 +42,13 @@ class ResponseCache(object):
self.pending_result_cache[key] = result
def remove(r):
- self.pending_result_cache.pop(key, None)
+ if self.timeout_sec:
+ self.clock.call_later(
+ self.timeout_sec,
+ self.pending_result_cache.pop, key, None,
+ )
+ else:
+ self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index e1f374807e..0b944d3e63 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -84,7 +84,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
- "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+ "Context has unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
)
return
diff --git a/synapse/util/presentable_names.py b/synapse/util/presentable_names.py
index a6866f6117..f68676e9e7 100644
--- a/synapse/util/presentable_names.py
+++ b/synapse/util/presentable_names.py
@@ -25,7 +25,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
-def calculate_room_name(room_state, user_id, fallback_to_members=True):
+def calculate_room_name(room_state, user_id, fallback_to_members=True,
+ fallback_to_single_member=True):
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
@@ -82,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
):
if ("m.room.member", my_member_event.sender) in room_state:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
- return "Invite from %s" % (name_from_member_event(inviter_member_event),)
+ if fallback_to_single_member:
+ return "Invite from %s" % (name_from_member_event(inviter_member_event),)
+ else:
+ return None
else:
return "Room Invite"
@@ -129,6 +133,8 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
return name_from_member_event(all_members[0])
else:
return ALL_ALONE
+ elif len(other_members) == 1 and not fallback_to_single_member:
+ return None
else:
return descriptor_from_member_events(other_members)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 43cf11f3f6..49527f4d21 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -128,7 +128,7 @@ class RetryDestinationLimiter(object):
)
valid_err_code = False
- if exc_type is CodeMessageException:
+ if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4f156cb3b..52086df465 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -21,7 +21,7 @@ import logging
logger = logging.getLogger(__name__)
-def get_version_string(name, module):
+def get_version_string(module):
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -74,11 +74,11 @@ def get_version_string(name, module):
)
return (
- "%s/%s (%s)" % (
- name, module.__version__, git_version,
+ "%s (%s)" % (
+ module.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return ("%s/%s" % (name, module.__version__,)).encode("ascii")
+ return module.__version__.encode("ascii")
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ad269af0ec..e91723ca3d 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -45,6 +45,7 @@ class AuthTestCase(unittest.TestCase):
user_info = {
"name": self.test_user,
"token_id": "ditto",
+ "device_id": "device",
}
self.store.get_user_by_access_token = Mock(return_value=user_info)
@@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org"}
+ return_value={
+ "name": "@baldrick:matrix.org",
+ "device_id": "device",
+ }
)
user_id = "@baldrick:matrix.org"
@@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
+ # TODO: device_id should come from the macaroon, but currently comes
+ # from the db.
+ self.assertEqual(user_info["device_id"], "device")
+
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
@@ -281,7 +289,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
- macaroon.add_first_party_caveat("time < 1") # ms
+ macaroon.add_first_party_caveat("time < -2000") # ms
self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True
@@ -293,3 +301,32 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_with_valid_duration(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ macaroon.add_first_party_caveat("time < 900000000") # ms
+
+ self.hs.clock.now = 5000 # seconds
+ self.hs.config.expire_access_token = True
+
+ user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ user = user_info["user"]
+ self.assertEqual(UserID.from_string(user_id), user)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
new file mode 100644
index 0000000000..85a970a6c9
--- /dev/null
+++ b/tests/handlers/test_device.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.device
+
+import synapse.storage
+from synapse import types
+from tests import unittest, utils
+
+user1 = "@boris:aaa"
+user2 = "@theresa:bbb"
+
+
+class DeviceTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(DeviceTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+ self.handler = None # type: synapse.handlers.device.DeviceHandler
+ self.clock = None # type: utils.MockClock
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield utils.setup_test_homeserver(handlers=None)
+ self.handler = synapse.handlers.device.DeviceHandler(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def test_device_is_created_if_doesnt_exist(self):
+ res = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="display name"
+ )
+ self.assertEqual(res, "fco")
+
+ dev = yield self.handler.store.get_device("boris", "fco")
+ self.assertEqual(dev["display_name"], "display name")
+
+ @defer.inlineCallbacks
+ def test_device_is_preserved_if_exists(self):
+ res1 = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="display name"
+ )
+ self.assertEqual(res1, "fco")
+
+ res2 = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="new display name"
+ )
+ self.assertEqual(res2, "fco")
+
+ dev = yield self.handler.store.get_device("boris", "fco")
+ self.assertEqual(dev["display_name"], "display name")
+
+ @defer.inlineCallbacks
+ def test_device_id_is_made_up_if_unspecified(self):
+ device_id = yield self.handler.check_device_registered(
+ user_id="theresa",
+ device_id=None,
+ initial_device_display_name="display"
+ )
+
+ dev = yield self.handler.store.get_device("theresa", device_id)
+ self.assertEqual(dev["display_name"], "display")
+
+ @defer.inlineCallbacks
+ def test_get_devices_by_user(self):
+ yield self._record_users()
+
+ res = yield self.handler.get_devices_by_user(user1)
+ self.assertEqual(3, len(res))
+ device_map = {
+ d["device_id"]: d for d in res
+ }
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "xyz",
+ "display_name": "display 0",
+ "last_seen_ip": None,
+ "last_seen_ts": None,
+ }, device_map["xyz"])
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "fco",
+ "display_name": "display 1",
+ "last_seen_ip": "ip1",
+ "last_seen_ts": 1000000,
+ }, device_map["fco"])
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "abc",
+ "display_name": "display 2",
+ "last_seen_ip": "ip3",
+ "last_seen_ts": 3000000,
+ }, device_map["abc"])
+
+ @defer.inlineCallbacks
+ def test_get_device(self):
+ yield self._record_users()
+
+ res = yield self.handler.get_device(user1, "abc")
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "abc",
+ "display_name": "display 2",
+ "last_seen_ip": "ip3",
+ "last_seen_ts": 3000000,
+ }, res)
+
+ @defer.inlineCallbacks
+ def test_delete_device(self):
+ yield self._record_users()
+
+ # delete the device
+ yield self.handler.delete_device(user1, "abc")
+
+ # check the device was deleted
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.get_device(user1, "abc")
+
+ # we'd like to check the access token was invalidated, but that's a
+ # bit of a PITA.
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self._record_users()
+
+ update = {"display_name": "new display"}
+ yield self.handler.update_device(user1, "abc", update)
+
+ res = yield self.handler.get_device(user1, "abc")
+ self.assertEqual(res["display_name"], "new display")
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ update = {"display_name": "new_display"}
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.update_device("user_id", "unknown_device_id",
+ update)
+
+ @defer.inlineCallbacks
+ def _record_users(self):
+ # check this works for both devices which have a recorded client_ip,
+ # and those which don't.
+ yield self._record_user(user1, "xyz", "display 0")
+ yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
+ yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
+ yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
+
+ yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
+
+ @defer.inlineCallbacks
+ def _record_user(self, user_id, device_id, display_name,
+ access_token=None, ip=None):
+ device_id = yield self.handler.check_device_registered(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=display_name
+ )
+
+ if ip is not None:
+ yield self.store.insert_client_ip(
+ types.UserID.from_string(user_id),
+ access_token, ip, "user_agent", device_id)
+ self.clock.advance_time(1000)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
new file mode 100644
index 0000000000..878a54dc34
--- /dev/null
+++ b/tests/handlers/test_e2e_keys.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.e2e_keys
+
+import synapse.storage
+from tests import unittest, utils
+
+
+class E2eKeysHandlerTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ self.hs = None # type: synapse.server.HomeServer
+ self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield utils.setup_test_homeserver(
+ handlers=None,
+ replication_layer=mock.Mock(),
+ )
+ self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+
+ @defer.inlineCallbacks
+ def test_query_local_devices_no_devices(self):
+ """If the user has no devices, we expect an empty list.
+ """
+ local_user = "@boris:" + self.hs.hostname
+ res = yield self.handler.query_local_devices({local_user: None})
+ self.assertDictEqual(res, {local_user: {}})
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f2c14e4ff..f1f664275f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -19,11 +19,12 @@ from twisted.internet import defer
from mock import Mock, NonCallableMock
+import synapse.types
from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
-from tests.utils import setup_test_homeserver, requester_for_user
+from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
@@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank,
- requester_for_user(self.frank),
+ synapse.types.create_requester(self.frank),
"Frank Jr."
)
@@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank,
- requester_for_user(self.bob),
+ synapse.types.create_requester(self.bob),
"Frank Jr."
)
@@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
- self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
+ self.frank, synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif"
)
self.assertEquals(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 69a5e5b1d4..a7de3c7c17 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase):
http_client=None,
expire_access_token=True)
self.auth_handler = Mock(
- generate_short_term_login_token=Mock(return_value='secret'))
+ generate_access_token=Mock(return_value='secret'))
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
- "generate_short_term_login_token",
+ "generate_access_token",
])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 17587fda00..f33e6f60fb 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -59,47 +59,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
- def test_room_name_and_aliases(self):
- create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
- yield self.persist(type="m.room.name", key="", name="name1")
- yield self.persist(
- type="m.room.aliases", key="blue", aliases=["#1:blue"]
- )
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
- )
-
- # Set the room name.
- yield self.persist(type="m.room.name", key="", name="name2")
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
- )
-
- # Set the room aliases.
- yield self.persist(
- type="m.room.aliases", key="blue", aliases=["#2:blue"]
- )
- yield self.replicate()
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
- )
-
- # Leave and join the room clobbering the state.
- yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
- yield self.persist(
- type="m.room.member", key=USER_ID, membership="join",
- reset_state=[create]
- )
- yield self.replicate()
-
- yield self.check(
- "get_room_name_and_aliases", (ROOM_ID,), (None, [])
- )
-
- @defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index 842e3d29d7..e70ac6f14d 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.resource import ReplicationResource
-from synapse.types import Requester, UserID
+import contextlib
+import json
+from mock import Mock, NonCallableMock
from twisted.internet import defer
+
+import synapse.types
+from synapse.replication.resource import ReplicationResource
+from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver, requester_for_user
-from mock import Mock, NonCallableMock
-import json
-import contextlib
+from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase):
@@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
code, body = yield get
self.assertEquals(code, 200)
@@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event(
- requester_for_user(self.user),
+ synapse.types.create_requester(self.user),
{
"type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"},
@@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
defer.returnValue(result["room_id"])
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index af02fce8fb..1e95e97538 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,17 +14,14 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
+from twisted.internet import defer
-from ....utils import MockHttpResource, setup_test_homeserver
-
+import synapse.types
from synapse.api.errors import SynapseError, AuthError
-from synapse.types import Requester, UserID
-
from synapse.rest.client.v1 import profile
+from tests import unittest
+from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1"
@@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return Requester(UserID.from_string(myid), "", False)
+ return synapse.types.create_requester(myid)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index cda0a2b27c..8ac56a1fb2 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -30,6 +30,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
+ self.device_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
@@ -42,6 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
+ self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
# init the thing we're testing
@@ -61,13 +63,18 @@ class RegisterRestServletTestCase(unittest.TestCase):
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
- return_value=(user_id, token)
+ return_value=user_id
)
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
+
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
+ "refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
@@ -105,26 +112,37 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
+ device_id = "frogfone"
self.request_data = json.dumps({
"username": "kermit",
- "password": "monkey"
+ "password": "monkey",
+ "device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
}, None)
- self.registration_handler.register = Mock(return_value=(user_id, token))
+ self.registration_handler.register = Mock(return_value=(user_id, None))
+ self.auth_handler.get_login_tuple_for_user_id = Mock(
+ return_value=(token, "kermits_refresh_token")
+ )
+ self.device_handler.check_device_registered = \
+ Mock(return_value=device_id)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "refresh_token": "kermits_refresh_token",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
+ self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
index f22ba8db89..38556da9a7 100644
--- a/tests/storage/event_injector.py
+++ b/tests/storage/event_injector.py
@@ -30,6 +30,7 @@ class EventInjector:
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
+ "sender": "",
"room_id": room.to_string(),
"content": {},
})
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 6e4d9b1373..1286b4ce2d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update", self.update_handler
)
+ # run the real background updates, to get them out the way
+ # (perhaps we should run them as part of the test HS setup, since we
+ # run all of the other schema setup stuff there?)
+ while True:
+ res = yield self.store.do_next_background_update(1000)
+ if res is None:
+ break
+
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000
duration_ms = 42
+ # first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
@@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
+ # second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
-
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 2}, desired_count
)
+ # third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
new file mode 100644
index 0000000000..1f0c0e7c37
--- /dev/null
+++ b/tests/storage/test_client_ips.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.server
+import synapse.storage
+import synapse.types
+import tests.unittest
+import tests.utils
+
+
+class ClientIpStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+ self.clock = None # type: tests.utils.MockClock
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def test_insert_new_client_ip(self):
+ self.clock.now = 12345678
+ user_id = "@user:id"
+ yield self.store.insert_client_ip(
+ synapse.types.UserID.from_string(user_id),
+ "access_token", "ip", "user_agent", "device_id",
+ )
+
+ # deliberately use an iterable here to make sure that the lookup
+ # method doesn't iterate it twice
+ device_list = iter(((user_id, "device_id"),))
+ result = yield self.store.get_last_client_ip_by_device(device_list)
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 12345678000,
+ },
+ r
+ )
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
new file mode 100644
index 0000000000..f8725acea0
--- /dev/null
+++ b/tests/storage/test_devices.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import tests.unittest
+import tests.utils
+
+
+class DeviceStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_store_new_device(self):
+ yield self.store.store_device(
+ "user_id", "device_id", "display_name"
+ )
+
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device_id",
+ "display_name": "display_name",
+ }, res)
+
+ @defer.inlineCallbacks
+ def test_get_devices_by_user(self):
+ yield self.store.store_device(
+ "user_id", "device1", "display_name 1"
+ )
+ yield self.store.store_device(
+ "user_id", "device2", "display_name 2"
+ )
+ yield self.store.store_device(
+ "user_id2", "device3", "display_name 3"
+ )
+
+ res = yield self.store.get_devices_by_user("user_id")
+ self.assertEqual(2, len(res.keys()))
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device1",
+ "display_name": "display_name 1",
+ }, res["device1"])
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device2",
+ "display_name": "display_name 2",
+ }, res["device2"])
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self.store.store_device(
+ "user_id", "device_id", "display_name 1"
+ )
+
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do a no-op first
+ yield self.store.update_device(
+ "user_id", "device_id",
+ )
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do the update
+ yield self.store.update_device(
+ "user_id", "device_id",
+ new_display_name="display_name 2",
+ )
+
+ # check it worked
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 2", res["display_name"])
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ with self.assertRaises(synapse.api.errors.StoreError) as cm:
+ yield self.store.update_device(
+ "user_id", "unknown_device_id",
+ new_display_name="display_name 2",
+ )
+ self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
new file mode 100644
index 0000000000..453bc61438
--- /dev/null
+++ b/tests/storage/test_end_to_end_keys.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_key_without_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": None,
+ }, dev)
+
+ @defer.inlineCallbacks
+ def test_get_key_with_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+ yield self.store.store_device(
+ "user", "device", "display_name"
+ )
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": "display_name",
+ }, dev)
+
+ @defer.inlineCallbacks
+ def test_multiple_devices(self):
+ now = 1470174257070
+
+ yield self.store.set_e2e_device_keys(
+ "user1", "device1", now, 'json11')
+ yield self.store.set_e2e_device_keys(
+ "user1", "device2", now, 'json12')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device1", now, 'json21')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device2", now, 'json22')
+
+ res = yield self.store.get_e2e_device_keys((("user1", "device1"),
+ ("user2", "device2")))
+ self.assertIn("user1", res)
+ self.assertIn("device1", res["user1"])
+ self.assertNotIn("device2", res["user1"])
+ self.assertIn("user2", res)
+ self.assertNotIn("device1", res["user2"])
+ self.assertIn("device2", res["user2"])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
new file mode 100644
index 0000000000..e9044afa2e
--- /dev/null
+++ b/tests/storage/test_event_push_actions.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+USER_ID = "@user:example.com"
+
+
+class EventPushActionsStoreTestCase(tests.unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_get_unread_push_actions_for_user_in_range_for_http(self):
+ yield self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
+
+ @defer.inlineCallbacks
+ def test_get_unread_push_actions_for_user_in_range_for_email(self):
+ yield self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 18a6cff0c7..3762b38e37 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_count_daily_messages(self):
- self.db_pool.runQuery("DELETE FROM stats_reporting")
+ yield self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
@@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase):
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(1, self.hs.clock.now)
+ yield self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
@@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
- self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
+ yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(4, self.hs.clock.now)
+ yield self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(5, self.hs.clock.now)
+ yield self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
@@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
- self._assert_stats_reporting(8, self.hs.clock.now)
+ yield self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index b8384c98d8..f7d74dea8e 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"BcDeFgHiJkLmNoPqRsTuVwXyZa"
]
self.pwhash = "{xx1}123456789"
+ self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self):
@@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
- yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
+ yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+ self.device_id)
result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset(
{
"name": self.user_id,
+ "device_id": self.device_id,
},
result
)
@@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32)
+ device_id = stringutils.random_string(16)
generator = TokenGenerator()
last_token = generator.generate(uid)
self.db_pool.runQuery(
- "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
- (uid, last_token,))
+ "INSERT INTO refresh_tokens(user_id, token, device_id) "
+ "VALUES(?,?,?)",
+ (uid, last_token, device_id))
- (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
- last_token, generator.generate)
+ (found_user_id, refresh_token, device_id) = \
+ yield self.store.exchange_refresh_token(last_token,
+ generator.generate)
self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery(
- "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
- self.assertEqual([(refresh_token,)], rows)
+ "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?",
+ (uid, ))
+ self.assertEqual([(refresh_token, device_id)], rows)
# We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token)
@@ -121,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
+ @defer.inlineCallbacks
+ def test_user_delete_access_tokens(self):
+ # add some tokens
+ generator = TokenGenerator()
+ refresh_token = generator.generate(self.user_id)
+ yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+ self.device_id)
+ yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
+ self.device_id)
+
+ # now delete some
+ yield self.store.user_delete_access_tokens(
+ self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+
+ # check they were deleted
+ user = yield self.store.get_user_by_access_token(self.tokens[1])
+ self.assertIsNone(user, "access token was not deleted by device_id")
+ with self.assertRaises(StoreError):
+ yield self.store.exchange_refresh_token(refresh_token,
+ generator.generate)
+
+ # check the one not associated with the device was not deleted
+ user = yield self.store.get_user_by_access_token(self.tokens[0])
+ self.assertEqual(self.user_id, user["name"])
+
+ # now delete the rest
+ yield self.store.user_delete_access_tokens(
+ self.user_id, delete_refresh_tokens=True)
+
+ user = yield self.store.get_user_by_access_token(self.tokens[0])
+ self.assertIsNone(user,
+ "access token was not deleted without device_id")
+
class TokenGenerator:
def __init__(self):
diff --git a/tests/test_preview.py b/tests/test_preview.py
new file mode 100644
index 0000000000..2a801173a0
--- /dev/null
+++ b/tests/test_preview.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import unittest
+
+from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
+
+
+class PreviewTestCase(unittest.TestCase):
+
+ def test_long_summarize(self):
+ example_paras = [
+ """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
+ Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
+ Troms county, Norway. The administrative centre of the municipality is
+ the city of Tromsø. Outside of Norway, Tromso and Tromsö are
+ alternative spellings of the city.Tromsø is considered the northernmost
+ city in the world with a population above 50,000. The most populous town
+ north of it is Alta, Norway, with a population of 14,272 (2013).""",
+
+ """Tromsø lies in Northern Norway. The municipality has a population of
+ (2015) 72,066, but with an annual influx of students it has over 75,000
+ most of the year. It is the largest urban area in Northern Norway and the
+ third largest north of the Arctic Circle (following Murmansk and Norilsk).
+ Most of Tromsø, including the city centre, is located on the island of
+ Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,
+ Tromsøya had a population of 36,088. Substantial parts of the urban area
+ are also situated on the mainland to the east, and on parts of Kvaløya—a
+ large island to the west. Tromsøya is connected to the mainland by the Tromsø
+ Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the
+ Sandnessund Bridge. Tromsø Airport connects the city to many destinations
+ in Europe. The city is warmer than most other places located on the same
+ latitude, due to the warming effect of the Gulf Stream.""",
+
+ """The city centre of Tromsø contains the highest number of old wooden
+ houses in Northern Norway, the oldest house dating from 1789. The Arctic
+ Cathedral, a modern church from 1965, is probably the most famous landmark
+ in Tromsø. The city is a cultural centre for its region, with several
+ festivals taking place in the summer. Some of Norway's best-known
+ musicians, Torbjørn Brundtland and Svein Berge of the electronica duo
+ Röyksopp and Lene Marlin grew up and started their careers in Tromsø.
+ Noted electronic musician Geir Jenssen also hails from Tromsø.""",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ " Troms county, Norway. The administrative centre of the municipality is"
+ " the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
+ " alternative spellings of the city.Tromsø is considered the northernmost"
+ " city in the world with a population above 50,000. The most populous town"
+ " north of it is Alta, Norway, with a population of 14,272 (2013)."
+ )
+
+ desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ "Tromsø lies in Northern Norway. The municipality has a population of"
+ " (2015) 72,066, but with an annual influx of students it has over 75,000"
+ " most of the year. It is the largest urban area in Northern Norway and the"
+ " third largest north of the Arctic Circle (following Murmansk and Norilsk)."
+ " Most of Tromsø, including the city centre, is located on the island of"
+ " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
+ " Tromsøya had a population of 36,088. Substantial parts of the…"
+ )
+
+ def test_short_summarize(self):
+ example_paras = [
+ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ " Troms county, Norway.",
+
+ "Tromsø lies in Northern Norway. The municipality has a population of"
+ " (2015) 72,066, but with an annual influx of students it has over 75,000"
+ " most of the year.",
+
+ "The city centre of Tromsø contains the highest number of old wooden"
+ " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+ " Cathedral, a modern church from 1965, is probably the most famous landmark"
+ " in Tromsø.",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ " Troms county, Norway.\n"
+ "\n"
+ "Tromsø lies in Northern Norway. The municipality has a population of"
+ " (2015) 72,066, but with an annual influx of students it has over 75,000"
+ " most of the year."
+ )
+
+ def test_small_then_large_summarize(self):
+ example_paras = [
+ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ " Troms county, Norway.",
+
+ "Tromsø lies in Northern Norway. The municipality has a population of"
+ " (2015) 72,066, but with an annual influx of students it has over 75,000"
+ " most of the year."
+ " The city centre of Tromsø contains the highest number of old wooden"
+ " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+ " Cathedral, a modern church from 1965, is probably the most famous landmark"
+ " in Tromsø.",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+ self.assertEquals(
+ desc,
+ "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ " Troms county, Norway.\n"
+ "\n"
+ "Tromsø lies in Northern Norway. The municipality has a population of"
+ " (2015) 72,066, but with an annual influx of students it has over 75,000"
+ " most of the year. The city centre of Tromsø contains the highest number"
+ " of old wooden houses in Northern Norway, the oldest house dating from"
+ " 1789. The Arctic Cathedral, a modern church…"
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index 5b22abfe74..38715972dd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,13 +17,18 @@ from twisted.trial import unittest
import logging
-
# logging doesn't have a "don't log anything at all EVARRRR setting,
# but since the highest value is 50, 1000000 should do ;)
NEVER = 1000000
-logging.getLogger().addHandler(logging.StreamHandler())
+handler = logging.StreamHandler()
+handler.setFormatter(logging.Formatter(
+ "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
+))
+logging.getLogger().addHandler(handler)
logging.getLogger().setLevel(NEVER)
+logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
+logging.getLogger("synapse.storage.txn").setLevel(NEVER)
def around(target):
@@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
return ret
logging.getLogger().setLevel(level)
- # Don't set SQL logging
- logging.getLogger("synapse.storage").setLevel(old_level)
return orig()
def assertObjectHasAttributes(self, attrs, obj):
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
new file mode 100644
index 0000000000..1d745ae1a7
--- /dev/null
+++ b/tests/util/test_rwlock.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from synapse.util.async import ReadWriteLock
+
+
+class ReadWriteLockTestCase(unittest.TestCase):
+
+ def _assert_called_before_not_after(self, lst, first_false):
+ for i, d in enumerate(lst[:first_false]):
+ self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+
+ for i, d in enumerate(lst[first_false:]):
+ self.assertFalse(
+ d.called, msg="%d was unexpectedly true" % (i + first_false)
+ )
+
+ def test_rwlock(self):
+ rwlock = ReadWriteLock()
+
+ key = object()
+
+ ds = [
+ rwlock.read(key), # 0
+ rwlock.read(key), # 1
+ rwlock.write(key), # 2
+ rwlock.write(key), # 3
+ rwlock.read(key), # 4
+ rwlock.read(key), # 5
+ rwlock.write(key), # 6
+ ]
+
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[0].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[1].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 3)
+
+ with ds[2].result:
+ self._assert_called_before_not_after(ds, 3)
+ self._assert_called_before_not_after(ds, 4)
+
+ with ds[3].result:
+ self._assert_called_before_not_after(ds, 4)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[5].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[4].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 7)
+
+ with ds[6].result:
+ pass
+
+ d = rwlock.write(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
+
+ d = rwlock.read(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index 6e41ae1ff6..915b934e94 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
-from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -56,6 +55,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
+ config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -511,7 +511,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
-
-
-def requester_for_user(user):
- return Requester(user, None, False)
|