diff --git a/AUTHORS.rst b/AUTHORS.rst
index 3a457cd9fc..54ced67000 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -38,3 +38,10 @@ Brabo <brabo at riseup.net>
Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration
+
+Eric Myhre <hash at exultant.us>
+ * Fix bug where ``media_store_path`` config option was ignored by v0 content
+ repository API.
+
+Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
+ * Add SAML2 support for registration and logins.
diff --git a/CHANGES.rst b/CHANGES.rst
index 1ca2407a73..6a5fce899a 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,54 @@
+Changes in synapse v0.9.3 (2015-07-01)
+======================================
+
+No changes from v0.9.3 Release Candidate 1.
+
+Changes in synapse v0.9.3-rc1 (2015-06-23)
+==========================================
+
+General:
+
+* Fix a memory leak in the notifier. (SYN-412)
+* Improve performance of room initial sync. (SYN-418)
+* General improvements to logging.
+* Remove ``access_token`` query params from ``INFO`` level logging.
+
+Configuration:
+
+* Add support for specifying and configuring multiple listeners. (SYN-389)
+
+Application services:
+
+* Fix bug where synapse failed to send user queries to application services.
+
+Changes in synapse v0.9.2-r2 (2015-06-15)
+=========================================
+
+Fix packaging so that schema delta python files get included in the package.
+
+Changes in synapse v0.9.2 (2015-06-12)
+======================================
+
+General:
+
+* Use ultrajson for json (de)serialisation when a canonical encoding is not
+ required. Ultrajson is significantly faster than simplejson in certain
+ circumstances.
+* Use connection pools for outgoing HTTP connections.
+* Process thumbnails on separate threads.
+
+Configuration:
+
+* Add option, ``gzip_responses``, to disable HTTP response compression.
+
+Federation:
+
+* Improve resilience of backfill by ensuring we fetch any missing auth events.
+* Improve performance of backfill and joining remote rooms by removing
+ unnecessary computations. This included handling events we'd previously
+ handled as well as attempting to compute the current state for outliers.
+
+
Changes in synapse v0.9.1 (2015-05-26)
======================================
diff --git a/MANIFEST.in b/MANIFEST.in
index 8243a942ee..a9b543af82 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -5,6 +5,7 @@ include *.rst
include demo/README
recursive-include synapse/storage/schema *.sql
+recursive-include synapse/storage/schema *.py
recursive-include demo *.dh
recursive-include demo *.py
diff --git a/README.rst b/README.rst
index 259fbaf459..5ff53f2df7 100644
--- a/README.rst
+++ b/README.rst
@@ -101,36 +101,40 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian::
- $ sudo apt-get install build-essential python2.7-dev libffi-dev \
- python-pip python-setuptools sqlite3 \
- libssl-dev python-virtualenv libjpeg-dev
+ sudo apt-get install build-essential python2.7-dev libffi-dev \
+ python-pip python-setuptools sqlite3 \
+ libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
- $ sudo pacman -S base-devel python2 python-pip \
- python-setuptools python-virtualenv sqlite3
+ sudo pacman -S base-devel python2 python-pip \
+ python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X::
- $ xcode-select --install
- $ sudo pip install virtualenv
+ xcode-select --install
+ sudo easy_install pip
+ sudo pip install virtualenv
To install the synapse homeserver run::
- $ virtualenv -p python2.7 ~/.synapse
- $ source ~/.synapse/bin/activate
- $ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
+ virtualenv -p python2.7 ~/.synapse
+ source ~/.synapse/bin/activate
+ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual
-environment under ``~/.synapse``.
+environment under ``~/.synapse``. Feel free to pick a different directory
+if you prefer.
+
+In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before)::
- $ cd ~/.synapse
- $ python -m synapse.app.homeserver \
+ cd ~/.synapse
+ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
@@ -189,9 +193,9 @@ Running Synapse
To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
- $ cd ~/.synapse
- $ source ./bin/activate
- $ synctl start
+ cd ~/.synapse
+ source ./bin/activate
+ synctl start
Platform Specific Instructions
==============================
@@ -209,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
- $ sudo pip2.7 install --upgrade pip
+ sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
- $ pip2.7 install --process-dependency-links \
+ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
@@ -222,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
- $ sudo pip2.7 uninstall py-bcrypt
- $ sudo pip2.7 install py-bcrypt
+ sudo pip2.7 uninstall py-bcrypt
+ sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again::
- $ cd ~/.synapse
- $ python2.7 -m synapse.app.homeserver \
+ cd ~/.synapse
+ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
@@ -276,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
- $ sudo pip install --upgrade pip
+ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
- $ rm -rf /tmp/pip_install_matrix
+ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
- $ pip install twisted
+ pip install twisted
-On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
+On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running
@@ -307,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl)::
- $ # Install from PyPI
- $ pip install --user --upgrade --force pynacl
- $ # Install from github
- $ pip install --user https://github.com/pyca/pynacl/tarball/master
+ # Install from PyPI
+ pip install --user --upgrade --force pynacl
+
+ # Install from github
+ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
~~~~~~~~~
@@ -318,7 +323,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as::
- $ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
+ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable.
@@ -328,16 +333,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working
directory of your choice::
- $ git clone https://github.com/matrix-org/synapse.git
- $ cd synapse
+ git clone https://github.com/matrix-org/synapse.git
+ cd synapse
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
- $ virtualenv env
- $ source env/bin/activate
- $ python synapse/python_dependencies.py | xargs -n1 pip install
- $ pip install setuptools_trial mock
+ virtualenv env
+ source env/bin/activate
+ python synapse/python_dependencies.py | xargs -n1 pip install
+ pip install setuptools_trial mock
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@@ -345,7 +350,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be::
- $ python setup.py test
+ python setup.py test
This should end with a 'PASSED' result::
@@ -386,11 +391,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter::
- $ python -m synapse.app.homeserver \
+ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
- $ python -m synapse.app.homeserver --config-path homeserver.yaml
+ python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process.
@@ -407,11 +412,11 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have::
- $ python -m synapse.app.homeserver \
+ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--config-path homeserver.yaml \
--generate-config
- $ python -m synapse.app.homeserver --config-path homeserver.yaml
+ python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to
@@ -425,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run::
- $ demo/start.sh
+ demo/start.sh
This is mainly useful just for development purposes.
@@ -499,10 +504,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and
sphinxcontrib-napoleon::
- $ pip install sphinx
- $ pip install sphinxcontrib-napoleon
+ pip install sphinx
+ pip install sphinxcontrib-napoleon
Building internal API documentation::
- $ python setup.py build_sphinx
+ python setup.py build_sphinx
diff --git a/demo/clean.sh b/demo/clean.sh
index c5dabd4767..418ca9457e 100755
--- a/demo/clean.sh
+++ b/demo/clean.sh
@@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1
fi
-find "$DIR" -name "*.log" -delete
-find "$DIR" -name "*.db" -delete
+for port in 8080 8081 8082; do
+ rm -rf $DIR/$port
+ rm -rf $DIR/media_store.$port
+done
rm -rf $DIR/etc
diff --git a/demo/start.sh b/demo/start.sh
index b9cc14b9d2..b5dea5e176 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc
-# Check the --no-rate-limit param
-PARAMS=""
-if [ $# -eq 1 ]; then
- if [ $1 = "--no-rate-limit" ]; then
- PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
- fi
-fi
-
export PYTHONPATH=$(readlink -f $(pwd))
@@ -35,6 +27,15 @@ for port in 8080 8081 8082; do
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
+ # Check script parameters
+ if [ $# -eq 1 ]; then
+ if [ $1 = "--no-rate-limit" ]; then
+ # Set high limits in config file to disable rate limiting
+ perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
+ perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
+ fi
+ fi
+
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 4720d99848..96e37308d6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.9.1"
+__version__ = "0.9.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d5bf0be85c..a7f428a96c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules,
+ EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
)
@@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
+ Args:
+ event: the event being checked.
+ auth_events (dict: event-key -> event): the existing room state.
+
+
Returns:
True if the auth checks pass.
"""
@@ -187,6 +192,9 @@ class Auth(object):
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
+ target_level = self._get_user_power_level(
+ target_user_id, auth_events
+ )
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
@@ -258,12 +266,12 @@ class Auth(object):
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
- if user_level < kick_level:
+ if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- if user_level < ban_level:
+ if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@@ -316,7 +324,7 @@ class Auth(object):
Returns:
tuple : of UserID and device string:
User ID object of the user making the request
- Client ID object of the client instance the user is using
+ ClientInfo object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -349,7 +357,7 @@ class Auth(object):
)
return
except KeyError:
- pass # normal users won't have this query parameter set
+ pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
@@ -370,6 +378,8 @@ class Auth(object):
user_agent=user_agent
)
+ request.authenticated_entity = user.to_string()
+
defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError:
raise AuthError(
@@ -516,23 +526,22 @@ class Auth(object):
# Check state_key
if hasattr(event, "state_key"):
- if not event.state_key.startswith("_"):
- if event.state_key.startswith("@"):
- if event.state_key != event.user_id:
+ if event.state_key.startswith("@"):
+ if event.state_key != event.user_id:
+ raise AuthError(
+ 403,
+ "You are not allowed to set others state"
+ )
+ else:
+ sender_domain = UserID.from_string(
+ event.user_id
+ ).domain
+
+ if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
- else:
- sender_domain = UserID.from_string(
- event.user_id
- ).domain
-
- if sender_domain != event.state_key:
- raise AuthError(
- 403,
- "You are not allowed to set others state"
- )
return True
@@ -571,25 +580,26 @@ class Auth(object):
# Check other levels:
levels_to_check = [
- ("users_default", []),
- ("events_default", []),
- ("ban", []),
- ("redact", []),
- ("kick", []),
- ("invite", []),
+ ("users_default", None),
+ ("events_default", None),
+ ("state_default", None),
+ ("ban", None),
+ ("redact", None),
+ ("kick", None),
+ ("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
- (user, ["users"])
+ (user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
- (ev_id, ["events"])
+ (ev_id, "events")
)
old_state = current_state.content
@@ -597,12 +607,10 @@ class Auth(object):
for level_to_check, dir in levels_to_check:
old_loc = old_state
- for d in dir:
- old_loc = old_loc.get(d, {})
-
new_loc = new_state
- for d in dir:
- new_loc = new_loc.get(d, {})
+ if dir:
+ old_loc = old_loc.get(dir, {})
+ new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
@@ -618,6 +626,14 @@ class Auth(object):
if new_level == old_level:
continue
+ if dir == "users" and level_to_check != event.user_id:
+ if old_level == user_level:
+ raise AuthError(
+ 403,
+ "You don't have permission to remove ops level equal "
+ "to your own"
+ )
+
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index d8a18ee87b..7156ee4e7d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -75,6 +75,8 @@ class EventTypes(object):
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
+ RoomHistoryVisibility = "m.room.history_visibility"
+
# These are used for validation
Message = "m.room.message"
Topic = "m.room.topic"
@@ -85,3 +87,8 @@ class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"
+
+
+class RoomCreationPreset(object):
+ PRIVATE_CHAT = "private_chat"
+ PUBLIC_CHAT = "public_chat"
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d93afdc1c2..f04493f92a 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -34,8 +34,7 @@ from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
-from twisted.web.server import Site, GzipEncoderFactory
-from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
+from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@@ -61,11 +60,13 @@ import twisted.manhole.telnet
import synapse
+import contextlib
import logging
import os
import re
import resource
import subprocess
+import time
logger = logging.getLogger("synapse.app.homeserver")
@@ -87,10 +88,10 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
- return gz_wrap(ClientV1RestResource(self))
+ return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self):
- return gz_wrap(ClientV2AlphaRestResource(self))
+ return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
@@ -113,7 +114,7 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_content_repo(self):
return ContentRepoResource(
- self, self.upload_dir, self.auth, self.content_addr
+ self, self.config.uploads_path, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
@@ -139,152 +140,105 @@ class SynapseHomeServer(HomeServer):
**self.db_config.get("args", {})
)
- def create_resource_tree(self, redirect_root_to_web_client):
- """Create the resource tree for this Home Server.
+ def _listener_http(self, config, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ tls = listener_config.get("tls", False)
+ site_tag = listener_config.get("tag", port)
- This in unduly complicated because Twisted does not support putting
- child resources more than 1 level deep at a time.
-
- Args:
- web_client (bool): True to enable the web client.
- redirect_root_to_web_client (bool): True to redirect '/' to the
- location of the web client. This does nothing if web_client is not
- True.
- """
- config = self.get_config()
- web_client = config.web_client
-
- # list containing (path_str, Resource) e.g:
- # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
- desired_tree = [
- (CLIENT_PREFIX, self.get_resource_for_client()),
- (CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
- (FEDERATION_PREFIX, self.get_resource_for_federation()),
- (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
- (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
- (SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
- (MEDIA_PREFIX, self.get_resource_for_media_repository()),
- (STATIC_PREFIX, self.get_resource_for_static_content()),
- ]
-
- if web_client:
- logger.info("Adding the web client.")
- desired_tree.append((WEB_CLIENT_PREFIX,
- self.get_resource_for_web_client()))
-
- if web_client and redirect_root_to_web_client:
- self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
- else:
- self.root_resource = Resource()
+ if tls and config.no_tls:
+ return
metrics_resource = self.get_resource_for_metrics()
- if config.metrics_port is None and metrics_resource is not None:
- desired_tree.append((METRICS_PREFIX, metrics_resource))
-
- # ideally we'd just use getChild and putChild but getChild doesn't work
- # unless you give it a Request object IN ADDITION to the name :/ So
- # instead, we'll store a copy of this mapping so we can actually add
- # extra resources to existing nodes. See self._resource_id for the key.
- resource_mappings = {}
- for full_path, res in desired_tree:
- logger.info("Attaching %s to path %s", res, full_path)
- last_resource = self.root_resource
- for path_seg in full_path.split('/')[1:-1]:
- if path_seg not in last_resource.listNames():
- # resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
- last_resource.putChild(path_seg, child_resource)
- res_id = self._resource_id(last_resource, path_seg)
- resource_mappings[res_id] = child_resource
- last_resource = child_resource
- else:
- # we have an existing Resource, use that instead.
- res_id = self._resource_id(last_resource, path_seg)
- last_resource = resource_mappings[res_id]
-
- # ===========================
- # now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
-
- # if there is already a resource here, thieve its children and
- # replace it
- res_id = self._resource_id(last_resource, last_path_seg)
- if res_id in resource_mappings:
- # there is a dummy resource at this path already, which needs
- # to be replaced with the desired resource.
- existing_dummy_resource = resource_mappings[res_id]
- for child_name in existing_dummy_resource.listNames():
- child_res_id = self._resource_id(existing_dummy_resource,
- child_name)
- child_resource = resource_mappings[child_res_id]
- # steal the children
- res.putChild(child_name, child_resource)
-
- # finally, insert the desired resource in the right place
- last_resource.putChild(last_path_seg, res)
- res_id = self._resource_id(last_resource, last_path_seg)
- resource_mappings[res_id] = res
-
- return self.root_resource
-
- def _resource_id(self, resource, path_seg):
- """Construct an arbitrary resource ID so you can retrieve the mapping
- later.
-
- If you want to represent resource A putChild resource B with path C,
- the mapping should looks like _resource_id(A,C) = B.
-
- Args:
- resource (Resource): The *parent* Resource
- path_seg (str): The name of the child Resource to be attached.
- Returns:
- str: A unique string which can be a key to the child Resource.
- """
- return "%s-%s" % (resource, path_seg)
- def start_listening(self):
- config = self.get_config()
-
- if not config.no_tls and config.bind_port is not None:
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "client":
+ if res["compress"]:
+ client_v1 = gz_wrap(self.get_resource_for_client())
+ client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
+ else:
+ client_v1 = self.get_resource_for_client()
+ client_v2 = self.get_resource_for_client_v2_alpha()
+
+ resources.update({
+ CLIENT_PREFIX: client_v1,
+ CLIENT_V2_ALPHA_PREFIX: client_v2,
+ })
+
+ if name == "federation":
+ resources.update({
+ FEDERATION_PREFIX: self.get_resource_for_federation(),
+ })
+
+ if name in ["static", "client"]:
+ resources.update({
+ STATIC_PREFIX: self.get_resource_for_static_content(),
+ })
+
+ if name in ["media", "federation", "client"]:
+ resources.update({
+ MEDIA_PREFIX: self.get_resource_for_media_repository(),
+ CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
+ })
+
+ if name in ["keys", "federation"]:
+ resources.update({
+ SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
+ SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
+ })
+
+ if name == "webclient":
+ resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
+
+ if name == "metrics" and metrics_resource:
+ resources[METRICS_PREFIX] = metrics_resource
+
+ root_resource = create_resource_tree(resources)
+ if tls:
reactor.listenSSL(
- config.bind_port,
+ port,
SynapseSite(
- "synapse.access.https",
- config,
- self.root_resource,
+ "synapse.access.https.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
),
self.tls_context_factory,
- interface=config.bind_host
+ interface=bind_address
)
- logger.info("Synapse now listening on port %d", config.bind_port)
-
- if config.unsecure_port is not None:
+ else:
reactor.listenTCP(
- config.unsecure_port,
+ port,
SynapseSite(
- "synapse.access.http",
- config,
- self.root_resource,
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
),
- interface=config.bind_host
+ interface=bind_address
)
- logger.info("Synapse now listening on port %d", config.unsecure_port)
+ logger.info("Synapse now listening on port %d", port)
- metrics_resource = self.get_resource_for_metrics()
- if metrics_resource and config.metrics_port is not None:
- reactor.listenTCP(
- config.metrics_port,
- SynapseSite(
- "synapse.access.metrics",
- config,
- metrics_resource,
- ),
- interface=config.metrics_bind_host,
- )
- logger.info(
- "Metrics now running on %s port %d",
- config.metrics_bind_host, config.metrics_port,
- )
+ def start_listening(self):
+ config = self.get_config()
+
+ for listener in config.listeners:
+ if listener["type"] == "http":
+ self._listener_http(config, listener)
+ elif listener["type"] == "manhole":
+ f = twisted.manhole.telnet.ShellFactory()
+ f.username = "matrix"
+ f.password = "rabbithole"
+ f.namespace['hs'] = self
+ reactor.listenTCP(
+ listener["port"],
+ f,
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
@@ -419,11 +373,6 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- if re.search(":[0-9]+$", config.server_name):
- domain_with_port = config.server_name
- else:
- domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
-
tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
@@ -431,8 +380,6 @@ def setup(config_options):
hs = SynapseHomeServer(
config.server_name,
- domain_with_port=domain_with_port,
- upload_dir=os.path.abspath("uploads"),
db_config=config.database_config,
tls_context_factory=tls_context_factory,
config=config,
@@ -441,10 +388,6 @@ def setup(config_options):
database_engine=database_engine,
)
- hs.create_resource_tree(
- redirect_root_to_web_client=True,
- )
-
logger.info("Preparing database: %r...", config.database_config)
try:
@@ -469,13 +412,6 @@ def setup(config_options):
logger.info("Database prepared in %r.", config.database_config)
- if config.manhole:
- f = twisted.manhole.telnet.ShellFactory()
- f.username = "matrix"
- f.password = "rabbithole"
- f.namespace['hs'] = hs
- reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
-
hs.start_listening()
hs.get_pusherpool().start()
@@ -501,22 +437,194 @@ class SynapseService(service.Service):
return self._port.stopListening()
+class SynapseRequest(Request):
+ def __init__(self, site, *args, **kw):
+ Request.__init__(self, *args, **kw)
+ self.site = site
+ self.authenticated_entity = None
+ self.start_time = 0
+
+ def __repr__(self):
+ # We overwrite this so that we don't log ``access_token``
+ return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ self.__class__.__name__,
+ id(self),
+ self.method,
+ self.get_redacted_uri(),
+ self.clientproto,
+ self.site.site_tag,
+ )
+
+ def get_redacted_uri(self):
+ return re.sub(
+ r'(\?.*access_token=)[^&]*(.*)$',
+ r'\1<redacted>\2',
+ self.uri
+ )
+
+ def get_user_agent(self):
+ return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+
+ def started_processing(self):
+ self.site.access_logger.info(
+ "%s - %s - Received request: %s %s",
+ self.getClientIP(),
+ self.site.site_tag,
+ self.method,
+ self.get_redacted_uri()
+ )
+ self.start_time = int(time.time() * 1000)
+
+ def finished_processing(self):
+ self.site.access_logger.info(
+ "%s - %s - {%s}"
+ " Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
+ self.getClientIP(),
+ self.site.site_tag,
+ self.authenticated_entity,
+ int(time.time() * 1000) - self.start_time,
+ self.sentLength,
+ self.code,
+ self.method,
+ self.get_redacted_uri(),
+ self.clientproto,
+ self.get_user_agent(),
+ )
+
+ @contextlib.contextmanager
+ def processing(self):
+ self.started_processing()
+ yield
+ self.finished_processing()
+
+
+class XForwardedForRequest(SynapseRequest):
+ def __init__(self, *args, **kw):
+ SynapseRequest.__init__(self, *args, **kw)
+
+ """
+ Add a layer on top of another request that only uses the value of an
+ X-Forwarded-For header as the result of C{getClientIP}.
+ """
+ def getClientIP(self):
+ """
+ @return: The client address (the first address) in the value of the
+ I{X-Forwarded-For header}. If the header is not present, return
+ C{b"-"}.
+ """
+ return self.requestHeaders.getRawHeaders(
+ b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
+
+
+class SynapseRequestFactory(object):
+ def __init__(self, site, x_forwarded_for):
+ self.site = site
+ self.x_forwarded_for = x_forwarded_for
+
+ def __call__(self, *args, **kwargs):
+ if self.x_forwarded_for:
+ return XForwardedForRequest(self.site, *args, **kwargs)
+ else:
+ return SynapseRequest(self.site, *args, **kwargs)
+
+
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, config, resource, *args, **kwargs):
+ def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
- if config.captcha_ip_origin_is_x_forwarded:
- self._log_formatter = proxiedLogFormatter
- else:
- self._log_formatter = combinedLogFormatter
+
+ self.site_tag = site_tag
+
+ proxied = config.get("x_forwarded", False)
+ self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
- line = self._log_formatter(self._logDateTime, request)
- self.access_logger.info(line)
+ pass
+
+
+def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
+ """Create the resource tree for this Home Server.
+
+ This in unduly complicated because Twisted does not support putting
+ child resources more than 1 level deep at a time.
+
+ Args:
+ web_client (bool): True to enable the web client.
+ redirect_root_to_web_client (bool): True to redirect '/' to the
+ location of the web client. This does nothing if web_client is not
+ True.
+ """
+ if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
+ root_resource = RootRedirect(WEB_CLIENT_PREFIX)
+ else:
+ root_resource = Resource()
+
+ # ideally we'd just use getChild and putChild but getChild doesn't work
+ # unless you give it a Request object IN ADDITION to the name :/ So
+ # instead, we'll store a copy of this mapping so we can actually add
+ # extra resources to existing nodes. See self._resource_id for the key.
+ resource_mappings = {}
+ for full_path, res in desired_tree.items():
+ logger.info("Attaching %s to path %s", res, full_path)
+ last_resource = root_resource
+ for path_seg in full_path.split('/')[1:-1]:
+ if path_seg not in last_resource.listNames():
+ # resource doesn't exist, so make a "dummy resource"
+ child_resource = Resource()
+ last_resource.putChild(path_seg, child_resource)
+ res_id = _resource_id(last_resource, path_seg)
+ resource_mappings[res_id] = child_resource
+ last_resource = child_resource
+ else:
+ # we have an existing Resource, use that instead.
+ res_id = _resource_id(last_resource, path_seg)
+ last_resource = resource_mappings[res_id]
+
+ # ===========================
+ # now attach the actual desired resource
+ last_path_seg = full_path.split('/')[-1]
+
+ # if there is already a resource here, thieve its children and
+ # replace it
+ res_id = _resource_id(last_resource, last_path_seg)
+ if res_id in resource_mappings:
+ # there is a dummy resource at this path already, which needs
+ # to be replaced with the desired resource.
+ existing_dummy_resource = resource_mappings[res_id]
+ for child_name in existing_dummy_resource.listNames():
+ child_res_id = _resource_id(
+ existing_dummy_resource, child_name
+ )
+ child_resource = resource_mappings[child_res_id]
+ # steal the children
+ res.putChild(child_name, child_resource)
+
+ # finally, insert the desired resource in the right place
+ last_resource.putChild(last_path_seg, res)
+ res_id = _resource_id(last_resource, last_path_seg)
+ resource_mappings[res_id] = res
+
+ return root_resource
+
+
+def _resource_id(resource, path_seg):
+ """Construct an arbitrary resource ID so you can retrieve the mapping
+ later.
+
+ If you want to represent resource A putChild resource B with path C,
+ the mapping should looks like _resource_id(A,C) = B.
+
+ Args:
+ resource (Resource): The *parent* Resource
+ path_seg (str): The name of the child Resource to be attached.
+ Returns:
+ str: A unique string which can be a key to the child Resource.
+ """
+ return "%s-%s" % (resource, path_seg)
def run(hs):
@@ -549,7 +657,8 @@ def run(hs):
if hs.config.daemonize:
- print hs.config.pid_file
+ if hs.config.print_pidfile:
+ print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index d4163d6272..73f6959959 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -139,63 +139,59 @@ class Config(object):
help="Generate a config file for the server name"
)
config_parser.add_argument(
+ "--generate-keys",
+ action="store_true",
+ help="Generate any missing key files then exit"
+ )
+ config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
+ generate_keys = config_args.generate_keys
+
if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -h SERVER_NAME"
+ " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
-
- config_dir_path = os.path.dirname(config_args.config_path[0])
- config_dir_path = os.path.abspath(config_dir_path)
-
- server_name = config_args.server_name
- if not server_name:
- print "Must specify a server_name to a generate config for."
- sys.exit(1)
(config_path,) = config_args.config_path
- if not os.path.exists(config_dir_path):
- os.makedirs(config_dir_path)
- if os.path.exists(config_path):
- print "Config file %r already exists" % (config_path,)
- yaml_config = cls.read_config_file(config_path)
- yaml_name = yaml_config["server_name"]
- if server_name != yaml_name:
- print (
- "Config file %r has a different server_name: "
- " %r != %r" % (config_path, server_name, yaml_name)
- )
+ if not os.path.exists(config_path):
+ config_dir_path = os.path.dirname(config_path)
+ config_dir_path = os.path.abspath(config_dir_path)
+
+ server_name = config_args.server_name
+ if not server_name:
+ print "Must specify a server_name to a generate config for."
sys.exit(1)
- config_bytes, config = obj.generate_config(
- config_dir_path, server_name
+ if not os.path.exists(config_dir_path):
+ os.makedirs(config_dir_path)
+ with open(config_path, "wb") as config_file:
+ config_bytes, config = obj.generate_config(
+ config_dir_path, server_name
+ )
+ obj.invoke_all("generate_files", config)
+ config_file.write(config_bytes)
+ print (
+ "A config file has been generated in %r for server name"
+ " %r with corresponding SSL keys and self-signed"
+ " certificates. Please review this file and customise it"
+ " to your needs."
+ ) % (config_path, server_name)
+ print (
+ "If this server name is incorrect, you will need to"
+ " regenerate the SSL certificates"
)
- config.update(yaml_config)
- print "Generating any missing keys for %r" % (server_name,)
- obj.invoke_all("generate_files", config)
sys.exit(0)
- with open(config_path, "wb") as config_file:
- config_bytes, config = obj.generate_config(
- config_dir_path, server_name
- )
- obj.invoke_all("generate_files", config)
- config_file.write(config_bytes)
+ else:
print (
- "A config file has been generated in %s for server name"
- " '%s' with corresponding SSL keys and self-signed"
- " certificates. Please review this file and customise it to"
- " your needs."
- ) % (config_path, server_name)
- print (
- "If this server name is incorrect, you will need to regenerate"
- " the SSL certificates"
- )
- sys.exit(0)
+ "Config file %r already exists. Generating any missing key"
+ " files."
+ ) % (config_path,)
+ generate_keys = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@@ -209,11 +205,11 @@ class Config(object):
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -h SERVER_NAME"
+ " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
- config_dir_path = os.path.dirname(config_args.config_path[0])
+ config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
@@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config")
config.update(specified_config)
+ if generate_keys:
+ obj.invoke_all("generate_files", config)
+ sys.exit(0)
+
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index ba221121cb..15a132b4e3 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -21,10 +21,6 @@ class CaptchaConfig(Config):
self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"]
- # XXX: This is used for more than just captcha
- self.captcha_ip_origin_is_x_forwarded = (
- config["captcha_ip_origin_is_x_forwarded"]
- )
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
@@ -33,20 +29,16 @@ class CaptchaConfig(Config):
## Captcha ##
# This Home Server's ReCAPTCHA public key.
- recaptcha_private_key: "YOUR_PUBLIC_KEY"
+ recaptcha_private_key: "YOUR_PRIVATE_KEY"
# This Home Server's ReCAPTCHA private key.
- recaptcha_public_key: "YOUR_PRIVATE_KEY"
+ recaptcha_public_key: "YOUR_PUBLIC_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
enable_registration_captcha: False
- # When checking captchas, use the X-Forwarded-For (XFF) header
- # as the client IP and not the actual client IP.
- captcha_ip_origin_is_x_forwarded: False
-
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index fe0ccb6eb7..d77f045406 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -25,12 +25,13 @@ from .registration import RegistrationConfig
from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
+from .saml2 import SAML2Config
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- VoipConfig, RegistrationConfig,
- MetricsConfig, AppServiceConfig, KeyConfig,):
+ VoipConfig, RegistrationConfig, MetricsConfig,
+ AppServiceConfig, KeyConfig, SAML2Config, ):
pass
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 0cfb30ce7f..ae5a691527 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -28,10 +28,4 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics
enable_metrics: False
-
- # Separate port to accept metrics requests on
- # metrics_port: 8081
-
- # Which host to bind the metric listener to
- # metrics_bind_host: 127.0.0.1
"""
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index adaf4e4bb2..6891abd71d 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -21,13 +21,18 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"])
+ self.uploads_path = self.ensure_directory(config["uploads_path"])
def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store")
+ uploads_path = self.default_path("uploads")
return """
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
+ # Directory where in-progress uploads are stored.
+ uploads_path: "%(uploads_path)s"
+
# The largest allowed upload size in bytes
max_upload_size: "10M"
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
new file mode 100644
index 0000000000..1532036876
--- /dev/null
+++ b/synapse/config/saml2.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 Ericsson
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class SAML2Config(Config):
+ """SAML2 Configuration
+ Synapse uses pysaml2 libraries for providing SAML2 support
+
+ config_path: Path to the sp_conf.py configuration file
+ idp_redirect_url: Identity provider URL which will redirect
+ the user back to /login/saml2 with proper info.
+
+ sp_conf.py file is something like:
+ https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
+
+ More information: https://pythonhosted.org/pysaml2/howto/config.html
+ """
+
+ def read_config(self, config):
+ saml2_config = config.get("saml2_config", None)
+ if saml2_config:
+ self.saml2_enabled = True
+ self.saml2_config_path = saml2_config["config_path"]
+ self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
+ else:
+ self.saml2_enabled = False
+ self.saml2_config_path = None
+ self.saml2_idp_redirect_url = None
+
+ def default_config(self, config_dir_path, server_name):
+ return """
+ # Enable SAML2 for registration and login. Uses pysaml2
+ # config_path: Path to the sp_conf.py configuration file
+ # idp_redirect_url: Identity provider URL which will redirect
+ # the user back to /login/saml2 with proper info.
+ # See pysaml2 docs for format of config.
+ #saml2_config:
+ # config_path: "%s/sp_conf.py"
+ # idp_redirect_url: "http://%s/idp"
+ """ % (config_dir_path, server_name)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 48a26c65d9..f9a3b5f15b 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -20,25 +20,98 @@ class ServerConfig(Config):
def read_config(self, config):
self.server_name = config["server_name"]
- self.bind_port = config["bind_port"]
- self.bind_host = config["bind_host"]
- self.unsecure_port = config["unsecure_port"]
- self.manhole = config.get("manhole")
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
+ self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
+ self.listeners = config.get("listeners", [])
+
+ bind_port = config.get("bind_port")
+ if bind_port:
+ self.listeners = []
+ bind_host = config.get("bind_host", "")
+ gzip_responses = config.get("gzip_responses", True)
+
+ names = ["client", "webclient"] if self.web_client else ["client"]
+
+ self.listeners.append({
+ "port": bind_port,
+ "bind_address": bind_host,
+ "tls": True,
+ "type": "http",
+ "resources": [
+ {
+ "names": names,
+ "compress": gzip_responses,
+ },
+ {
+ "names": ["federation"],
+ "compress": False,
+ }
+ ]
+ })
+
+ unsecure_port = config.get("unsecure_port", bind_port - 400)
+ if unsecure_port:
+ self.listeners.append({
+ "port": unsecure_port,
+ "bind_address": bind_host,
+ "tls": False,
+ "type": "http",
+ "resources": [
+ {
+ "names": names,
+ "compress": gzip_responses,
+ },
+ {
+ "names": ["federation"],
+ "compress": False,
+ }
+ ]
+ })
+
+ manhole = config.get("manhole")
+ if manhole:
+ self.listeners.append({
+ "port": manhole,
+ "bind_address": "127.0.0.1",
+ "type": "manhole",
+ })
+
+ metrics_port = config.get("metrics_port")
+ if metrics_port:
+ self.listeners.append({
+ "port": metrics_port,
+ "bind_address": config.get("metrics_bind_host", "127.0.0.1"),
+ "tls": False,
+ "type": "http",
+ "resources": [
+ {
+ "names": ["metrics"],
+ "compress": False,
+ },
+ ]
+ })
+
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
+ for listener in self.listeners:
+ if listener["type"] == "http" and not listener.get("tls", False):
+ unsecure_port = listener["port"]
+ break
+ else:
+ raise RuntimeError("Could not determine 'content_addr'")
+
host = self.server_name
if ':' not in host:
- host = "%s:%d" % (host, self.unsecure_port)
+ host = "%s:%d" % (host, unsecure_port)
else:
host = host.split(':')[0]
- host = "%s:%d" % (host, self.unsecure_port)
+ host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
self.content_addr = content_addr
@@ -60,18 +133,6 @@ class ServerConfig(Config):
# e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s"
- # The port to listen for HTTPS requests on.
- # For when matrix traffic is sent directly to synapse.
- bind_port: %(bind_port)s
-
- # The port to listen for HTTP requests on.
- # For when matrix traffic passes through loadbalancer that unwraps TLS.
- unsecure_port: %(unsecure_port)s
-
- # Local interface to listen on.
- # The empty string will cause synapse to listen on all interfaces.
- bind_host: ""
-
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
@@ -83,9 +144,64 @@ class ServerConfig(Config):
# hard limit.
soft_file_limit: 0
- # Turn on the twisted telnet manhole service on localhost on the given
- # port.
- #manhole: 9000
+ # List of ports that Synapse should listen on, their purpose and their
+ # configuration.
+ listeners:
+ # Main HTTPS listener
+ # For when matrix traffic is sent directly to synapse.
+ -
+ # The port to listen for HTTPS requests on.
+ port: %(bind_port)s
+
+ # Local interface to listen on.
+ # The empty string will cause synapse to listen on all interfaces.
+ bind_address: ''
+
+ # This is a 'http' listener, allows us to specify 'resources'.
+ type: http
+
+ tls: true
+
+ # Use the X-Forwarded-For (XFF) header as the client IP and not the
+ # actual client IP.
+ x_forwarded: false
+
+ # List of HTTP resources to serve on this listener.
+ resources:
+ -
+ # List of resources to host on this listener.
+ names:
+ - client # The client-server APIs, both v1 and v2
+ - webclient # The bundled webclient.
+
+ # Should synapse compress HTTP responses to clients that support it?
+ # This should be disabled if running synapse behind a load balancer
+ # that can do automatic compression.
+ compress: true
+
+ - names: [federation] # Federation APIs
+ compress: false
+
+ # Unsecure HTTP listener,
+ # For when matrix traffic passes through loadbalancer that unwraps TLS.
+ - port: %(unsecure_port)s
+ tls: false
+ bind_address: ''
+ type: http
+
+ x_forwarded: false
+
+ resources:
+ - names: [client, webclient]
+ compress: true
+ - names: [federation]
+ compress: false
+
+ # Turn on the twisted telnet manhole service on localhost on the given
+ # port.
+ # - port: 9000
+ # bind_address: 127.0.0.1
+ # type: manhole
""" % locals()
def read_arguments(self, args):
@@ -93,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
+ if args.print_pidfile is not None:
+ self.print_pidfile = args.print_pidfile
def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
+ server_group.add_argument("--print-pidfile", action='store_true',
+ default=None,
+ help="Print the path to the pidfile just"
+ " before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index ecb2d42c1f..4751d39bc9 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -27,6 +27,7 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
)
+ self.tls_certificate_file = config.get("tls_certificate_path")
self.no_tls = config.get("no_tls", False)
@@ -49,7 +50,11 @@ class TlsConfig(Config):
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
- # PEM encoded X509 certificate for TLS
+ # PEM encoded X509 certificate for TLS.
+ # You can replace the self-signed certificate that synapse
+ # autogenerates on launch with your own SSL certificate + key pair
+ # if you like. Any required intermediary certificates can be
+ # appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS
@@ -91,7 +96,7 @@ class TlsConfig(Config):
)
if not os.path.exists(tls_certificate_path):
- with open(tls_certificate_path, "w") as certifcate_file:
+ with open(tls_certificate_path, "w") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
@@ -106,7 +111,7 @@ class TlsConfig(Config):
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
- certifcate_file.write(cert_pem)
+ certificate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2f8618a0df..c4390f3b2b 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
_ecCurve.addECKeyToContext(context)
except:
- logger.exception("Failed to enable eliptic curve for TLS")
+ logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
- context.use_certificate(config.tls_certificate)
+ context.use_certificate_chain_file(config.tls_certificate_file)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..aa74d4d0cb 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
+from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from OpenSSL import crypto
+from collections import namedtuple
import urllib
import hashlib
import logging
@@ -38,6 +40,9 @@ import logging
logger = logging.getLogger(__name__)
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -49,141 +54,325 @@ class Keyring(object):
self.key_downloads = {}
- @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
- logger.debug("Verifying for %s", server_name)
- key_ids = signature_ids(json_object, server_name)
- if not key_ids:
- raise SynapseError(
- 400,
- "Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
- )
- try:
- verify_key = yield self.get_server_verify_key(server_name, key_ids)
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.warn(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ return self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
+ def verify_json_objects_for_server(self, server_and_json):
+ """Bulk verfies signatures of json objects, bulk fetching keys as
+ necessary.
+
+ Args:
+ server_and_json (list): List of pairs of (server_name, json_object)
+
+ Returns:
+ list of deferreds indicating success or failure to verify each
+ json object's signature for the given server_name.
+ """
+ group_id_to_json = {}
+ group_id_to_group = {}
+ group_ids = []
+
+ next_group_id = 0
+ deferreds = {}
+
+ for server_name, json_object in server_and_json:
+ logger.debug("Verifying for %s", server_name)
+ group_id = next_group_id
+ next_group_id += 1
+ group_ids.append(group_id)
+
+ key_ids = signature_ids(json_object, server_name)
+ if not key_ids:
+ deferreds[group_id] = defer.fail(SynapseError(
+ 400,
+ "Not signed with a supported algorithm",
+ Codes.UNAUTHORIZED,
+ ))
+ else:
+ deferreds[group_id] = defer.Deferred()
+
+ group = KeyGroup(server_name, group_id, key_ids)
+
+ group_id_to_group[group_id] = group
+ group_id_to_json[group_id] = json_object
+
+ @defer.inlineCallbacks
+ def handle_key_deferred(group, deferred):
+ server_name = group.server_name
+ try:
+ _, _, key_id, verify_key = yield deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = group_id_to_json[group.group_id]
+
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
+
+ server_to_deferred = {
+ server_name: defer.Deferred()
+ for server_name, _ in server_and_json
+ }
+
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
+
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
+
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ server_to_deferred.pop(server_name).callback(None)
+ return res
+
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
+
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ return [
+ handle_key_deferred(
+ group_id_to_group[g_id],
+ deferreds[g_id],
)
+ for g_id in group_ids
+ ]
@defer.inlineCallbacks
- def get_server_verify_key(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Trys to fetch the key from a trusted perspective server first.
+ def wait_for_previous_lookups(self, server_names, server_to_deferred):
+ """Waits for any previous key lookups for the given servers to finish.
+
Args:
- server_name(str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
+ server_names (list): list of server_names we want to lookup
+ server_to_deferred (dict): server_name to deferred which gets
+ resolved once we've finished looking up keys for that server
+ """
+ while True:
+ wait_on = [
+ self.key_downloads[server_name]
+ for server_name in server_names
+ if server_name in self.key_downloads
+ ]
+ if wait_on:
+ yield defer.DeferredList(wait_on)
+ else:
+ break
+
+ for server_name, deferred in server_to_deferred:
+ self.key_downloads[server_name] = ObservableDeferred(deferred)
+
+ def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+ """Takes a dict of KeyGroups and tries to find at least one key for
+ each group.
"""
- cached = yield self.store.get_server_verify_keys(server_name, key_ids)
- if cached:
- defer.returnValue(cached[0])
- return
+ # These are functions that produce keys given a list of key ids
+ key_fetch_fns = (
+ self.get_keys_from_store, # First try the local store
+ self.get_keys_from_perspectives, # Then try via perspectives
+ self.get_keys_from_server, # Then try directly
+ )
+
+ @defer.inlineCallbacks
+ def do_iterations():
+ merged_results = {}
+
+ missing_keys = {
+ group.server_name: key_id
+ for group in group_id_to_group.values()
+ for key_id in group.key_ids
+ }
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which groups we have keys for
+ # and which we don't
+ missing_groups = {}
+ for group in group_id_to_group.values():
+ for key_id in group.key_ids:
+ if key_id in merged_results[group.server_name]:
+ group_id_to_deferred[group.group_id].callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
+ break
+ else:
+ missing_groups.setdefault(
+ group.server_name, []
+ ).append(group)
+
+ if not missing_groups:
+ break
+
+ missing_keys = {
+ server_name: set(
+ key_id for group in groups for key_id in group.key_ids
+ )
+ for server_name, groups in missing_groups.items()
+ }
- download = self.key_downloads.get(server_name)
+ for group in missing_groups.values():
+ group_id_to_deferred[group.group_id].errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ group.server_name, group.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
- if download is None:
- download = self._get_server_verify_key_impl(server_name, key_ids)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.key_downloads[server_name] = download
+ def on_err(err):
+ for deferred in group_id_to_deferred.values():
+ if not deferred.called:
+ deferred.errback(err)
- @download.addBoth
- def callback(ret):
- del self.key_downloads[server_name]
- return ret
+ do_iterations().addErrback(on_err)
- r = yield download.observe()
- defer.returnValue(r)
+ return group_id_to_deferred
@defer.inlineCallbacks
- def _get_server_verify_key_impl(self, server_name, key_ids):
- keys = None
+ def get_keys_from_store(self, server_name_and_key_ids):
+ res = yield defer.gatherResults(
+ [
+ self.store.get_server_verify_keys(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(dict(zip(
+ [server_name for server_name, _ in server_name_and_key_ids],
+ res
+ )))
+ @defer.inlineCallbacks
+ def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name, key_ids, perspective_name, perspective_keys
+ server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
- logging.info(
- "Unable to getting key %r for %r from %r: %s %s",
- key_ids, server_name, perspective_name,
+ logger.exception(
+ "Unable to get key from %r: %s %s",
+ perspective_name,
type(e).__name__, str(e.message),
)
+ defer.returnValue({})
- perspective_results = yield defer.gatherResults([
- get_key(p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ])
+ results = yield defer.gatherResults(
+ [
+ get_key(p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for results in perspective_results:
- if results is not None:
- keys = results
+ union_of_keys = {}
+ for result in results:
+ for server_name, keys in result.items():
+ union_of_keys.setdefault(server_name, {}).update(keys)
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
+ defer.returnValue(union_of_keys)
- with limiter:
- if not keys:
+ @defer.inlineCallbacks
+ def get_keys_from_server(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(server_name, key_ids):
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
+ )
+ with limiter:
+ keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
- logging.info(
+ logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ keys = {server_name: keys}
+
+ defer.returnValue(keys)
+
+ results = yield defer.gatherResults(
+ [
+ get_key(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for key_id in key_ids:
- if key_id in keys:
- defer.returnValue(keys[key_id])
- return
- raise ValueError("No verification key found for given key ids")
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue({
+ server_name: keys
+ for server_name, keys in merged.items()
+ if keys
+ })
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@@ -204,6 +393,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
+ for server_name, key_ids in server_names_and_key_ids
}
},
)
@@ -243,23 +433,29 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
- response_keys = yield self.process_v2_response(
- server_name, perspective_name, response
+ processed_response = yield self.process_v2_response(
+ perspective_name, response
)
- keys.update(response_keys)
+ for server_name, response_keys in processed_response.items():
+ keys.setdefault(server_name, {}).update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=response_keys,
+ )
+ for server_name, response_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
keys = {}
for requested_key_id in key_ids:
@@ -295,25 +491,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
- server_name=server_name,
from_server=server_name,
- requested_id=requested_key_id,
+ requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=key_server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+ for key_server_name, verify_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
- def process_v2_response(self, server_name, from_server, response_json,
- requested_id=None):
+ def process_v2_response(self, from_server, response_json,
+ requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -335,6 +536,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
+ results = {}
+ server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
@@ -357,28 +560,31 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
- updated_key_ids = set()
- if requested_id is not None:
- updated_key_ids.add(requested_id)
+ updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- for key_id in updated_key_ids:
- yield self.store.store_server_keys_json(
- server_name=server_name,
- key_id=key_id,
- from_server=server_name,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
+ yield defer.gatherResults(
+ [
+ self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- defer.returnValue(response_keys)
+ results[server_name] = response_keys
- raise ValueError("No verification key found for given key ids")
+ defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -462,8 +668,13 @@ class Keyring(object):
Returns:
A deferred that completes when the keys are stored.
"""
- for key_id, key in verify_keys.items():
- # TODO(markjh): Store whether the keys have expired.
- yield self.store.store_server_verify_key(
- server_name, server_name, key.time_added, key
- )
+ # TODO(markjh): Store whether the keys have expired.
+ yield defer.gatherResults(
+ [
+ self.store.store_server_verify_key(
+ server_name, server_name, key.time_added, key
+ )
+ for key_id, key in verify_keys.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 1aa952150e..7bd78343f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -74,6 +74,8 @@ def prune_event(event):
)
elif event_type == EventTypes.Aliases:
add_fields("aliases")
+ elif event_type == EventTypes.RoomHistoryVisibility:
+ add_fields("history_visibility")
allowed_fields = {
k: v
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f0430b2cb1..bdfa247604 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
-from syutil.jsonutil import encode_canonical_json
-
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
@@ -34,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+ def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
+ include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -52,85 +51,108 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
+ deferreds = self._check_sigs_and_hashes(pdus)
- signed_pdus = []
+ def callback(pdu):
+ return pdu
- @defer.inlineCallbacks
- def do(pdu):
- try:
- new_pdu = yield self._check_sigs_and_hash(pdu)
- signed_pdus.append(new_pdu)
- except SynapseError:
- # FIXME: We should handle signature failures more gracefully.
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
+ return None
+ def try_local_db(res, pdu):
+ if not res:
# Check local db.
- new_pdu = yield self.store.get_event(
+ return self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
-
- # Check pdu.origin
- if pdu.origin != origin:
- try:
- new_pdu = yield self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- )
-
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
- except:
- pass
-
+ return res
+
+ def try_remote(res, pdu):
+ if not res and pdu.origin != origin:
+ return self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ ).addErrback(lambda e: None)
+ return res
+
+ def warn(res, pdu):
+ if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
+ return res
+
+ for pdu, deferred in zip(pdus, deferreds):
+ deferred.addCallbacks(
+ callback, errback, errbackArgs=[pdu]
+ ).addCallback(
+ try_local_db, pdu
+ ).addCallback(
+ try_remote, pdu
+ ).addCallback(
+ warn, pdu
+ )
- yield defer.gatherResults(
- [do(pdu) for pdu in pdus],
+ valid_pdus = yield defer.gatherResults(
+ deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
- defer.returnValue(signed_pdus)
+ if include_none:
+ defer.returnValue(valid_pdus)
+ else:
+ defer.returnValue([p for p in valid_pdus if p])
- @defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
- """Throws a SynapseError if the PDU does not have the correct
+ return self._check_sigs_and_hashes([pdu])[0]
+
+ def _check_sigs_and_hashes(self, pdus):
+ """Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
- # Check signatures are correct.
- redacted_event = prune_event(pdu)
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- pdu.origin, redacted_pdu_json
- )
- except SynapseError:
+ redacted_pdus = [
+ prune_event(pdu)
+ for pdu in pdus
+ ]
+
+ deferreds = self.keyring.verify_json_objects_for_server([
+ (p.origin, p.get_pdu_json())
+ for p in redacted_pdus
+ ])
+
+ def callback(_, pdu, redacted):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+ return pdu
+
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
+ "Signature check failed for %s",
+ pdu.event_id,
)
- raise
+ return failure
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s, %s",
- pdu.event_id, encode_canonical_json(pdu.get_dict())
+ for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+ deferred.addCallbacks(
+ callback, errback,
+ callbackArgs=[pdu, redacted],
+ errbackArgs=[pdu],
)
- defer.returnValue(redacted_event)
- defer.returnValue(pdu)
+ return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..7736d14fb5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+import copy
import itertools
import logging
import random
@@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
- [self._check_sigs_and_hash(pdu) for pdu in pdus],
+ self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
- signed_state, signed_auth = yield defer.gatherResults(
- [
- self._check_sigs_and_hash_and_fetch(
- destination, state, outlier=True
- ),
- self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
- )
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
+ pdus = {
+ p.event_id: p
+ for p in itertools.chain(state, auth_chain)
+ }
+
+ valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+ destination, pdus.values(),
+ outlier=True,
+ )
+
+ valid_pdus_map = {
+ p.event_id: p
+ for p in valid_pdus
+ }
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ signed_state = [
+ copy.copy(valid_pdus_map[p.event_id])
+ for p in state
+ if p.event_id in valid_pdus_map
+ ]
+
+ signed_auth = [
+ valid_pdus_map[p.event_id]
+ for p in auth_chain
+ if p.event_id in valid_pdus_map
+ ]
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ for s in signed_state:
+ s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
@@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
- logger.warn(
+ logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af87805f34..bad93c6b2f 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -93,6 +93,9 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
+ logger.info("Request from %s", origin)
+ request.authenticated_entity = origin
+
defer.returnValue((origin, content))
@log_function
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 685792dbdc..dc5b6ef79d 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
+from .receipts import ReceiptsHandler
class Handlers(object):
@@ -57,6 +58,7 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
+ self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 833ff41377..d6c064b398 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -78,7 +78,9 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
- builder.prev_state = context.prev_state_events
+ builder.prev_state = yield self.store.add_event_hashes(
+ context.prev_state_events
+ )
yield self.auth.add_auth_events(builder, context)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 8269482e47..1240e51649 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
return
user_info = yield self.store.get_user_by_id(user_id)
- if not user_info:
+ if user_info:
defer.returnValue(False)
return
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 63071653a3..1ecf7fef17 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -85,8 +85,10 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
- # sess['clientdict'] = clientdict
- # self._save_session(sess)
+ # Revisit: Assumimg the REST APIs do sensible validation, the data
+ # isn't arbintrary.
+ sess['clientdict'] = clientdict
+ self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 46ce3699d7..f7155fd8d3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
)
from synapse.types import UserID
+from synapse.events.utils import prune_event
+
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer
@@ -138,26 +140,29 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
+
+ event_infos = []
+
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
-
e.internal_metadata.outlier = True
- try:
- auth_ids = [e_id for e_id, _ in e.auth_events]
- auth = {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids
- }
- yield self._handle_new_event(
- origin, e, auth_events=auth
- )
- seen_ids.add(e.event_id)
- except:
- logger.exception(
- "Failed to handle state event %s",
- e.event_id,
- )
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ event_infos.append({
+ "event": e,
+ "auth_events": auth,
+ })
+ seen_ids.add(e.event_id)
+
+ yield self._handle_new_events(
+ origin,
+ event_infos,
+ outliers=True
+ )
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -222,6 +227,56 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id
)
+ @defer.inlineCallbacks
+ def _filter_events_for_server(self, server_name, room_id, events):
+ states = yield self.store.get_state_for_events(
+ room_id, [e.event_id for e in events],
+ )
+
+ events_and_states = zip(events, states)
+
+ def redact_disallowed(event_and_state):
+ event, state = event_and_state
+
+ if not state:
+ return event
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ if visibility in ["invited", "joined"]:
+ # We now loop through all state events looking for
+ # membership states for the requesting server to determine
+ # if the server is either in the room or has been invited
+ # into the room.
+ for ev in state.values():
+ if ev.type != EventTypes.Member:
+ continue
+ try:
+ domain = UserID.from_string(ev.state_key).domain
+ except:
+ continue
+
+ if domain != server_name:
+ continue
+
+ memtype = ev.membership
+ if memtype == Membership.JOIN:
+ return event
+ elif memtype == Membership.INVITE:
+ if visibility == "invited":
+ return event
+ else:
+ return prune_event(event)
+
+ return event
+
+ res = map(redact_disallowed, events_and_states)
+
+ logger.info("_filter_events_for_server %r", res)
+
+ defer.returnValue(res)
+
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
@@ -247,9 +302,15 @@ class FederationHandler(BaseHandler):
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
+ logger.info(
+ "backfill: Got %d events with %d edges",
+ len(events), len(edges),
+ )
+
# For each edge get the current state.
auth_events = {}
+ state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
@@ -258,27 +319,57 @@ class FederationHandler(BaseHandler):
event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
+ auth_events.update({s.event_id: s for s in state})
+ state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- yield defer.gatherResults(
- [
- self._handle_new_event(dest, a)
- for a in auth_events.values()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ seen_events = yield self.store.have_events(
+ set(auth_events.keys()) | set(state_events.keys())
+ )
+
+ all_events = events + state_events.values() + auth_events.values()
+ required_auth = set(
+ a_id for event in all_events for a_id, _ in event.auth_events
+ )
- yield defer.gatherResults(
+ missing_auth = required_auth - set(auth_events)
+ results = yield defer.gatherResults(
[
- self._handle_new_event(
- dest, event_map[e_id],
- state=events_to_state[e_id],
- backfilled=True,
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
)
- for e_id in events_to_state
+ for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
+
+ ev_infos = []
+ for a in auth_events.values():
+ if a.event_id in seen_events:
+ continue
+ ev_infos.append({
+ "event": a,
+ "auth_events": {
+ (auth_events[a_id].type, auth_events[a_id].state_key):
+ auth_events[a_id]
+ for a_id, _ in a.auth_events
+ }
+ })
+
+ for e_id in events_to_state:
+ ev_infos.append({
+ "event": event_map[e_id],
+ "state": events_to_state[e_id],
+ "auth_events": {
+ (auth_events[a_id].type, auth_events[a_id].state_key):
+ auth_events[a_id]
+ for a_id, _ in event_map[e_id].auth_events
+ }
+ })
events.sort(key=lambda e: e.depth)
@@ -286,10 +377,14 @@ class FederationHandler(BaseHandler):
if event in events_to_state:
continue
- yield self._handle_new_event(
- dest, event,
- backfilled=True,
- )
+ ev_infos.append({
+ "event": event,
+ })
+
+ yield self._handle_new_events(
+ dest, ev_infos,
+ backfilled=True,
+ )
defer.returnValue(events)
@@ -555,32 +650,22 @@ class FederationHandler(BaseHandler):
# FIXME
pass
- yield self._handle_auth_events(
- origin, [e for e in auth_chain if e.event_id != event.event_id]
- )
-
- @defer.inlineCallbacks
- def handle_state(e):
+ ev_infos = []
+ for e in itertools.chain(state, auth_chain):
if e.event_id == event.event_id:
- return
+ continue
e.internal_metadata.outlier = True
- try:
- auth_ids = [e_id for e_id, _ in e.auth_events]
- auth = {
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ ev_infos.append({
+ "event": e,
+ "auth_events": {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
- yield self._handle_new_event(
- origin, e, auth_events=auth
- )
- except:
- logger.exception(
- "Failed to handle state event %s",
- e.event_id,
- )
+ })
- yield defer.DeferredList([handle_state(e) for e in state])
+ yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
@@ -837,6 +922,8 @@ class FederationHandler(BaseHandler):
limit
)
+ events = yield self._filter_events_for_server(origin, room_id, events)
+
defer.returnValue(events)
@defer.inlineCallbacks
@@ -895,24 +982,62 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
- logger.debug(
- "_handle_new_event: %s, sigs: %s",
- event.event_id, event.signatures,
+ outlier = event.internal_metadata.is_outlier()
+
+ context = yield self._prep_event(
+ origin, event,
+ state=state,
+ backfilled=backfilled,
+ current_state=current_state,
+ auth_events=auth_events,
)
- context = yield self.state_handler.compute_event_context(
- event, old_state=state
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ is_new_state=(not outlier and not backfilled),
+ current_state=current_state,
)
- if not auth_events:
- auth_events = context.current_state
+ defer.returnValue((context, event_stream_id, max_stream_id))
- logger.debug(
- "_handle_new_event: %s, auth_events: %s",
- event.event_id, auth_events,
+ @defer.inlineCallbacks
+ def _handle_new_events(self, origin, event_infos, backfilled=False,
+ outliers=False):
+ contexts = yield defer.gatherResults(
+ [
+ self._prep_event(
+ origin,
+ ev_info["event"],
+ state=ev_info.get("state"),
+ backfilled=backfilled,
+ auth_events=ev_info.get("auth_events"),
+ )
+ for ev_info in event_infos
+ ]
)
- is_new_state = not event.internal_metadata.is_outlier()
+ yield self.store.persist_events(
+ [
+ (ev_info["event"], context)
+ for ev_info, context in itertools.izip(event_infos, contexts)
+ ],
+ backfilled=backfilled,
+ is_new_state=(not outliers and not backfilled),
+ )
+
+ @defer.inlineCallbacks
+ def _prep_event(self, origin, event, state=None, backfilled=False,
+ current_state=None, auth_events=None):
+ outlier = event.internal_metadata.is_outlier()
+
+ context = yield self.state_handler.compute_event_context(
+ event, old_state=state, outlier=outlier,
+ )
+
+ if not auth_events:
+ auth_events = context.current_state
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -937,26 +1062,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
- # FIXME: Don't store as rejected with AUTH_ERROR if we haven't
- # seen all the auth events.
- yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- is_new_state=False,
- current_state=current_state,
- )
- raise
-
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- is_new_state=(is_new_state and not backfilled),
- current_state=current_state,
- )
-
- defer.returnValue((context, event_stream_id, max_stream_id))
+ defer.returnValue(context)
@defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@@ -1019,14 +1125,24 @@ class FederationHandler(BaseHandler):
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
- have_events = yield self.store.have_events(
- [e_id for e_id, _ in event.auth_events]
- )
-
+ current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
+
+ if event_auth_events - current_state:
+ have_events = yield self.store.have_events(
+ event_auth_events - current_state
+ )
+ else:
+ have_events = {}
+
+ have_events.update({
+ e.event_id: ""
+ for e in auth_events.values()
+ })
+
seen_events = set(have_events.keys())
- missing_auth = event_auth_events - seen_events
+ missing_auth = event_auth_events - seen_events - current_state
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 6200e10775..c1095708a0 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
- trustedIdServers = ['matrix.org']
+ trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds:
id_server = creds['id_server']
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 867fdbefb0..9d6d4f0978 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
"room_key", next_key
)
+ if not events:
+ defer.returnValue({
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ })
+
+ events = yield self._filter_events_for_client(user_id, room_id, events)
+
time_now = self.clock.time_msec()
chunk = {
"chunk": [
- serialize_event(e, time_now, as_client_event) for e in events
+ serialize_event(e, time_now, as_client_event)
+ for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
@@ -126,6 +136,52 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
+ def _filter_events_for_client(self, user_id, room_id, events):
+ states = yield self.store.get_state_for_events(
+ room_id, [e.event_id for e in events],
+ )
+
+ events_and_states = zip(events, states)
+
+ def allowed(event_and_state):
+ event, state = event_and_state
+
+ if event.type == EventTypes.RoomHistoryVisibility:
+ return True
+
+ membership_ev = state.get((EventTypes.Member, user_id), None)
+ if membership_ev:
+ membership = membership_ev.membership
+ else:
+ membership = Membership.LEAVE
+
+ if membership == Membership.JOIN:
+ return True
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ else:
+ visibility = "shared"
+
+ if visibility == "public":
+ return True
+ elif visibility == "shared":
+ return True
+ elif visibility == "joined":
+ return membership == Membership.JOIN
+ elif visibility == "invited":
+ return membership == Membership.INVITE
+
+ return True
+
+ events_and_states = filter(allowed, events_and_states)
+ defer.returnValue([
+ ev
+ for ev, _ in events_and_states
+ ])
+
+ @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
@@ -278,6 +334,11 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None
)
+ receipt_stream = self.hs.get_event_sources().sources["receipt"]
+ receipt, _ = yield receipt_stream.get_pagination_rows(
+ user, pagination_config.get_source_config("receipt"), None
+ )
+
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
@@ -316,6 +377,10 @@ class MessageHandler(BaseHandler):
]
).addErrback(unwrapFirstError)
+ messages = yield self._filter_events_for_client(
+ user_id, event.room_id, messages
+ )
+
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
@@ -344,7 +409,8 @@ class MessageHandler(BaseHandler):
ret = {
"rooms": rooms_ret,
"presence": presence,
- "end": now_token.to_string()
+ "receipts": receipt,
+ "end": now_token.to_string(),
}
defer.returnValue(ret)
@@ -380,15 +446,6 @@ class MessageHandler(BaseHandler):
if limit is None:
limit = 10
- messages, token = yield self.store.get_recent_events_for_room(
- room_id,
- limit=limit,
- end_token=now_token.room_key,
- )
-
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
-
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
@@ -396,19 +453,45 @@ class MessageHandler(BaseHandler):
]
presence_handler = self.hs.get_handlers().presence_handler
- presence = []
- for m in room_members:
- try:
- member_presence = yield presence_handler.get_state(
- target_user=UserID.from_string(m.user_id),
- auth_user=auth_user,
- as_event=True,
- )
- presence.append(member_presence)
- except SynapseError:
- logger.exception(
- "Failed to get member presence of %r", m.user_id
+
+ @defer.inlineCallbacks
+ def get_presence():
+ presence_defs = yield defer.DeferredList(
+ [
+ presence_handler.get_state(
+ target_user=UserID.from_string(m.user_id),
+ auth_user=auth_user,
+ as_event=True,
+ check_auth=False,
+ )
+ for m in room_members
+ ],
+ consumeErrors=True,
+ )
+
+ defer.returnValue([p for success, p in presence_defs if success])
+
+ receipts_handler = self.hs.get_handlers().receipts_handler
+
+ presence, receipts, (messages, token) = yield defer.gatherResults(
+ [
+ get_presence(),
+ receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
+ self.store.get_recent_events_for_room(
+ room_id,
+ limit=limit,
+ end_token=now_token.room_key,
)
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ messages = yield self._filter_events_for_client(
+ user_id, room_id, messages
+ )
+
+ start_token = now_token.copy_and_replace("room_key", token[0])
+ end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
@@ -421,5 +504,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
},
"state": state,
- "presence": presence
+ "presence": presence,
+ "receipts": receipts,
})
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 023ad33ab0..341a516da2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -191,24 +191,24 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False)
@defer.inlineCallbacks
- def get_state(self, target_user, auth_user, as_event=False):
+ def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
if self.hs.is_mine(target_user):
- visible = yield self.is_presence_visible(
- observer_user=auth_user,
- observed_user=target_user
- )
+ if check_auth:
+ visible = yield self.is_presence_visible(
+ observer_user=auth_user,
+ observed_user=target_user
+ )
- if not visible:
- raise SynapseError(404, "Presence information not visible")
- state = yield self.store.get_presence_state(target_user.localpart)
- if "mtime" in state:
- del state["mtime"]
- state["presence"] = state.pop("state")
+ if not visible:
+ raise SynapseError(404, "Presence information not visible")
if target_user in self._user_cachemap:
- cached_state = self._user_cachemap[target_user].get_state()
- if "last_active" in cached_state:
- state["last_active"] = cached_state["last_active"]
+ state = self._user_cachemap[target_user].get_state()
+ else:
+ state = yield self.store.get_presence_state(target_user.localpart)
+ if "mtime" in state:
+ del state["mtime"]
+ state["presence"] = state.pop("state")
else:
# TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state()
@@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler):
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext():
- self.notifier.on_new_user_event(
+ self.notifier.on_new_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
new file mode 100644
index 0000000000..415dd339f6
--- /dev/null
+++ b/synapse/handlers/receipts.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseHandler
+
+from twisted.internet import defer
+
+from synapse.util.logcontext import PreserveLoggingContext
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsHandler(BaseHandler):
+ def __init__(self, hs):
+ super(ReceiptsHandler, self).__init__(hs)
+
+ self.hs = hs
+ self.federation = hs.get_replication_layer()
+ self.federation.register_edu_handler(
+ "m.receipt", self._received_remote_receipt
+ )
+ self.clock = self.hs.get_clock()
+
+ self._receipt_cache = None
+
+ @defer.inlineCallbacks
+ def received_client_receipt(self, room_id, receipt_type, user_id,
+ event_id):
+ """Called when a client tells us a local user has read up to the given
+ event_id in the room.
+ """
+ receipt = {
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_ids": [event_id],
+ "data": {
+ "ts": int(self.clock.time_msec()),
+ }
+ }
+
+ is_new = yield self._handle_new_receipts([receipt])
+
+ if is_new:
+ self._push_remotes([receipt])
+
+ @defer.inlineCallbacks
+ def _received_remote_receipt(self, origin, content):
+ """Called when we receive an EDU of type m.receipt from a remote HS.
+ """
+ receipts = [
+ {
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_ids": user_values["event_ids"],
+ "data": user_values.get("data", {}),
+ }
+ for room_id, room_values in content.items()
+ for receipt_type, users in room_values.items()
+ for user_id, user_values in users.items()
+ ]
+
+ yield self._handle_new_receipts(receipts)
+
+ @defer.inlineCallbacks
+ def _handle_new_receipts(self, receipts):
+ """Takes a list of receipts, stores them and informs the notifier.
+ """
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ res = yield self.store.insert_receipt(
+ room_id, receipt_type, user_id, event_ids, data
+ )
+
+ if not res:
+ # res will be None if this read receipt is 'old'
+ defer.returnValue(False)
+
+ stream_id, max_persisted_id = res
+
+ with PreserveLoggingContext():
+ self.notifier.on_new_event(
+ "receipt_key", max_persisted_id, rooms=[room_id]
+ )
+
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ def _push_remotes(self, receipts):
+ """Given a list of receipts, works out which remote servers should be
+ poked and pokes them.
+ """
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ remotedomains = set()
+
+ rm_handler = self.hs.get_handlers().room_member_handler
+ yield rm_handler.fetch_room_distributions_into(
+ room_id, localusers=None, remotedomains=remotedomains
+ )
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
+ }
+ },
+ },
+ )
+
+ @defer.inlineCallbacks
+ def get_receipts_for_room(self, room_id, to_key):
+ """Gets all receipts for a room, upto the given key.
+ """
+ result = yield self.store.get_linearized_receipts_for_room(
+ room_id,
+ to_key=to_key,
+ )
+
+ if not result:
+ defer.returnValue([])
+
+ event = {
+ "type": "m.receipt",
+ "room_id": room_id,
+ "content": result,
+ }
+
+ defer.returnValue([event])
+
+
+class ReceiptEventSource(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_new_events_for_user(self, user, from_key, limit):
+ defer.returnValue(([], from_key))
+ from_key = int(from_key)
+ to_key = yield self.get_current_key()
+
+ if from_key == to_key:
+ defer.returnValue(([], to_key))
+
+ rooms = yield self.store.get_rooms_for_user(user.to_string())
+ rooms = [room.room_id for room in rooms]
+ events = yield self.store.get_linearized_receipts_for_rooms(
+ rooms,
+ from_key=from_key,
+ to_key=to_key,
+ )
+
+ defer.returnValue((events, to_key))
+
+ def get_current_key(self, direction='f'):
+ return self.store.get_max_receipt_stream_id()
+
+ @defer.inlineCallbacks
+ def get_pagination_rows(self, user, config, key):
+ to_key = int(config.from_key)
+ defer.returnValue(([], to_key))
+
+ if config.to_key:
+ from_key = int(config.to_key)
+ else:
+ from_key = None
+
+ rooms = yield self.store.get_rooms_for_user(user.to_string())
+ rooms = [room.room_id for room in rooms]
+ events = yield self.store.get_linearized_receipts_for_rooms(
+ rooms,
+ from_key=from_key,
+ to_key=to_key,
+ )
+
+ defer.returnValue((events, to_key))
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7b68585a17..f81d75017d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
password (str) : The password to assign to this user so they can
- login again.
+ login again. This can be None which means they cannot login again
+ via a password (e.g. the user is an application service user).
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -193,6 +194,35 @@ class RegistrationHandler(BaseHandler):
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
+ def register_saml2(self, localpart):
+ """
+ Registers email_id as SAML2 Based Auth.
+ """
+ if urllib.quote(localpart) != localpart:
+ raise SynapseError(
+ 400,
+ "User ID must only contain characters which do not"
+ " require URL encoding."
+ )
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+
+ yield self.check_user_id_is_valid(user_id)
+ token = self._generate_token(user_id)
+ try:
+ yield self.store.register(
+ user_id=user_id,
+ token=token,
+ password_hash=None
+ )
+ yield self.distributor.fire("registered_user", user)
+ except Exception, e:
+ yield self.store.add_access_token_to_user(user_id, token)
+ # Ignore Registration errors
+ logger.exception(e)
+ defer.returnValue((user_id, token))
+
+ @defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
Registers emails with an identity server.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4bd027d9bb..7511d294f3 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -19,12 +19,15 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
-from synapse.api.constants import EventTypes, Membership, JoinRules
+from synapse.api.constants import (
+ EventTypes, Membership, JoinRules, RoomCreationPreset,
+)
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
+from collections import OrderedDict
import logging
import string
@@ -33,6 +36,19 @@ logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseHandler):
+ PRESETS_DICT = {
+ RoomCreationPreset.PRIVATE_CHAT: {
+ "join_rules": JoinRules.INVITE,
+ "history_visibility": "invited",
+ "original_invitees_have_ops": False,
+ },
+ RoomCreationPreset.PUBLIC_CHAT: {
+ "join_rules": JoinRules.PUBLIC,
+ "history_visibility": "shared",
+ "original_invitees_have_ops": False,
+ },
+ }
+
@defer.inlineCallbacks
def create_room(self, user_id, room_id, config):
""" Creates a new room.
@@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname],
)
+ preset_config = config.get(
+ "preset",
+ RoomCreationPreset.PUBLIC_CHAT
+ if is_public
+ else RoomCreationPreset.PRIVATE_CHAT
+ )
+
+ raw_initial_state = config.get("initial_state", [])
+
+ initial_state = OrderedDict()
+ for val in raw_initial_state:
+ initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
+
user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room(
- user, room_id, is_public=is_public
+ user, room_id,
+ preset_config=preset_config,
+ invite_list=invite_list,
+ initial_state=initial_state,
)
msg_handler = self.hs.get_handlers().message_handler
@@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
- def _create_events_for_new_room(self, creator, room_id, is_public=False):
+ def _create_events_for_new_room(self, creator, room_id, preset_config,
+ invite_list, initial_state):
+ config = RoomCreationHandler.PRESETS_DICT[preset_config]
+
creator_id = creator.to_string()
event_keys = {
@@ -203,9 +238,10 @@ class RoomCreationHandler(BaseHandler):
},
)
- power_levels_event = create(
- etype=EventTypes.PowerLevels,
- content={
+ returned_events = [creation_event, join_event]
+
+ if (EventTypes.PowerLevels, '') not in initial_state:
+ power_level_content = {
"users": {
creator.to_string(): 100,
},
@@ -213,6 +249,7 @@ class RoomCreationHandler(BaseHandler):
"events": {
EventTypes.Name: 100,
EventTypes.PowerLevels: 100,
+ EventTypes.RoomHistoryVisibility: 100,
},
"events_default": 0,
"state_default": 50,
@@ -220,21 +257,43 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 0,
- },
- )
+ }
- join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
- join_rules_event = create(
- etype=EventTypes.JoinRules,
- content={"join_rule": join_rule},
- )
+ if config["original_invitees_have_ops"]:
+ for invitee in invite_list:
+ power_level_content["users"][invitee] = 100
- return [
- creation_event,
- join_event,
- power_levels_event,
- join_rules_event,
- ]
+ power_levels_event = create(
+ etype=EventTypes.PowerLevels,
+ content=power_level_content,
+ )
+
+ returned_events.append(power_levels_event)
+
+ if (EventTypes.JoinRules, '') not in initial_state:
+ join_rules_event = create(
+ etype=EventTypes.JoinRules,
+ content={"join_rule": config["join_rules"]},
+ )
+
+ returned_events.append(join_rules_event)
+
+ if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
+ history_event = create(
+ etype=EventTypes.RoomHistoryVisibility,
+ content={"history_visibility": config["history_visibility"]}
+ )
+
+ returned_events.append(history_event)
+
+ for (etype, state_key), content in initial_state.items():
+ returned_events.append(create(
+ etype=etype,
+ state_key=state_key,
+ content=content,
+ ))
+
+ return returned_events
class RoomMemberHandler(BaseHandler):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bd8c603681..6cff6230c1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -293,6 +293,51 @@ class SyncHandler(BaseHandler):
))
@defer.inlineCallbacks
+ def _filter_events_for_client(self, user_id, room_id, events):
+ states = yield self.store.get_state_for_events(
+ room_id, [e.event_id for e in events],
+ )
+
+ events_and_states = zip(events, states)
+
+ def allowed(event_and_state):
+ event, state = event_and_state
+
+ if event.type == EventTypes.RoomHistoryVisibility:
+ return True
+
+ membership_ev = state.get((EventTypes.Member, user_id), None)
+ if membership_ev:
+ membership = membership_ev.membership
+ else:
+ membership = Membership.LEAVE
+
+ if membership == Membership.JOIN:
+ return True
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ else:
+ visibility = "shared"
+
+ if visibility == "public":
+ return True
+ elif visibility == "shared":
+ return True
+ elif visibility == "joined":
+ return membership == Membership.JOIN
+ elif visibility == "invited":
+ return membership == Membership.INVITE
+
+ return True
+ events_and_states = filter(allowed, events_and_states)
+ defer.returnValue([
+ ev
+ for ev, _ in events_and_states
+ ])
+
+ @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
limited = True
@@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events)
+ loaded_recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(), room_id, loaded_recents,
+ )
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a9895292c2..026bd2b9d4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler):
self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext():
- self.notifier.on_new_user_event(
+ self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e746f2416e..49737d55da 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -61,21 +61,31 @@ class SimpleHttpClient(object):
self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string
- def request(self, method, *args, **kwargs):
+ def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
d = preserve_context_over_fn(
self.agent.request,
- method, *args, **kwargs
+ method, uri, *args, **kwargs
)
+ logger.info("Sending request %s %s", method, uri)
+
def _cb(response):
incoming_responses_counter.inc(method, response.code)
+ logger.info(
+ "Received response to %s %s: %s",
+ method, uri, response.code
+ )
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
+ logger.info(
+ "Error sending request to %s %s: %s %s",
+ method, uri, failure.type, failure.getErrorMessage()
+ )
return failure
d.addCallbacks(_cb, _eb)
@@ -84,7 +94,9 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}):
+ # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
+
query_bytes = urllib.urlencode(args, True)
response = yield self.request(
@@ -97,7 +109,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -105,7 +117,7 @@ class SimpleHttpClient(object):
def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json)
- logger.info("HTTP POST %s -> %s", json_str, uri)
+ logger.debug("HTTP POST %s -> %s", json_str, uri)
response = yield self.request(
"POST",
@@ -116,7 +128,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -149,7 +161,7 @@ class SimpleHttpClient(object):
})
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -192,7 +204,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -226,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ec5b06ddca..4d74bd5d78 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -35,11 +35,13 @@ from syutil.crypto.jsonsign import sign_json
import simplejson as json
import logging
+import sys
import urllib
import urlparse
logger = logging.getLogger(__name__)
+outbound_logger = logging.getLogger("synapse.http.outbound")
metrics = synapse.metrics.get_metrics_for(__name__)
@@ -86,6 +88,7 @@ class MatrixFederationHttpClient(object):
)
self.clock = hs.get_clock()
self.version_string = hs.version_string
+ self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
@@ -106,16 +109,12 @@ class MatrixFederationHttpClient(object):
destination, path_bytes, param_bytes, query_bytes
)
- logger.info("Sending request to %s: %s %s",
- destination, method, url_bytes)
+ txn_id = "%s-O-%s" % (method, self._next_id)
+ self._next_id = (self._next_id + 1) % (sys.maxint - 1)
- logger.debug(
- "Types: %s",
- [
- type(destination), type(method), type(path_bytes),
- type(param_bytes),
- type(query_bytes)
- ]
+ outbound_logger.info(
+ "{%s} [%s] Sending request: %s %s",
+ txn_id, destination, method, url_bytes
)
# XXX: Would be much nicer to retry only at the transaction-layer
@@ -126,66 +125,80 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "")
)
- while True:
- producer = None
- if body_callback:
- producer = body_callback(method, http_url_bytes, headers_dict)
-
- try:
- request_deferred = preserve_context_over_fn(
- self.agent.request,
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
+ log_result = None
+ try:
+ while True:
+ producer = None
+ if body_callback:
+ producer = body_callback(method, http_url_bytes, headers_dict)
+
+ try:
+ def send_request():
+ request_deferred = preserve_context_over_fn(
+ self.agent.request,
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+
+
+ return self.clock.time_bound_deferred(
+ request_deferred,
+ time_out=timeout/1000. if timeout else 60,
+ )
+
+ response = yield preserve_context_over_fn(
+ send_request,
+ )
- response = yield self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout/1000. if timeout else 60,
- )
+ log_result = "%d %s" % (response.code, response.phrase,)
+ break
+ except Exception as e:
+ if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ logger.warn(
+ "DNS Lookup failed to %s with %s",
+ destination,
+ e
+ )
+ log_result = "DNS Lookup failed to %s with %s" % (
+ destination, e
+ )
+ raise
- logger.debug("Got response to %s", method)
- break
- except Exception as e:
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
- "DNS Lookup failed to %s with %s",
+ "{%s} Sending request failed to %s: %s %s: %s - %s",
+ txn_id,
destination,
- e
+ method,
+ url_bytes,
+ type(e).__name__,
+ _flatten_response_never_received(e),
)
- raise
-
- logger.warn(
- "Sending request failed to %s: %s %s: %s - %s",
- destination,
- method,
- url_bytes,
- type(e).__name__,
- _flatten_response_never_received(e),
- )
- if retries_left and not timeout:
- yield sleep(2 ** (5 - retries_left))
- retries_left -= 1
- else:
- raise
-
- logger.info(
- "Received response %d %s for %s: %s %s",
- response.code,
- response.phrase,
- destination,
- method,
- url_bytes
- )
+ log_result = "%s - %s" % (
+ type(e).__name__, _flatten_response_never_received(e),
+ )
+
+ if retries_left and not timeout:
+ yield sleep(2 ** (5 - retries_left))
+ retries_left -= 1
+ else:
+ raise
+ finally:
+ outbound_logger.info(
+ "{%s} [%s] Result: %s",
+ txn_id,
+ destination,
+ log_result,
+ )
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException(
response.code, response.phrase, body
)
@@ -265,10 +278,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
-
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
@@ -311,9 +321,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -371,9 +379,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -416,7 +422,10 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders())
try:
- length = yield _readBodyToFile(response, output_stream, max_size)
+ length = yield preserve_context_over_fn(
+ _readBodyToFile,
+ response, output_stream, max_size
+ )
except:
logger.exception("Failed to download body")
raise
diff --git a/synapse/http/server.py b/synapse/http/server.py
index ae8f3b3972..b60e905a62 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -79,53 +79,39 @@ def request_handler(request_handler):
_next_request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
- code = None
- start = self.clock.time_msec()
- try:
- logger.info(
- "Received request: %s %s",
- request.method, request.path
- )
- d = request_handler(self, request)
- with PreserveLoggingContext():
- yield d
- code = request.code
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
+ with request.processing():
+ try:
+ d = request_handler(self, request)
+ with PreserveLoggingContext():
+ yield d
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ outgoing_responses_counter.inc(request.method, str(code))
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ version_string=self.version_string,
+ )
+ except:
+ logger.exception(
+ "Failed handle request %s.%s on %r: %r",
+ request_handler.__module__,
+ request_handler.__name__,
+ self,
+ request
+ )
+ respond_with_json(
+ request,
+ 500,
+ {"error": "Internal server error"},
+ send_cors=True
)
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except:
- code = 500
- logger.exception(
- "Failed handle request %s.%s on %r: %r",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request
- )
- respond_with_json(
- request,
- 500,
- {"error": "Internal server error"},
- send_cors=True
- )
- finally:
- code = str(code) if code else "-"
- end = self.clock.time_msec()
- logger.info(
- "Processed request: %dms %s %s %s",
- end-start, code, request.method, request.path
- )
return wrapped_request_handler
@@ -221,7 +207,7 @@ class JsonResource(HttpServer, resource.Resource):
incoming_requests_counter.inc(request.method, servlet_classname)
args = [
- urllib.unquote(u).decode("UTF-8") for u in m.groups()
+ urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
]
callback_return = yield callback(request, *args)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 078abfc56d..dbd8efe9fb 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken
import synapse.metrics
@@ -45,21 +45,11 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
+ __slots__ = ["deferred"]
def __init__(self, deferred):
self.deferred = deferred
- def notified(self):
- return self.deferred.called
-
- def notify(self, token):
- """ Inform whoever is listening about the new events.
- """
- try:
- self.deferred.callback(token)
- except defer.AlreadyCalledError:
- pass
-
class _NotifierUserStream(object):
"""This represents a user connected to the event stream.
@@ -75,11 +65,12 @@ class _NotifierUserStream(object):
appservice=None):
self.user = str(user)
self.appservice = appservice
- self.listeners = set()
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
+
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
event source.
@@ -91,12 +82,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
- if self.listeners:
- self.last_notified_ms = time_now_ms
- listeners = self.listeners
- self.listeners = set()
- for listener in listeners:
- listener.notify(self.current_token)
+ self.last_notified_ms = time_now_ms
+ noify_deferred = self.notify_deferred
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
+ noify_deferred.callback(self.current_token)
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
@@ -114,6 +103,18 @@ class _NotifierUserStream(object):
self.appservice, set()
).discard(self)
+ def count_listeners(self):
+ return len(self.notify_deferred.observers())
+
+ def new_listener(self, token):
+ """Returns a deferred that is resolved when there is a new token
+ greater than the given token.
+ """
+ if self.current_token.is_after(token):
+ return _NotificationListener(defer.succeed(self.current_token))
+ else:
+ return _NotificationListener(self.notify_deferred.observe())
+
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
@@ -158,7 +159,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
- return sum(len(stream.listeners) for stream in all_user_streams)
+ return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
metrics.register_callback(
@@ -220,16 +221,7 @@ class Notifier(object):
event
)
- room_id = event.room_id
-
- room_user_streams = self.room_to_user_streams.get(room_id, set())
-
- user_streams = room_user_streams.copy()
-
- for user in extra_users:
- user_stream = self.user_to_user_stream.get(str(user))
- if user_stream is not None:
- user_streams.add(user_stream)
+ app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
@@ -241,24 +233,20 @@ class Notifier(object):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
- user_streams |= app_user_streams
+ app_streams |= app_user_streams
- logger.debug("on_new_room_event listeners %s", user_streams)
-
- time_now_ms = self.clock.time_msec()
- for user_stream in user_streams:
- try:
- user_stream.notify(
- "room_key", "s%d" % (room_stream_id,), time_now_ms
- )
- except:
- logger.exception("Failed to notify listener")
+ self.on_new_event(
+ "room_key", room_stream_id,
+ users=extra_users,
+ rooms=[event.room_id],
+ extra_streams=app_streams,
+ )
@defer.inlineCallbacks
@log_function
- def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
- """ Used to inform listeners that something has happend
- presence/user event wise.
+ def on_new_event(self, stream_key, new_token, users=[], rooms=[],
+ extra_streams=set()):
+ """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
@@ -282,14 +270,10 @@ class Notifier(object):
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback,
- from_token=StreamToken("s0", "0", "0")):
+ from_token=StreamToken("s0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
-
- deferred = defer.Deferred()
- time_now_ms = self.clock.time_msec()
-
user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
@@ -302,55 +286,44 @@ class Notifier(object):
rooms=rooms,
appservice=appservice,
current_token=current_token,
- time_now_ms=time_now_ms,
+ time_now_ms=self.clock.time_msec(),
)
self._register_with_keys(user_stream)
+
+ result = None
+ if timeout:
+ # Will be set to a _NotificationListener that we'll be waiting on.
+ # Allows us to cancel it.
+ listener = None
+
+ def timed_out():
+ if listener:
+ listener.deferred.cancel()
+ timer = self.clock.call_later(timeout/1000., timed_out)
+
+ prev_token = from_token
+ while not result:
+ try:
+ current_token = user_stream.current_token
+
+ result = yield callback(prev_token, current_token)
+ if result:
+ break
+
+ # Now we wait for the _NotifierUserStream to be told there
+ # is a new token.
+ # We need to supply the token we supplied to callback so
+ # that we don't miss any current_token updates.
+ prev_token = current_token
+ listener = user_stream.new_listener(prev_token)
+ yield listener.deferred
+ except defer.CancelledError:
+ break
+
+ self.clock.cancel_call_later(timer, ignore_errs=True)
else:
current_token = user_stream.current_token
-
- listener = [_NotificationListener(deferred)]
-
- if timeout and not current_token.is_after(from_token):
- user_stream.listeners.add(listener[0])
-
- if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
- else:
- result = None
-
- timer = [None]
-
- if result:
- user_stream.listeners.discard(listener[0])
- defer.returnValue(result)
- return
-
- if timeout:
- timed_out = [False]
-
- def _timeout_listener():
- timed_out[0] = True
- timer[0] = None
- user_stream.listeners.discard(listener[0])
- listener[0].notify(current_token)
-
- # We create multiple notification listeners so we have to manage
- # canceling the timeout ourselves.
- timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
-
- while not result and not timed_out[0]:
- new_token = yield deferred
- deferred = defer.Deferred()
- listener[0] = _NotificationListener(deferred)
- user_stream.listeners.add(listener[0])
- result = yield callback(current_token, new_token)
- current_token = new_token
-
- if timer[0] is not None:
- try:
- self.clock.cancel_call_later(timer[0])
- except:
- logger.exception("Failed to cancel notifer timer")
defer.returnValue(result)
@@ -368,6 +341,9 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
+ if not after_token.is_after(before_token):
+ defer.returnValue(None)
+
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
@@ -376,10 +352,10 @@ class Notifier(object):
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
- stuff, new_key = yield source.get_new_events_for_user(
+ new_events, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit,
)
- events.extend(stuff)
+ events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
if events:
@@ -402,7 +378,7 @@ class Notifier(object):
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
- if stream.listeners:
+ if stream.count_listeners():
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 8059fff1b2..36f450c31d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -24,6 +24,7 @@ import baserules
import logging
import simplejson as json
import re
+import random
logger = logging.getLogger(__name__)
@@ -256,134 +257,154 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
+ wait = 0
while self.alive:
- from_tok = StreamToken.from_string(self.last_token)
- config = PaginationConfig(from_token=from_tok, limit='1')
- chunk = yield self.evStreamHandler.get_stream(
- self.user_name, config,
- timeout=100*365*24*60*60*1000, affect_presence=False
- )
+ try:
+ if wait > 0:
+ yield synapse.util.async.sleep(wait)
+ yield self.get_and_dispatch()
+ wait = 0
+ except:
+ if wait == 0:
+ wait = 1
+ else:
+ wait = min(wait * 2, 1800)
+ logger.exception(
+ "Exception in pusher loop for pushkey %s. Pausing for %ds",
+ self.pushkey, wait
+ )
- # limiting to 1 may get 1 event plus 1 presence event, so
- # pick out the actual event
- single_event = None
- for c in chunk['chunk']:
- if 'event_id' in c: # Hmmm...
- single_event = c
- break
- if not single_event:
- self.last_token = chunk['end']
- continue
+ @defer.inlineCallbacks
+ def get_and_dispatch(self):
+ from_tok = StreamToken.from_string(self.last_token)
+ config = PaginationConfig(from_token=from_tok, limit='1')
+ timeout = (300 + random.randint(-60, 60)) * 1000
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_name, config,
+ timeout=timeout, affect_presence=False
+ )
- if not self.alive:
- continue
+ # limiting to 1 may get 1 event plus 1 presence event, so
+ # pick out the actual event
+ single_event = None
+ for c in chunk['chunk']:
+ if 'event_id' in c: # Hmmm...
+ single_event = c
+ break
+ if not single_event:
+ self.last_token = chunk['end']
+ logger.debug("Event stream timeout for pushkey %s", self.pushkey)
+ return
- processed = False
- actions = yield self._actions_for_event(single_event)
- tweaks = _tweaks_for_actions(actions)
+ if not self.alive:
+ return
- if len(actions) == 0:
- logger.warn("Empty actions! Using default action.")
- actions = Pusher.DEFAULT_ACTIONS
+ processed = False
+ actions = yield self._actions_for_event(single_event)
+ tweaks = _tweaks_for_actions(actions)
- if 'notify' not in actions and 'dont_notify' not in actions:
- logger.warn("Neither notify nor dont_notify in actions: adding default")
- actions.extend(Pusher.DEFAULT_ACTIONS)
+ if len(actions) == 0:
+ logger.warn("Empty actions! Using default action.")
+ actions = Pusher.DEFAULT_ACTIONS
- if 'dont_notify' in actions:
- logger.debug(
- "%s for %s: dont_notify",
- single_event['event_id'], self.user_name
- )
+ if 'notify' not in actions and 'dont_notify' not in actions:
+ logger.warn("Neither notify nor dont_notify in actions: adding default")
+ actions.extend(Pusher.DEFAULT_ACTIONS)
+
+ if 'dont_notify' in actions:
+ logger.debug(
+ "%s for %s: dont_notify",
+ single_event['event_id'], self.user_name
+ )
+ processed = True
+ else:
+ rejected = yield self.dispatch_push(single_event, tweaks)
+ self.has_unread = True
+ if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
- else:
- rejected = yield self.dispatch_push(single_event, tweaks)
- self.has_unread = True
- if isinstance(rejected, list) or isinstance(rejected, tuple):
- processed = True
- for pk in rejected:
- if pk != self.pushkey:
- # for sanity, we only remove the pushkey if it
- # was the one we actually sent...
- logger.warn(
- ("Ignoring rejected pushkey %s because we"
- " didn't send it"), pk
- )
- else:
- logger.info(
- "Pushkey %s was rejected: removing",
- pk
- )
- yield self.hs.get_pusherpool().remove_pusher(
- self.app_id, pk, self.user_name
- )
-
- if not self.alive:
- continue
+ for pk in rejected:
+ if pk != self.pushkey:
+ # for sanity, we only remove the pushkey if it
+ # was the one we actually sent...
+ logger.warn(
+ ("Ignoring rejected pushkey %s because we"
+ " didn't send it"), pk
+ )
+ else:
+ logger.info(
+ "Pushkey %s was rejected: removing",
+ pk
+ )
+ yield self.hs.get_pusherpool().remove_pusher(
+ self.app_id, pk, self.user_name
+ )
+
+ if not self.alive:
+ return
+
+ if processed:
+ self.backoff_delay = Pusher.INITIAL_BACKOFF
+ self.last_token = chunk['end']
+ self.store.update_pusher_last_token_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.last_token,
+ self.clock.time_msec()
+ )
+ if self.failing_since:
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.failing_since)
+ else:
+ if not self.failing_since:
+ self.failing_since = self.clock.time_msec()
+ self.store.update_pusher_failing_since(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.failing_since
+ )
- if processed:
+ if (self.failing_since and
+ self.failing_since <
+ self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
+ # we really only give up so that if the URL gets
+ # fixed, we don't suddenly deliver a load
+ # of old notifications.
+ logger.warn("Giving up on a notification to user %s, "
+ "pushkey %s",
+ self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
- self.store.update_pusher_last_token_and_success(
+ self.store.update_pusher_last_token(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.last_token
+ )
+
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
- self.last_token,
- self.clock.time_msec()
+ self.failing_since
)
- if self.failing_since:
- self.failing_since = None
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since)
else:
- if not self.failing_since:
- self.failing_since = self.clock.time_msec()
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since
- )
-
- if (self.failing_since and
- self.failing_since <
- self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
- # we really only give up so that if the URL gets
- # fixed, we don't suddenly deliver a load
- # of old notifications.
- logger.warn("Giving up on a notification to user %s, "
- "pushkey %s",
- self.user_name, self.pushkey)
- self.backoff_delay = Pusher.INITIAL_BACKOFF
- self.last_token = chunk['end']
- self.store.update_pusher_last_token(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.last_token
- )
-
- self.failing_since = None
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since
- )
- else:
- logger.warn("Failed to dispatch push for user %s "
- "(failing for %dms)."
- "Trying again in %dms",
- self.user_name,
- self.clock.time_msec() - self.failing_since,
- self.backoff_delay)
- yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
- self.backoff_delay *= 2
- if self.backoff_delay > Pusher.MAX_BACKOFF:
- self.backoff_delay = Pusher.MAX_BACKOFF
+ logger.warn("Failed to dispatch push for user %s "
+ "(failing for %dms)."
+ "Trying again in %dms",
+ self.user_name,
+ self.clock.time_msec() - self.failing_since,
+ self.backoff_delay)
+ yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
+ self.backoff_delay *= 2
+ if self.backoff_delay > Pusher.MAX_BACKOFF:
+ self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f3d1cf5c5f..1f015a7f2e 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
]
},
{
- 'rule_id': 'global/override/.m.rule.contains_display_name',
+ 'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c76e16c28b..534c4a6698 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -31,6 +31,8 @@ REQUIREMENTS = {
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
+ "blist": ["blist"],
+ "pysaml2": ["saml2"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b2257b749d..998d4d44c6 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,14 +20,32 @@ from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
import simplejson as json
+import urllib
+
+import logging
+from saml2 import BINDING_HTTP_POST
+from saml2 import config
+from saml2.client import Saml2Client
+
+
+logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
+ SAML2_TYPE = "m.login.saml2"
+
+ def __init__(self, hs):
+ super(LoginRestServlet, self).__init__(hs)
+ self.idp_redirect_url = hs.config.saml2_idp_redirect_url
+ self.saml2_enabled = hs.config.saml2_enabled
def on_GET(self, request):
- return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
+ flows = [{"type": LoginRestServlet.PASS_TYPE}]
+ if self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SAML2_TYPE})
+ return (200, {"flows": flows})
def on_OPTIONS(self, request):
return (200, {})
@@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
+ elif self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
+ relay_state = ""
+ if "relay_state" in login_submission:
+ relay_state = "&RelayState="+urllib.quote(
+ login_submission["relay_state"])
+ result = {
+ "uri": "%s%s" % (self.idp_redirect_url, relay_state)
+ }
+ defer.returnValue((200, result))
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
@@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
)
+class SAML2RestServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/login/saml2")
+
+ def __init__(self, hs):
+ super(SAML2RestServlet, self).__init__(hs)
+ self.sp_config = hs.config.saml2_config_path
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ saml2_auth = None
+ try:
+ conf = config.SPConfig()
+ conf.load_file(self.sp_config)
+ SP = Saml2Client(conf)
+ saml2_auth = SP.parse_authn_request_response(
+ request.args['SAMLResponse'][0], BINDING_HTTP_POST)
+ except Exception, e: # Not authenticated
+ logger.exception(e)
+ if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
+ username = saml2_auth.name_id.text
+ handler = self.handlers.registration_handler
+ (user_id, token) = yield handler.register_saml2(username)
+ # Forward to the RelayState callback along with ava
+ if 'RelayState' in request.args:
+ request.redirect(urllib.unquote(
+ request.args['RelayState'][0]) +
+ '?status=authenticated&access_token=' +
+ token + '&user_id=' + user_id + '&ava=' +
+ urllib.quote(json.dumps(saml2_auth.ava)))
+ request.finish()
+ defer.returnValue(None)
+ defer.returnValue((200, {"status": "authenticated",
+ "user_id": user_id, "token": token,
+ "ava": saml2_auth.ava}))
+ elif 'RelayState' in request.args:
+ request.redirect(urllib.unquote(
+ request.args['RelayState'][0]) +
+ '?status=not_authenticated')
+ request.finish()
+ defer.returnValue(None)
+ defer.returnValue((200, {"status": "not_authenticated"}))
+
+
def _parse_json(request):
try:
content = json.loads(request.content.read())
@@ -106,4 +177,6 @@ def _parse_json(request):
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ if hs.config.saml2_enabled:
+ SAML2RestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0346afb1b4..b4a70cba99 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"]
+ # make sure it looks like a user ID; it'll throw if it's invalid.
+ UserID.from_string(state_key)
if membership_action == "kick":
membership_action = "leave"
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index d933fea18a..b861069b89 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -39,10 +39,10 @@ class HttpTransactionStore(object):
A tuple of (HTTP response code, response content) or None.
"""
try:
- logger.debug("get_response Key: %s TxnId: %s", key, txn_id)
+ logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id:
- logger.info("get_response: Returning a response for %s", key)
+ logger.info("get_response: Returning a response for %s", txn_id)
return response
except KeyError:
pass
@@ -58,7 +58,7 @@ class HttpTransactionStore(object):
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
- logger.debug("store_response Key: %s TxnId: %s", key, txn_id)
+ logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response):
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 7d1aff4307..33f961e898 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -18,7 +18,9 @@ from . import (
filter,
account,
register,
- auth
+ auth,
+ receipts,
+ keys,
)
from synapse.http.server import JsonResource
@@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
+ receipts.register_servlets(hs, client_resource)
+ keys.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
new file mode 100644
index 0000000000..5f3a6207b5
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -0,0 +1,276 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from syutil.jsonutil import encode_canonical_json
+
+from ._base import client_v2_pattern
+
+import simplejson as json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class KeyUploadServlet(RestServlet):
+ """
+ POST /keys/upload/<device_id> HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "device_keys": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha256",
+ ]
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures:" {
+ "<user_id>" {
+ "<algorithm>:<device_id>": "<signature_base64>"
+ } } },
+ "one_time_keys": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ },
+ }
+ """
+ PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
+
+ def __init__(self, hs):
+ super(KeyUploadServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ user_id = auth_user.to_string()
+ # TODO: Check that the device_id matches that in the authentication
+ # or derive the device_id from the authentication instead.
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ time_now = self.clock.time_msec()
+
+ # TODO: Validate the JSON to make sure it has the right keys.
+ device_keys = body.get("device_keys", None)
+ if device_keys:
+ logger.info(
+ "Updating device_keys for device %r for user %r at %d",
+ device_id, auth_user, time_now
+ )
+ # TODO: Sign the JSON with the server key
+ yield self.store.set_e2e_device_keys(
+ user_id, device_id, time_now,
+ encode_canonical_json(device_keys)
+ )
+
+ one_time_keys = body.get("one_time_keys", None)
+ if one_time_keys:
+ logger.info(
+ "Adding %d one_time_keys for device %r for user %r at %d",
+ len(one_time_keys), device_id, user_id, time_now
+ )
+ key_list = []
+ for key_id, key_json in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, encode_canonical_json(key_json)
+ ))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, key_list
+ )
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ user_id = auth_user.to_string()
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class KeyQueryServlet(RestServlet):
+ """
+ GET /keys/query/<user_id> HTTP/1.1
+
+ GET /keys/query/<user_id>/<device_id> HTTP/1.1
+
+ POST /keys/query HTTP/1.1
+ Content-Type: application/json
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ } }
+
+ HTTP/1.1 200 OK
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "user_id": "<user_id>", // Duplicated to be signed
+ "device_id": "<device_id>", // Duplicated to be signed
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [ // List of supported algorithms
+ "m.olm.curve25519-aes-sha256",
+ ],
+ "keys": { // Must include a ed25519 signing key
+ "<algorithm>:<key_id>": "<key_base64>",
+ },
+ "signatures:" {
+ // Must be signed with device's ed25519 key
+ "<user_id>/<device_id>": {
+ "<algorithm>:<key_id>": "<signature_base64>"
+ }
+ // Must be signed by this server.
+ "<server_name>": {
+ "<algorithm>:<key_id>": "<signature_base64>"
+ } } } } } }
+ """
+
+ PATTERN = client_v2_pattern(
+ "/keys/query(?:"
+ "/(?P<user_id>[^/]*)(?:"
+ "/(?P<device_id>[^/]*)"
+ ")?"
+ ")?"
+ )
+
+ def __init__(self, hs):
+ super(KeyQueryServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id, device_id):
+ logger.debug("onPOST")
+ yield self.auth.get_user_by_req(request)
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ query = []
+ for user_id, device_ids in body.get("device_keys", {}).items():
+ if not device_ids:
+ query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ query.append((user_id, device_id))
+ results = yield self.store.get_e2e_device_keys(query)
+ defer.returnValue(self.json_result(request, results))
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ auth_user_id = auth_user.to_string()
+ if not user_id:
+ user_id = auth_user_id
+ if not device_id:
+ device_id = None
+ # Returns a map of user_id->device_id->json_bytes.
+ results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
+ defer.returnValue(self.json_result(request, results))
+
+ def json_result(self, request, results):
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, json_bytes in device_keys.items():
+ json_result.setdefault(user_id, {})[device_id] = json.loads(
+ json_bytes
+ )
+ return (200, {"device_keys": json_result})
+
+
+class OneTimeKeyServlet(RestServlet):
+ """
+ GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": "<algorithm>"
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ """
+ PATTERN = client_v2_pattern(
+ "/keys/claim(?:/?|(?:/"
+ "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
+ ")?)"
+ )
+
+ def __init__(self, hs):
+ super(OneTimeKeyServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, device_id, algorithm):
+ yield self.auth.get_user_by_req(request)
+ results = yield self.store.claim_e2e_one_time_keys(
+ [(user_id, device_id, algorithm)]
+ )
+ defer.returnValue(self.json_result(request, results))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id, device_id, algorithm):
+ yield self.auth.get_user_by_req(request)
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ query = []
+ for user_id, device_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ query.append((user_id, device_id, algorithm))
+ results = yield self.store.claim_e2e_one_time_keys(query)
+ defer.returnValue(self.json_result(request, results))
+
+ def json_result(self, request, results):
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+ return (200, {"one_time_keys": json_result})
+
+
+def register_servlets(hs, http_server):
+ KeyUploadServlet(hs).register(http_server)
+ KeyQueryServlet(hs).register(http_server)
+ OneTimeKeyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
new file mode 100644
index 0000000000..40406e2ede
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from ._base import client_v2_pattern
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptRestServlet(RestServlet):
+ PATTERN = client_v2_pattern(
+ "/rooms/(?P<room_id>[^/]*)"
+ "/receipt/(?P<receipt_type>[^/]*)"
+ "/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(ReceiptRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.receipts_handler = hs.get_handlers().receipts_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id, receipt_type, event_id):
+ user, client = yield self.auth.get_user_by_req(request)
+
+ yield self.receipts_handler.received_client_receipt(
+ room_id,
+ receipt_type,
+ user_id=user.to_string(),
+ event_id=event_id
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 72dfb876c5..b5926f9ca6 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
-from ._base import client_v2_pattern, parse_request_allow_empty
+from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
import hmac
@@ -55,21 +55,55 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
-
- body = parse_request_allow_empty(request)
- if 'password' not in body:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
+ body = parse_json_dict_from_request(request)
+
+ # we do basic sanity checks here because the auth layer will store these
+ # in sessions. Pull out the username/password provided to us.
+ desired_password = None
+ if 'password' in body:
+ if (not isinstance(body['password'], basestring) or
+ len(body['password']) > 512):
+ raise SynapseError(400, "Invalid password")
+ desired_password = body["password"]
+
+ desired_username = None
if 'username' in body:
+ if (not isinstance(body['username'], basestring) or
+ len(body['username']) > 512):
+ raise SynapseError(400, "Invalid username")
desired_username = body['username']
- yield self.registration_handler.check_username(desired_username)
- is_using_shared_secret = False
- is_application_server = False
-
- service = None
+ appservice = None
if 'access_token' in request.args:
- service = yield self.auth.get_appservice_by_req(request)
+ appservice = yield self.auth.get_appservice_by_req(request)
+
+ # fork off as soon as possible for ASes and shared secret auth which
+ # have completely different registration flows to normal users
+
+ # == Application Service Registration ==
+ if appservice:
+ result = yield self._do_appservice_registration(
+ desired_username, request.args["access_token"][0]
+ )
+ defer.returnValue((200, result)) # we throw for non 200 responses
+ return
+
+ # == Shared Secret Registration == (e.g. create new user scripts)
+ if 'mac' in body:
+ # FIXME: Should we really be determining if this is shared secret
+ # auth based purely on the 'mac' key?
+ result = yield self._do_shared_secret_registration(
+ desired_username, desired_password, body["mac"]
+ )
+ defer.returnValue((200, result)) # we throw for non 200 responses
+ return
+
+ # == Normal User Registration == (everyone else)
+ if self.hs.config.disable_registration:
+ raise SynapseError(403, "Registration has been disabled")
+
+ if desired_username is not None:
+ yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha:
flows = [
@@ -82,39 +116,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
- result = None
- if service:
- is_application_server = True
- params = body
- elif 'mac' in body:
- # Check registration-specific shared secret auth
- if 'username' not in body:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
- self._check_shared_secret_auth(
- body['username'], body['mac']
- )
- is_using_shared_secret = True
- params = body
- else:
- authed, result, params = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
- )
-
- if not authed:
- defer.returnValue((401, result))
-
- can_register = (
- not self.hs.config.disable_registration
- or is_application_server
- or is_using_shared_secret
+ authed, result, params = yield self.auth_handler.check_auth(
+ flows, body, self.hs.get_ip_from_request(request)
)
- if not can_register:
- raise SynapseError(403, "Registration has been disabled")
+ if not authed:
+ defer.returnValue((401, result))
+ return
+
+ # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
- desired_username = params['username'] if 'username' in params else None
- new_password = params['password']
+ raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
+
+ desired_username = params.get("username", None)
+ new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
@@ -147,18 +162,21 @@ class RegisterRestServlet(RestServlet):
else:
logger.info("bind_email not specified: not binding email")
- result = {
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- }
-
+ result = self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
- def _check_shared_secret_auth(self, username, mac):
+ @defer.inlineCallbacks
+ def _do_appservice_registration(self, username, as_token):
+ (user_id, token) = yield self.registration_handler.appservice_register(
+ username, as_token
+ )
+ defer.returnValue(self._create_registration_details(user_id, token))
+
+ @defer.inlineCallbacks
+ def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -174,13 +192,23 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1,
).hexdigest()
- if compare_digest(want_mac, got_mac):
- return True
- else:
+ if not compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
+ (user_id, token) = yield self.registration_handler.register(
+ localpart=username, password=password
+ )
+ defer.returnValue(self._create_registration_details(user_id, token))
+
+ def _create_registration_details(self, user_id, token):
+ return {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ }
+
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index 4af5f73878..84e1961a21 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -15,20 +15,23 @@
from .thumbnailer import Thumbnailer
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
-from twisted.internet import defer
+from twisted.internet import defer, threads
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
+from synapse.util.stringutils import is_ascii
import os
+import cgi
import logging
logger = logging.getLogger(__name__)
@@ -36,8 +39,13 @@ logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
- server_name, media_id = request.postpath
- return (server_name, media_id)
+ # This allows users to append e.g. /test.png to the URL. Useful for
+ # clients that parse the URL to see content type.
+ server_name, media_id = request.postpath[:2]
+ if len(request.postpath) > 2 and is_ascii(request.postpath[-1]):
+ return server_name, media_id, request.postpath[-1]
+ else:
+ return server_name, media_id, None
except:
raise SynapseError(
404,
@@ -52,7 +60,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -127,12 +135,21 @@ class BaseMediaResource(Resource):
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = params.get("filename", None)
+ if upload_name and not is_ascii(upload_name):
+ upload_name = None
+ else:
+ upload_name = None
+
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
- upload_name=None,
+ upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
@@ -143,7 +160,7 @@ class BaseMediaResource(Resource):
media_info = {
"media_type": media_type,
"media_length": length,
- "upload_name": None,
+ "upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
@@ -156,11 +173,16 @@ class BaseMediaResource(Resource):
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path,
- file_size=None):
+ file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (upload_name.encode("utf-8"),),
+ )
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
@@ -222,43 +244,52 @@ class BaseMediaResource(Resource):
)
return
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
+ local_thumbnails = []
+
+ def generate_thumbnails():
+ scales = set()
+ crops = set()
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ scales.add((
+ min(m_width, t_width), min(m_height, t_height), r_type,
+ ))
+ elif r_method == "crop":
+ crops.add((r_width, r_height, r_type))
+
+ for t_width, t_height, t_type in scales:
+ t_method = "scale"
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+
+ local_thumbnails.append((
+ media_id, t_width, t_height, t_type, t_method, t_len
))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- yield self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ for t_width, t_height, t_type in crops:
+ if (t_width, t_height, t_type) in scales:
+ # If the aspect ratio of the cropped thumbnail matches a purely
+ # scaled one then there is no point in calculating a separate
+ # thumbnail.
+ continue
+ t_method = "crop"
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ local_thumbnails.append((
+ media_id, t_width, t_height, t_type, t_method, t_len
+ ))
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- yield self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ yield threads.deferToThread(generate_thumbnails)
+
+ for l in local_thumbnails:
+ yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
@@ -273,57 +304,65 @@ class BaseMediaResource(Resource):
if not requirements:
return
+ remote_thumbnails = []
+
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
- )
- return
-
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
+ def generate_thumbnails():
+ if m_width * m_height >= self.max_image_pixels:
+ logger.info(
+ "Image too large to thumbnail %r x %r > %r",
+ m_width, m_height, self.max_image_pixels
+ )
+ return
+
+ scales = set()
+ crops = set()
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ scales.add((
+ min(m_width, t_width), min(m_height, t_height), r_type,
+ ))
+ elif r_method == "crop":
+ crops.add((r_width, r_height, r_type))
+
+ for t_width, t_height, t_type in scales:
+ t_method = "scale"
+ t_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ remote_thumbnails.append([
+ server_name, media_id, file_id,
+ t_width, t_height, t_type, t_method, t_len
+ ])
+
+ for t_width, t_height, t_type in crops:
+ if (t_width, t_height, t_type) in scales:
+ # If the aspect ratio of the cropped thumbnail matches a purely
+ # scaled one then there is no point in calculating a separate
+ # thumbnail.
+ continue
+ t_method = "crop"
+ t_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ remote_thumbnails.append([
+ server_name, media_id, file_id,
+ t_width, t_height, t_type, t_method, t_len
+ ])
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- )
+ yield threads.deferToThread(generate_thumbnails)
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- )
+ for r in remote_thumbnails:
+ yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 0fe6abf647..ab384e5388 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- server_name, media_id = parse_media_id(request)
+ server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id)
+ yield self._respond_local_file(request, media_id, name)
else:
- yield self._respond_remote_file(request, server_name, media_id)
+ yield self._respond_remote_file(
+ request, server_name, media_id, name
+ )
@defer.inlineCallbacks
- def _respond_local_file(self, request, media_id):
+ def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
@@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(
- request, media_type, file_path, media_length
+ request, media_type, file_path, media_length,
+ upload_name=upload_name,
)
@defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id):
+ def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
+ upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield self._respond_with_file(
- request, media_type, file_path, media_length
+ request, media_type, file_path, media_length,
+ upload_name=upload_name,
)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 1dadd880b2..61f88e486e 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -36,7 +36,7 @@ class ThumbnailResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- server_name, media_id = parse_media_id(request)
+ server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width")
height = parse_integer(request, "height")
method = parse_string(request, "method", "scale")
@@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
info_list.append((
- aspect_quality, size_quality, type_quality,
+ aspect_quality, min_quality, size_quality, type_quality,
length_quality, info
))
if info_list:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 28404f2b7b..1e965c363a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -82,7 +82,7 @@ class Thumbnailer(object):
def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO()
- output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70)
+ output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index cc571976a5..cdd1d44e07 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,7 +15,7 @@
from synapse.http.server import respond_with_json, request_handler
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import random_string, is_ascii
from synapse.api.errors import SynapseError
from twisted.web.server import NOT_DONE_YET
@@ -84,6 +84,12 @@ class UploadResource(BaseMediaResource):
code=413,
)
+ upload_name = request.args.get("filename", None)
+ if upload_name:
+ upload_name = upload_name[0]
+ if upload_name and not is_ascii(upload_name):
+ raise SynapseError(400, "filename must be ascii")
+
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
@@ -99,7 +105,7 @@ class UploadResource(BaseMediaResource):
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
- media_type, None, request.content.read(),
+ media_type, upload_name, request.content.read(),
content_length, auth_user
)
diff --git a/synapse/server.py b/synapse/server.py
index 8b3dc675cc..4d1fb1cbf6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -132,16 +132,8 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request):
- # May be an X-Forwarding-For header depending on config
- ip_addr = request.getClientIP()
- if self.config.captcha_ip_origin_is_x_forwarded:
- # use the header
- if request.requestHeaders.hasHeader("X-Forwarded-For"):
- ip_addr = request.requestHeaders.getRawHeaders(
- "X-Forwarded-For"
- )[0]
-
- return ip_addr
+ # X-Forwarded-For is handled by our custom request type.
+ return request.getClientIP()
def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname
diff --git a/synapse/state.py b/synapse/state.py
index 9dddb77d5b..80da90a72c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def compute_event_context(self, event, old_state=None):
+ def compute_event_context(self, event, old_state=None, outlier=False):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@@ -119,9 +119,23 @@ class StateHandler(object):
Returns:
an EventContext
"""
+ yield run_on_reactor()
+
context = EventContext()
- yield run_on_reactor()
+ if outlier:
+ # If this is an outlier, then we know it shouldn't have any current
+ # state. Certainly store.get_current_state won't return any, and
+ # persisting the event won't store the state group.
+ if old_state:
+ context.current_state = {
+ (s.type, s.state_key): s for s in old_state
+ }
+ else:
+ context.current_state = {}
+ context.prev_state_events = []
+ context.state_group = None
+ defer.returnValue(context)
if old_state:
context.current_state = {
@@ -155,10 +169,6 @@ class StateHandler(object):
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
- prev_state = yield self.store.add_event_hashes(
- prev_state
- )
-
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 75af44d787..c6ce65b4cc 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,6 +37,9 @@ from .rejections import RejectionsStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
+from .end_to_end_keys import EndToEndKeyStore
+
+from .receipts import ReceiptsStore
import fnmatch
@@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 19
+SCHEMA_VERSION = 21
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
PushRuleStore,
ApplicationServiceTransactionStore,
EventsStore,
+ ReceiptsStore,
+ EndToEndKeyStore,
):
def __init__(self, hs):
@@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip)
try:
- last_seen = self.client_ip_last_seen.get(*key)
+ last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
@@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
- self.client_ip_last_seen.prefill(*key + (now,))
+ self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@@ -348,7 +353,12 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
- module.run_upgrade(cur)
+ module.run_upgrade(cur, database_engine)
+ elif ext == ".pyc":
+ # Sometimes .pyc files turn up anyway even though we've
+ # disabled their generation; e.g. from distribution package
+ # installers. Silently skip it
+ pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..73eea157a4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
+import inspect
import sys
import time
import threading
@@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
+ return val
cache_counter.inc_misses(self.name)
- raise KeyError()
- def update(self, sequence, *args):
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ self.prefill(key, value)
+ def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
- self.cache[keyargs] = value
+ self.cache[key] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, key):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(keyargs, None)
+ self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@@ -127,9 +131,12 @@ class Cache(object):
self.cache.clear()
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def wrap(orig):
- cache = Cache(
- name=orig.__name__,
- max_entries=max_entries,
- keylen=num_args,
- lru=lru,
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ inlineCallbacks=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ self.cache = Cache(
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
)
- @functools.wraps(orig)
- @defer.inlineCallbacks
- def wrapped(self, *keyargs):
+ def __get__(self, obj, objtype=None):
+
+ @functools.wraps(self.orig)
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result = cache.get(*keyargs)
+ cached_result_d = self.cache.get(cache_key)
+
+ observer = cached_result_d.observe()
if DEBUG_CACHES:
- actual_result = yield orig(self, *keyargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, cache_key,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
+
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ self.cache.invalidate(cache_key)
+ return f
+
+ ret.addErrback(onErr)
+
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
- ret = yield orig(self, *keyargs)
+ return ret.observe()
- cache.update(sequence, *keyargs + (ret,))
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
- defer.returnValue(ret)
+ obj.__dict__[self.orig.__name__] = wrapped
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
return wrapped
- return wrap
+
+def cached(max_entries=1000, num_args=1, lru=True):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
+
+
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
class LoggingTransaction(object):
@@ -312,13 +380,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator()
+ self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 2b2bdf8615..f3947bbe89 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
},
desc="create_room_alias_association",
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
new file mode 100644
index 0000000000..325740d7d0
--- /dev/null
+++ b/synapse/storage/end_to_end_keys.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class EndToEndKeyStore(SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
+ return self._simple_upsert(
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ )
+
+ def get_e2e_device_keys(self, query_list):
+ """Fetch a list of device keys.
+ Args:
+ query_list(list): List of pairs of user_ids and device_ids.
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key json byte strings.
+ """
+ def _get_e2e_device_keys(txn):
+ result = {}
+ for user_id, device_id in query_list:
+ user_result = result.setdefault(user_id, {})
+ keyvalues = {"user_id": user_id}
+ if device_id:
+ keyvalues["device_id"] = device_id
+ rows = self._simple_select_list_txn(
+ txn, table="e2e_device_keys_json",
+ keyvalues=keyvalues,
+ retcols=["device_id", "key_json"]
+ )
+ for row in rows:
+ user_result[row["device_id"]] = row["key_json"]
+ return result
+ return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+ def _add_e2e_one_time_keys(txn):
+ for (algorithm, key_id, json_bytes) in key_list:
+ self._simple_upsert_txn(
+ txn, table="e2e_one_time_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ )
+ return self.runInteraction(
+ "add_e2e_one_time_keys", _add_e2e_one_time_keys
+ )
+
+ def count_e2e_one_time_keys(self, user_id, device_id):
+ """ Count the number of one time keys the server has for a device
+ Returns:
+ Dict mapping from algorithm to number of keys for that algorithm.
+ """
+ def _count_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ?"
+ " GROUP BY algorithm"
+ )
+ txn.execute(sql, (user_id, device_id))
+ result = {}
+ for algorithm, key_count in txn.fetchall():
+ result[algorithm] = key_count
+ return result
+ return self.runInteraction(
+ "count_e2e_one_time_keys", _count_e2e_one_time_keys
+ )
+
+ def claim_e2e_one_time_keys(self, query_list):
+ """Take a list of one time keys out of the database"""
+ def _claim_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
+ result = {}
+ delete = []
+ for user_id, device_id, algorithm in query_list:
+ user_result = result.setdefault(user_id, {})
+ device_result = user_result.setdefault(device_id, {})
+ txn.execute(sql, (user_id, device_id, algorithm))
+ for key_id, key_json in txn.fetchall():
+ device_result[algorithm + ":" + key_id] = key_json
+ delete.append((user_id, device_id, algorithm, key_id))
+ sql = (
+ "DELETE FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " AND key_id = ?"
+ )
+ for user_id, device_id, algorithm, key_id in delete:
+ txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ return result
+ return self.runInteraction(
+ "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+ )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 1ba073884b..910b6598a7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -49,14 +49,22 @@ class EventFederationStore(SQLBaseStore):
results = set()
base_sql = (
- "SELECT auth_id FROM event_auth WHERE event_id = ?"
+ "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
)
front = set(event_ids)
while front:
new_front = set()
- for f in front:
- txn.execute(base_sql, (f,))
+ front_list = list(front)
+ chunks = [
+ front_list[x:x+100]
+ for x in xrange(0, len(front), 100)
+ ]
+ for chunk in chunks:
+ txn.execute(
+ base_sql % (",".join(["?"] * len(chunk)),),
+ chunk
+ )
new_front.update([r[0] for r in txn.fetchall()])
new_front -= results
@@ -274,8 +282,7 @@ class EventFederationStore(SQLBaseStore):
},
)
- def _handle_prev_events(self, txn, outlier, event_id, prev_events,
- room_id):
+ def _handle_mult_prev_events(self, txn, events):
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@@ -285,70 +292,77 @@ class EventFederationStore(SQLBaseStore):
table="event_edges",
values=[
{
- "event_id": event_id,
+ "event_id": ev.event_id,
"prev_event_id": e_id,
- "room_id": room_id,
+ "room_id": ev.room_id,
"is_state": False,
}
- for e_id, _ in prev_events
+ for ev in events
+ for e_id, _ in ev.prev_events
],
)
- # Update the extremities table if this is not an outlier.
- if not outlier:
- for e_id, _ in prev_events:
- # TODO (erikj): This could be done as a bulk insert
- self._simple_delete_txn(
- txn,
- table="event_forward_extremities",
- keyvalues={
- "event_id": e_id,
- "room_id": room_id,
- }
- )
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
- # We only insert as a forward extremity the new event if there are
- # no other events that reference it as a prev event
- query = (
- "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
- )
+ for room_id, room_events in events_by_room.items():
+ prevs = [
+ e_id for ev in room_events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ]
+ if prevs:
+ txn.execute(
+ "DELETE FROM event_forward_extremities"
+ " WHERE room_id = ?"
+ " AND event_id in (%s)" % (
+ ",".join(["?"] * len(prevs)),
+ ),
+ [room_id] + prevs,
+ )
- txn.execute(query, (event_id,))
+ query = (
+ "INSERT INTO event_forward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
+ " )"
+ )
- if not txn.fetchone():
- query = (
- "INSERT INTO event_forward_extremities"
- " (event_id, room_id)"
- " VALUES (?, ?)"
- )
+ txn.executemany(
+ query,
+ [(ev.event_id, ev.room_id, ev.event_id) for ev in events]
+ )
- txn.execute(query, (event_id, room_id))
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
- txn.executemany(query, [
- (e_id, room_id, e_id, room_id, e_id, room_id, False)
- for e_id, _ in prev_events
- ])
+ txn.executemany(query, [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ])
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.execute(query, (event_id, room_id))
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [(ev.event_id, ev.room_id) for ev in events]
+ )
+ for room_id in events_by_room:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, room_id
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_backfill_events(self, room_id, event_list, limit):
@@ -400,10 +414,12 @@ class EventFederationStore(SQLBaseStore):
keyvalues={
"event_id": event_id,
},
- retcol="depth"
+ retcol="depth",
+ allow_none=True,
)
- queue.put((-depth, event_id))
+ if depth:
+ queue.put((-depth, event_id))
while not queue.empty() and len(event_results) < limit:
try:
@@ -489,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
- txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
+ txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 20a8d81794..5b64918024 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
-from synapse.crypto.event_signing import compute_event_reference_hash
-from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json
from contextlib import contextmanager
@@ -47,6 +45,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
+ def persist_events(self, events_and_contexts, backfilled=False,
+ is_new_state=True):
+ if not events_and_contexts:
+ return
+
+ if backfilled:
+ if not self.min_token_deferred.called:
+ yield self.min_token_deferred
+ start = self.min_token - 1
+ self.min_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_token, -1)
+
+ @contextmanager
+ def stream_ordering_manager():
+ yield stream_orderings
+ stream_ordering_manager = stream_ordering_manager()
+ else:
+ stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
+ self, len(events_and_contexts)
+ )
+
+ with stream_ordering_manager as stream_orderings:
+ for (event, _), stream in zip(events_and_contexts, stream_orderings):
+ event.internal_metadata.stream_ordering = stream
+
+ chunks = [
+ events_and_contexts[x:x+100]
+ for x in xrange(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+ yield self.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=chunk,
+ backfilled=backfilled,
+ is_new_state=is_new_state,
+ )
+
+ @defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
@@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
try:
with stream_ordering_manager as stream_ordering:
+ event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
- stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
@@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
- stream_ordering=None, is_new_state=True,
- current_state=None):
-
- # Remove the any existing cache entries for the event_id
- txn.call_after(self._invalidate_get_event_cache, event.event_id)
-
+ is_new_state=True, current_state=None):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
@@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore):
}
)
- outlier = event.internal_metadata.is_outlier()
+ return self._persist_events_txn(
+ txn,
+ [(event, context)],
+ backfilled=backfilled,
+ is_new_state=is_new_state,
+ )
- if not outlier:
- self._update_min_depth_for_room_txn(
- txn,
- event.room_id,
- event.depth
+ @log_function
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ is_new_state=True):
+
+ # Remove the any existing cache entries for the event_ids
+ for event, _ in events_and_contexts:
+ txn.call_after(self._invalidate_get_event_cache, event.event_id)
+
+ depth_updates = {}
+ for event, _ in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+ depth_updates[event.room_id] = max(
+ event.depth, depth_updates.get(event.room_id, event.depth)
)
- have_persisted = self._simple_select_one_txn(
- txn,
- table="events",
- keyvalues={"event_id": event.event_id},
- retcols=["event_id", "outlier"],
- allow_none=True,
+ for room_id, depth in depth_updates.items():
+ self._update_min_depth_for_room_txn(txn, room_id, depth)
+
+ txn.execute(
+ "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
+ ",".join(["?"] * len(events_and_contexts)),
+ ),
+ [event.event_id for event, _ in events_and_contexts]
)
+ have_persisted = {
+ event_id: outlier
+ for event_id, outlier in txn.fetchall()
+ }
+
+ event_map = {}
+ to_remove = set()
+ for event, context in events_and_contexts:
+ # Handle the case of the list including the same event multiple
+ # times. The tricky thing here is when they differ by whether
+ # they are an outlier.
+ if event.event_id in event_map:
+ other = event_map[event.event_id]
+
+ if not other.internal_metadata.is_outlier():
+ to_remove.add(event)
+ continue
+ elif not event.internal_metadata.is_outlier():
+ to_remove.add(event)
+ continue
+ else:
+ to_remove.add(other)
- metadata_json = encode_json(
- event.internal_metadata.get_dict(),
- using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
-
- # If we have already persisted this event, we don't need to do any
- # more processing.
- # The processing above must be done on every call to persist event,
- # since they might not have happened on previous calls. For example,
- # if we are persisting an event that we had persisted as an outlier,
- # but is no longer one.
- if have_persisted:
- if not outlier and have_persisted["outlier"]:
- self._store_state_groups_txn(txn, event, context)
+ event_map[event.event_id] = event
+
+ if event.event_id not in have_persisted:
+ continue
+
+ to_remove.add(event)
+
+ outlier_persisted = have_persisted[event.event_id]
+ if not event.internal_metadata.is_outlier() and outlier_persisted:
+ self._store_state_groups_txn(
+ txn, event, context,
+ )
+
+ metadata_json = encode_json(
+ event.internal_metadata.get_dict(),
+ using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8")
sql = (
"UPDATE event_json SET internal_metadata = ?"
@@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore):
sql,
(False, event.event_id,)
)
- return
-
- if not outlier:
- self._store_state_groups_txn(txn, event, context)
- self._handle_prev_events(
- txn,
- outlier=outlier,
- event_id=event.event_id,
- prev_events=event.prev_events,
- room_id=event.room_id,
+ events_and_contexts = filter(
+ lambda ec: ec[0] not in to_remove,
+ events_and_contexts
)
- if event.type == EventTypes.Member:
- self._store_room_member_txn(txn, event)
- elif event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
-
- event_dict = {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
+ if not events_and_contexts:
+ return
- self._simple_insert_txn(
+ self._store_mult_state_groups_txn(txn, [
+ (event, context)
+ for event, context in events_and_contexts
+ if not event.internal_metadata.is_outlier()
+ ])
+
+ self._handle_mult_prev_events(
txn,
- table="event_json",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": metadata_json,
- "json": encode_json(
- event_dict, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8"),
- },
+ events=[event for event, _ in events_and_contexts],
)
- content = encode_json(
- event.content, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
-
- vals = {
- "topological_ordering": event.depth,
- "event_id": event.event_id,
- "type": event.type,
- "room_id": event.room_id,
- "content": content,
- "processed": True,
- "outlier": outlier,
- "depth": event.depth,
- }
+ for event, _ in events_and_contexts:
+ if event.type == EventTypes.Name:
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ self._store_redaction(txn, event)
- unrec = {
- k: v
- for k, v in event.get_dict().items()
- if k not in vals.keys() and k not in [
- "redacted",
- "redacted_because",
- "signatures",
- "hashes",
- "prev_events",
+ self._store_room_members_txn(
+ txn,
+ [
+ event
+ for event, _ in events_and_contexts
+ if event.type == EventTypes.Member
]
- }
+ )
- vals["unrecognized_keys"] = encode_json(
- unrec, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
+ def event_dict(event):
+ return {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in [
+ "redacted",
+ "redacted_because",
+ ]
+ }
- sql = (
- "INSERT INTO events"
- " (stream_ordering, topological_ordering, event_id, type,"
- " room_id, content, processed, outlier, depth)"
- " VALUES (?,?,?,?,?,?,?,?,?)"
+ self._simple_insert_many_txn(
+ txn,
+ table="event_json",
+ values=[
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "internal_metadata": encode_json(
+ event.internal_metadata.get_dict(),
+ using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ "json": encode_json(
+ event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ }
+ for event, _ in events_and_contexts
+ ],
)
- txn.execute(
- sql,
- (
- stream_ordering, event.depth, event.event_id, event.type,
- event.room_id, content, True, outlier, event.depth
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="events",
+ values=[
+ {
+ "stream_ordering": event.internal_metadata.stream_ordering,
+ "topological_ordering": event.depth,
+ "depth": event.depth,
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "processed": True,
+ "outlier": event.internal_metadata.is_outlier(),
+ "content": encode_json(
+ event.content, using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ }
+ for event, _ in events_and_contexts
+ ],
)
if context.rejected:
@@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, context.rejected
)
- for hash_alg, hash_base64 in event.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_event_content_hash_txn(
- txn, event.event_id, hash_alg, hash_bytes,
- )
-
- for prev_event_id, prev_hashes in event.prev_events:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg,
- hash_bytes
- )
-
self._simple_insert_many_txn(
txn,
table="event_auth",
@@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
"room_id": event.room_id,
"auth_id": auth_id,
}
+ for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
],
)
- (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
- self._store_event_reference_hash_txn(
- txn, event.event_id, ref_alg, ref_hash_bytes
+ self._store_event_reference_hashes_txn(
+ txn, [event for event, _ in events_and_contexts]
)
- if event.is_state():
+ state_events_and_contexts = filter(
+ lambda i: i[0].is_state(),
+ events_and_contexts,
+ )
+
+ state_values = []
+ for event, context in state_events_and_contexts:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
@@ -337,51 +402,55 @@ class EventsStore(SQLBaseStore):
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
- self._simple_insert_txn(
- txn,
- "state_events",
- vals,
- )
+ state_values.append(vals)
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": event.event_id,
- "prev_event_id": e_id,
- "room_id": event.room_id,
- "is_state": True,
- }
- for e_id, h in event.prev_state
- ],
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="state_events",
+ values=state_values,
+ )
- if is_new_state and not context.rejected:
- txn.call_after(
- self.get_current_state_for_key.invalidate,
- event.room_id, event.type, event.state_key
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": event.event_id,
+ "prev_event_id": prev_id,
+ "room_id": event.room_id,
+ "is_state": True,
+ }
+ for event, _ in state_events_and_contexts
+ for prev_id, _ in event.prev_state
+ ],
+ )
- if (event.type == EventTypes.Name
- or event.type == EventTypes.Aliases):
+ if is_new_state:
+ for event, _ in state_events_and_contexts:
+ if not context.rejected:
txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- event.room_id
+ self.get_current_state_for_key.invalidate,
+ (event.room_id, event.type, event.state_key,)
)
- self._simple_upsert_txn(
- txn,
- "current_state_events",
- keyvalues={
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- values={
- "event_id": event.event_id,
- }
- )
+ if event.type in [EventTypes.Name, EventTypes.Aliases]:
+ txn.call_after(
+ self.get_room_name_and_aliases.invalidate,
+ (event.room_id,)
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ "current_state_events",
+ keyvalues={
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ }
+ )
return
@@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
- self._get_event_cache.invalidate(event_id, check_redacted,
- get_prev_content)
+ self._get_event_cache.invalidate(
+ (event_id, check_redacted, get_prev_content)
+ )
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events:
try:
ret = self._get_event_cache.get(
- event_id, check_redacted, get_prev_content
+ (event_id, check_redacted, get_prev_content,)
)
if allow_rejected or not ret.rejected_reason:
@@ -736,7 +806,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event(
redaction_id,
- check_redacted=False
+ check_redacted=False,
+ allow_none=True,
)
if because:
@@ -746,12 +817,13 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
+ allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
defer.returnValue(ev)
@@ -808,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
return ev
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..49b8e37cfd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
+ @cachedInlineCallbacks()
+ def get_all_server_verify_keys(self, server_name):
+ rows = yield self._simple_select_list(
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ },
+ retcols=["key_id", "verify_key"],
+ desc="get_all_server_verify_keys",
+ )
+
+ defer.returnValue({
+ row["key_id"]: decode_verify_key_bytes(
+ row["key_id"], str(row["verify_key"])
+ )
+ for row in rows
+ })
+
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
@@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- sql = (
- "SELECT key_id, verify_key FROM server_signature_keys"
- " WHERE server_name = ?"
- " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
- )
-
- rows = yield self._execute_and_decode(
- "get_server_verify_keys", sql, server_name, *key_ids
- )
-
- keys = []
- for row in rows:
- key_id = row["key_id"]
- key_bytes = row["verify_key"]
- key = decode_verify_key_bytes(key_id, str(key_bytes))
- keys.append(key)
- defer.returnValue(keys)
+ keys = yield self.get_all_server_verify_keys(server_name)
+ defer.returnValue({
+ k: keys[k]
+ for k in key_ids
+ if k in keys and keys[k]
+ })
+ @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
@@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- return self._simple_upsert(
+ yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
@@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
+ self.get_all_server_verify_keys.invalidate((server_name,))
+
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
@@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
+ desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index fefcf6bce0..576cf670cc 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid},
desc="del_presence_list",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 4cac118d17..9b88ca7b39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
import logging
@@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
@@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
@@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
- self.get_push_rules_for_user.invalidate(user_name)
- self.get_push_rules_enabled_for_user.invalidate(user_name)
+ self.get_push_rules_for_user.invalidate((user_name,))
+ self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id},
)
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
new file mode 100644
index 0000000000..b79d6683ca
--- /dev/null
+++ b/synapse/storage/receipts.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore, cachedInlineCallbacks
+
+from twisted.internet import defer
+
+from synapse.util import unwrapFirstError
+
+from blist import sorteddict
+import logging
+import ujson as json
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsStore(SQLBaseStore):
+ def __init__(self, hs):
+ super(ReceiptsStore, self).__init__(hs)
+
+ self._receipts_stream_cache = _RoomStreamChangeCache()
+
+ @defer.inlineCallbacks
+ def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ """Get receipts for multiple rooms for sending to clients.
+
+ Args:
+ room_ids (list): List of room_ids.
+ to_key (int): Max stream id to fetch receipts upto.
+ from_key (int): Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ list: A list of receipts.
+ """
+ room_ids = set(room_ids)
+
+ if from_key:
+ room_ids = yield self._receipts_stream_cache.get_rooms_changed(
+ self, room_ids, from_key
+ )
+
+ results = yield defer.gatherResults(
+ [
+ self.get_linearized_receipts_for_room(
+ room_id, to_key, from_key=from_key
+ )
+ for room_id in room_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue([ev for res in results for ev in res])
+
+ @defer.inlineCallbacks
+ def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """Get receipts for a single room for sending to clients.
+
+ Args:
+ room_ids (str): The room id.
+ to_key (int): Max stream id to fetch receipts upto.
+ from_key (int): Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ list: A list of receipts.
+ """
+ def f(txn):
+ if from_key:
+ sql = (
+ "SELECT * FROM receipts_linearized WHERE"
+ " room_id = ? AND stream_id > ? AND stream_id <= ?"
+ )
+
+ txn.execute(
+ sql,
+ (room_id, from_key, to_key)
+ )
+ else:
+ sql = (
+ "SELECT * FROM receipts_linearized WHERE"
+ " room_id = ? AND stream_id <= ?"
+ )
+
+ txn.execute(
+ sql,
+ (room_id, to_key)
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction(
+ "get_linearized_receipts_for_room", f
+ )
+
+ if not rows:
+ defer.returnValue([])
+
+ content = {}
+ for row in rows:
+ content.setdefault(
+ row["event_id"], {}
+ ).setdefault(
+ row["receipt_type"], {}
+ )[row["user_id"]] = json.loads(row["data"])
+
+ defer.returnValue([{
+ "type": "m.receipt",
+ "room_id": room_id,
+ "content": content,
+ }])
+
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_max_token(self)
+
+ @cachedInlineCallbacks()
+ def get_graph_receipts_for_room(self, room_id):
+ """Get receipts for sending to remote servers.
+ """
+ rows = yield self._simple_select_list(
+ table="receipts_graph",
+ keyvalues={"room_id": room_id},
+ retcols=["receipt_type", "user_id", "event_id"],
+ desc="get_linearized_receipts_for_room",
+ )
+
+ result = {}
+ for row in rows:
+ result.setdefault(
+ row["user_id"], {}
+ ).setdefault(
+ row["receipt_type"], []
+ ).append(row["event_id"])
+
+ defer.returnValue(result)
+
+ def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+ user_id, event_id, data, stream_id):
+
+ # We don't want to clobber receipts for more recent events, so we
+ # have to compare orderings of existing receipts
+ sql = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+ )
+
+ txn.execute(sql, (room_id, receipt_type, user_id))
+ results = txn.fetchall()
+
+ if results:
+ res = self._simple_select_one_txn(
+ txn,
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ )
+ topological_ordering = int(res["topological_ordering"])
+ stream_ordering = int(res["stream_ordering"])
+
+ for to, so, _ in results:
+ if int(to) > topological_ordering:
+ return False
+ elif int(to) == topological_ordering and int(so) >= stream_ordering:
+ return False
+
+ self._simple_delete_txn(
+ txn,
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="receipts_linearized",
+ values={
+ "stream_id": stream_id,
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_id": event_id,
+ "data": json.dumps(data),
+ }
+ )
+
+ return True
+
+ @defer.inlineCallbacks
+ def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ """Insert a receipt, either from local client or remote server.
+
+ Automatically does conversion between linearized and graph
+ representations.
+ """
+ if not event_ids:
+ return
+
+ if len(event_ids) == 1:
+ linearized_event_id = event_ids[0]
+ else:
+ # we need to points in graph -> linearized form.
+ # TODO: Make this better.
+ def graph_to_linear(txn):
+ query = (
+ "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
+ " SELECT max(stream_ordering) WHERE event_id IN (%s)"
+ ")"
+ ) % (",".join(["?"] * len(event_ids)))
+
+ txn.execute(query, [room_id] + event_ids)
+ rows = txn.fetchall()
+ if rows:
+ return rows[0][0]
+ else:
+ raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
+ linearized_event_id = yield self.runInteraction(
+ "insert_receipt_conv", graph_to_linear
+ )
+
+ stream_id_manager = yield self._receipts_id_gen.get_next(self)
+ with stream_id_manager as stream_id:
+ yield self._receipts_stream_cache.room_has_changed(
+ self, room_id, stream_id
+ )
+ have_persisted = yield self.runInteraction(
+ "insert_linearized_receipt",
+ self.insert_linearized_receipt_txn,
+ room_id, receipt_type, user_id, linearized_event_id,
+ data,
+ stream_id=stream_id,
+ )
+
+ if not have_persisted:
+ defer.returnValue(None)
+
+ yield self.insert_graph_receipt(
+ room_id, receipt_type, user_id, event_ids, data
+ )
+
+ max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ defer.returnValue((stream_id, max_persisted_id))
+
+ def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
+ data):
+ return self.runInteraction(
+ "insert_graph_receipt",
+ self.insert_graph_receipt_txn,
+ room_id, receipt_type, user_id, event_ids, data
+ )
+
+ def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
+ user_id, event_ids, data):
+ self._simple_delete_txn(
+ txn,
+ table="receipts_graph",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ }
+ )
+ self._simple_insert_txn(
+ txn,
+ table="receipts_graph",
+ values={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_ids": json.dumps(event_ids),
+ "data": json.dumps(data),
+ }
+ )
+
+
+class _RoomStreamChangeCache(object):
+ """Keeps track of the stream_id of the latest change in rooms.
+
+ Given a list of rooms and stream key, it will give a subset of rooms that
+ may have changed since that key. If the key is too old then the cache
+ will simply return all rooms.
+ """
+ def __init__(self, size_of_cache=10000):
+ self._size_of_cache = size_of_cache
+ self._room_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_key = None
+
+ @defer.inlineCallbacks
+ def get_rooms_changed(self, store, room_ids, key):
+ """Returns subset of room ids that have had new receipts since the
+ given key. If the key is too old it will just return the given list.
+ """
+ if key > (yield self._get_earliest_key(store)):
+ keys = self._cache.keys()
+ i = keys.bisect_right(key)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(room_ids)
+ else:
+ result = room_ids
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def room_has_changed(self, store, room_id, key):
+ """Informs the cache that the room has been changed at the given key.
+ """
+ if key > (yield self._get_earliest_key(store)):
+ old_key = self._room_to_key.get(room_id, None)
+ if old_key:
+ key = max(key, old_key)
+ self._cache.pop(old_key, None)
+ self._cache[key] = room_id
+
+ while len(self._cache) > self._size_of_cache:
+ k, r = self._cache.popitem()
+ self._earliest_key = max(k, self._earliest_key)
+ self._room_to_key.pop(r, None)
+
+ @defer.inlineCallbacks
+ def _get_earliest_key(self, store):
+ if self._earliest_key is None:
+ self._earliest_key = yield store.get_max_receipt_stream_id()
+ self._earliest_key = int(self._earliest_key)
+
+ defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 90e2606be2..4eaa088b36 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
- self.get_user_by_token.invalidate(r)
+ self.get_user_by_token.invalidate((r,))
@cached()
def get_user_by_token(self, token):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 4612a8aa83..dd5bc2c8fb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
import collections
import logging
@@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index d36a6c18a8..9f14f38f24 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,38 +35,28 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def _store_room_member_txn(self, txn, event):
+ def _store_room_members_txn(self, txn, events):
"""Store a room member in the database.
"""
- try:
- target_user_id = event.state_key
- except:
- logger.exception(
- "Failed to parse target_user_id=%s", target_user_id
- )
- raise
-
- logger.debug(
- "_store_room_member_txn: target_user_id=%s, membership=%s",
- target_user_id,
- event.membership,
- )
-
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
- "room_memberships",
- {
- "event_id": event.event_id,
- "user_id": target_user_id,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- }
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ }
+ for event in events
+ ]
)
- txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
+ for event in events:
+ txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -88,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
- @cached()
+ @cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@@ -164,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
- @cached()
+ @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 9f3a4dd4c5..61232f9757 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
-def run_upgrade(cur):
+def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/schema/delta/20/dummy.sql
new file mode 100644
index 0000000000..e0ac49d1ec
--- /dev/null
+++ b/synapse/storage/schema/delta/20/dummy.sql
@@ -0,0 +1 @@
+SELECT 1;
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
new file mode 100644
index 0000000000..543e57bbe2
--- /dev/null
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -0,0 +1,76 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Main purpose of this upgrade is to change the unique key on the
+pushers table again (it was missed when the v16 full schema was
+made) but this also changes the pushkey and data columns to text.
+When selecting a bytea column into a text column, postgres inserts
+the hex encoded data, and there's no portable way of getting the
+UTF-8 bytes, so we have to do it in Python.
+"""
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ logger.info("Porting pushers table...")
+ cur.execute("""
+ CREATE TABLE IF NOT EXISTS pushers2 (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag VARCHAR(32) NOT NULL,
+ kind VARCHAR(8) NOT NULL,
+ app_id VARCHAR(64) NOT NULL,
+ app_display_name VARCHAR(64) NOT NULL,
+ device_display_name VARCHAR(128) NOT NULL,
+ pushkey TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ lang VARCHAR(8),
+ data TEXT,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey, user_name)
+ )
+ """)
+ cur.execute("""SELECT
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ FROM pushers
+ """)
+ count = 0
+ for row in cur.fetchall():
+ row = list(row)
+ row[8] = bytes(row[8]).decode("utf-8")
+ row[11] = bytes(row[11]).decode("utf-8")
+ cur.execute(database_engine.convert_param_style("""
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
+ row
+ )
+ count += 1
+ cur.execute("DROP TABLE pushers")
+ cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
+ logger.info("Moved %d pushers to new table", count)
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql
new file mode 100644
index 0000000000..8b4a380d11
--- /dev/null
+++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql
@@ -0,0 +1,34 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
+ user_id TEXT NOT NULL, -- The user these keys are for.
+ device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
+ ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
+ key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
+ CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
+);
+
+
+CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
+ user_id TEXT NOT NULL, -- The user this one-time key is for.
+ device_id TEXT NOT NULL, -- The device this one-time key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
+);
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql
new file mode 100644
index 0000000000..2f64d609fc
--- /dev/null
+++ b/synapse/storage/schema/delta/21/receipts.sql
@@ -0,0 +1,38 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS receipts_graph(
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_ids TEXT NOT NULL,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE TABLE IF NOT EXISTS receipts_linearized (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE INDEX receipts_linearized_id ON receipts_linearized(
+ stream_id
+);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index f051828630..4f15e534b4 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from _base import SQLBaseStore
from syutil.base64util import encode_base64
+from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore):
@@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()}
- def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
- hash_bytes):
+ def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
Args:
txn (cursor):
- event_id (str): Id for the Event.
- algorithm (str): Hashing algorithm.
- hash_bytes (bytes): Hash function output bytes.
+ events (list): list of Events.
"""
- self._simple_insert_txn(
+
+ vals = []
+ for event in events:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ vals.append({
+ "event_id": event.event_id,
+ "algorithm": ref_alg,
+ "hash": buffer(ref_hash_bytes),
+ })
+
+ self._simple_insert_many_txn(
txn,
- "event_reference_hashes",
- {
- "event_id": event_id,
- "algorithm": algorithm,
- "hash": buffer(hash_bytes),
- },
+ table="event_reference_hashes",
+ values=vals,
)
def _get_event_signatures_txn(self, txn, event_id):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b24de34f23..7ce51b9bdc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -81,31 +81,41 @@ class StateStore(SQLBaseStore):
f,
)
- @defer.inlineCallbacks
- def c(vals):
- vals[:] = yield self._get_events(vals, get_prev_content=False)
-
- yield defer.gatherResults(
+ state_list = yield defer.gatherResults(
[
- c(vals)
- for vals in states.values()
+ self._fetch_events_for_group(group, vals)
+ for group, vals in states.items()
],
consumeErrors=True,
)
- defer.returnValue(states)
+ defer.returnValue(dict(state_list))
+
+ def _fetch_events_for_group(self, key, events):
+ return self._get_events(
+ events, get_prev_content=False
+ ).addCallback(
+ lambda evs: (key, evs)
+ )
def _store_state_groups_txn(self, txn, event, context):
- if context.current_state is None:
- return
+ return self._store_mult_state_groups_txn(txn, [(event, context)])
- state_events = dict(context.current_state)
+ def _store_mult_state_groups_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if context.current_state is None:
+ continue
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ if context.state_group is not None:
+ state_groups[event.event_id] = context.state_group
+ continue
+
+ state_events = dict(context.current_state)
+
+ if event.is_state():
+ state_events[(event.type, event.state_key)] = event
- state_group = context.state_group
- if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn(
txn,
@@ -131,14 +141,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values()
],
)
+ state_groups[event.event_id] = state_group
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
- values={
- "state_group": state_group,
- "event_id": event.event_id,
- },
+ values=[
+ {
+ "state_group": state_groups[event.event_id],
+ "event_id": event.event_id,
+ }
+ for event, context in events_and_contexts
+ if context.current_state is not None
+ ],
)
@defer.inlineCallbacks
@@ -173,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- @cached(num_args=3)
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
@@ -190,6 +204,65 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def get_state_for_events(self, room_id, event_ids):
+ def f(txn):
+ groups = set()
+ event_to_group = {}
+ for event_id in event_ids:
+ # TODO: Remove this loop.
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ event_to_group[event_id] = group
+ groups.add(group)
+
+ group_to_state_ids = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+
+ group_to_state_ids[group] = state_ids
+
+ return event_to_group, group_to_state_ids
+
+ res = yield self.runInteraction(
+ "annotate_events_with_state_groups",
+ f,
+ )
+
+ event_to_group, group_to_state_ids = res
+
+ state_list = yield defer.gatherResults(
+ [
+ self._fetch_events_for_group(group, vals)
+ for group, vals in group_to_state_ids.items()
+ ],
+ consumeErrors=True,
+ )
+
+ state_dict = {
+ group: {
+ (ev.type, ev.state_key): ev
+ for ev in state
+ }
+ for group, state in state_list
+ }
+
+ defer.returnValue([
+ state_dict.get(event_to_group.get(event, None), None)
+ for event in event_ids
+ ])
+
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 89d1643f10..e956df62c7 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,7 +72,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
- def __init__(self):
+ def __init__(self, table, column):
+ self.table = table
+ self.column = column
+
self._lock = threading.Lock()
self._current_max = None
@@ -108,6 +111,37 @@ class StreamIdGenerator(object):
defer.returnValue(manager())
@defer.inlineCallbacks
+ def get_next_mult(self, store, n):
+ """
+ Usage:
+ with yield stream_id_gen.get_next(store, n) as stream_ids:
+ # ... persist events ...
+ """
+ if not self._current_max:
+ yield store.runInteraction(
+ "_compute_current_max",
+ self._get_or_compute_current_max,
+ )
+
+ with self._lock:
+ next_ids = range(self._current_max + 1, self._current_max + n + 1)
+ self._current_max += n
+
+ for next_id in next_ids:
+ self._unfinished_ids.append(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ with self._lock:
+ for next_id in next_ids:
+ self._unfinished_ids.remove(next_id)
+
+ defer.returnValue(manager())
+
+ @defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
@@ -126,7 +160,7 @@ class StreamIdGenerator(object):
def _get_or_compute_current_max(self, txn):
with self._lock:
- txn.execute("SELECT MAX(stream_ordering) FROM events")
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dff7970bea..aaa3609aa5 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -20,6 +20,7 @@ from synapse.types import StreamToken
from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object):
@@ -43,6 +44,7 @@ class EventSources(object):
"room": RoomEventSource,
"presence": PresenceEventSource,
"typing": TypingNotificationEventSource,
+ "receipt": ReceiptEventSource,
}
def __init__(self, hs):
@@ -62,7 +64,10 @@ class EventSources(object):
),
typing_key=(
yield self.sources["typing"].get_current_key()
- )
+ ),
+ receipt_key=(
+ yield self.sources["receipt"].get_current_key()
+ ),
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 1b21160c57..e190374cbd 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
class StreamToken(
namedtuple(
"Token",
- ("room_key", "presence_key", "typing_key")
+ ("room_key", "presence_key", "typing_key", "receipt_key")
)
):
_SEPARATOR = "_"
@@ -109,6 +109,9 @@ class StreamToken(
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
+ if len(keys) == len(cls._fields) - 1:
+ # i.e. old token from before receipt_key
+ keys.append("0")
return cls(*keys)
except:
raise SynapseError(400, "Invalid Token")
@@ -131,6 +134,7 @@ class StreamToken(
(other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key))
+ or (int(other_token.receipt_key) < int(self.receipt_key))
)
def copy_and_advance(self, key, new_value):
@@ -174,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, follewed by "-",
+ "topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@@ -207,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,)
+# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 260714ccc2..07ff25cef3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
- def cancel_call_later(self, timer):
- timer.cancel()
+ def cancel_call_later(self, timer, ignore_errs=False):
+ try:
+ timer.cancel()
+ except:
+ if not ignore_errs:
+ raise
def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1c2044e5b4..7bf2d38bb8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -38,6 +38,9 @@ class ObservableDeferred(object):
deferred.
If consumeErrors is true errors will be captured from the origin deferred.
+
+ Cancelling or otherwise resolving an observer will not affect the original
+ ObservableDeferred.
"""
__slots__ = ["_deferred", "_observers", "_result"]
@@ -45,10 +48,10 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", [])
+ object.__setattr__(self, "_observers", set())
def callback(r):
- self._result = (True, r)
+ object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
self._observers.pop().callback(r)
@@ -57,7 +60,7 @@ class ObservableDeferred(object):
return r
def errback(f):
- self._result = (False, f)
+ object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
self._observers.pop().errback(f)
@@ -74,14 +77,28 @@ class ObservableDeferred(object):
def observe(self):
if not self._result:
d = defer.Deferred()
- self._observers.append(d)
+
+ def remove(r):
+ self._observers.discard(d)
+ return r
+ d.addBoth(remove)
+
+ self._observers.add(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
+ def observers(self):
+ return self._observers
+
def __getattr__(self, name):
return getattr(self._deferred, name)
def __setattr__(self, name, value):
setattr(self._deferred, name, value)
+
+ def __repr__(self):
+ return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
+ id(self), self._result, self._deferred,
+ )
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
)
+class _PreservingContextDeferred(defer.Deferred):
+ """A deferred that ensures that all callbacks and errbacks are called with
+ the given logging context.
+ """
+ def __init__(self, context):
+ self._log_context = context
+ defer.Deferred.__init__(self)
+
+ def addCallbacks(self, callback, errback=None,
+ callbackArgs=None, callbackKeywords=None,
+ errbackArgs=None, errbackKeywords=None):
+ callback = self._wrap_callback(callback)
+ errback = self._wrap_callback(errback)
+ return defer.Deferred.addCallbacks(
+ self, callback,
+ errback=errback,
+ callbackArgs=callbackArgs,
+ callbackKeywords=callbackKeywords,
+ errbackArgs=errbackArgs,
+ errbackKeywords=errbackKeywords,
+ )
+
+ def _wrap_callback(self, f):
+ def g(res, *args, **kwargs):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = self._log_context
+ res = f(res, *args, **kwargs)
+ return res
+ return g
+
+
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
- d = defer.Deferred()
-
current_context = LoggingContext.current_context()
-
- def cb(res):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.callback(res)
- return res
-
- def eb(failure):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.errback(failure)
- return res
-
- if deferred.called:
- return deferred
-
- deferred.addCallbacks(cb, eb)
+ d = _PreservingContextDeferred(current_context)
+ deferred.chainDeferred(d)
return d
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 52e66beaee..7a1e96af37 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -33,3 +33,12 @@ def random_string_with_symbols(length):
return ''.join(
random.choice(_string_with_symbols) for _ in xrange(length)
)
+
+
+def is_ascii(s):
+ try:
+ s.encode("ascii")
+ except UnicodeDecodeError:
+ return False
+ else:
+ return True
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 06cb1dd4cf..9e95d1e532 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -58,6 +58,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_query_user_exists_unknown_user(self):
+ user_id = "@someone:anywhere"
+ services = [self._mkservice(is_interested=True)]
+ services[0].is_interested_in_user = Mock(return_value=True)
+ self.mock_store.get_app_services = Mock(return_value=services)
+ self.mock_store.get_user_by_id = Mock(return_value=None)
+
+ event = Mock(
+ sender=user_id,
+ type="m.room.message",
+ room_id="!foo:bar"
+ )
+ self.mock_as_api.push = Mock()
+ self.mock_as_api.query_user = Mock()
+ yield self.handler.notify_interested_services(event)
+ self.mock_as_api.query_user.assert_called_once_with(
+ services[0], user_id
+ )
+
+ @defer.inlineCallbacks
+ def test_query_user_exists_known_user(self):
+ user_id = "@someone:anywhere"
+ services = [self._mkservice(is_interested=True)]
+ services[0].is_interested_in_user = Mock(return_value=True)
+ self.mock_store.get_app_services = Mock(return_value=services)
+ self.mock_store.get_user_by_id = Mock(return_value={
+ "name": user_id
+ })
+
+ event = Mock(
+ sender=user_id,
+ type="m.room.message",
+ room_id="!foo:bar"
+ )
+ self.mock_as_api.push = Mock()
+ self.mock_as_api.query_user = Mock()
+ yield self.handler.notify_interested_services(event)
+ self.assertFalse(
+ self.mock_as_api.query_user.called,
+ "query_user called when it shouldn't have been."
+ )
+
+ @defer.inlineCallbacks
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index f3821242bc..d392c23015 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
- def annotate(ev, old_state=None):
+ def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
@@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
)
self.state_handler.compute_event_context.assert_called_once_with(
- ANY, old_state=None,
+ ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index a2d7635995..2a7553f982 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room",
"store_room",
"get_latest_events_in_room",
+ "add_event_hashes",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
+ self.datastore.add_event_hashes.return_value = []
@defer.inlineCallbacks
def test_invite(self):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7ccbe2ea9c..41bb08b7ca 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource()
- mock_notifier = Mock(spec=["on_new_user_event"])
- self.on_new_user_event = mock_notifier.on_new_user_event
+ mock_notifier = Mock(spec=["on_new_event"])
+ self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[])
@@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=20000,
)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]),
])
@@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
)
)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]),
])
@@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
room_id=self.room_id,
)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]),
])
@@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000,
)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]),
])
- self.on_new_user_event.reset_mock()
+ self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
@@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.clock.advance_time(11)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]),
])
@@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000,
)
- self.on_new_user_event.assert_has_calls([
+ self.on_new_event.assert_has_calls([
call('typing_key', 3, rooms=[self.room_id]),
])
- self.on_new_user_event.reset_mock()
+ self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 445272e323..ac3b0b58ac 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
)
self.assertEquals(200, code, msg=str(response))
- self.assertEquals(0, len(response["chunk"]))
+ # We may get a presence event for ourselves down
+ self.assertEquals(
+ 0,
+ len([
+ c for c in response["chunk"]
+ if not (
+ c.get("type") == "m.presence"
+ and c["content"].get("user_id") == self.user_id
+ )
+ ])
+ )
# joined room (expect all content for room)
yield self.join(room=room_id, user=self.user_id, tok=self.token)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 4b32c7a203..089a71568c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours
# I'll already get my own presence state change
- self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []},
+ self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
response
)
@@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code)
- self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
+ self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
{"type": "m.presence",
"content": {
"user_id": "@banana:test",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
new file mode 100644
index 0000000000..66fd25964d
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -0,0 +1,134 @@
+from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.api.errors import SynapseError
+from twisted.internet import defer
+from mock import Mock, MagicMock
+from tests import unittest
+import json
+
+
+class RegisterRestServletTestCase(unittest.TestCase):
+
+ def setUp(self):
+ # do the dance to hook up request data to self.request_data
+ self.request_data = ""
+ self.request = Mock(
+ content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+ )
+ self.request.args = {}
+
+ self.appservice = None
+ self.auth = Mock(get_appservice_by_req=Mock(
+ side_effect=lambda x: defer.succeed(self.appservice))
+ )
+
+ self.auth_result = (False, None, None)
+ self.auth_handler = Mock(
+ check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
+ )
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+
+ # do the dance to hook it up to the hs global
+ self.handlers = Mock(
+ auth_handler=self.auth_handler,
+ registration_handler=self.registration_handler,
+ identity_handler=self.identity_handler,
+ login_handler=self.login_handler
+ )
+ self.hs = Mock()
+ self.hs.hostname = "superbig~testing~thing.com"
+ self.hs.get_auth = Mock(return_value=self.auth)
+ self.hs.get_handlers = Mock(return_value=self.handlers)
+ self.hs.config.disable_registration = False
+
+ # init the thing we're testing
+ self.servlet = RegisterRestServlet(self.hs)
+
+ @defer.inlineCallbacks
+ def test_POST_appservice_registration_valid(self):
+ user_id = "@kermit:muppet"
+ token = "kermits_access_token"
+ self.request.args = {
+ "access_token": "i_am_an_app_service"
+ }
+ self.request_data = json.dumps({
+ "username": "kermit"
+ })
+ self.appservice = {
+ "id": "1234"
+ }
+ self.registration_handler.appservice_register = Mock(
+ return_value=(user_id, token)
+ )
+ result = yield self.servlet.on_POST(self.request)
+ self.assertEquals(result, (200, {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname
+ }))
+
+ @defer.inlineCallbacks
+ def test_POST_appservice_registration_invalid(self):
+ self.request.args = {
+ "access_token": "i_am_an_app_service"
+ }
+ self.request_data = json.dumps({
+ "username": "kermit"
+ })
+ self.appservice = None # no application service exists
+ result = yield self.servlet.on_POST(self.request)
+ self.assertEquals(result, (401, None))
+
+ def test_POST_bad_password(self):
+ self.request_data = json.dumps({
+ "username": "kermit",
+ "password": 666
+ })
+ d = self.servlet.on_POST(self.request)
+ return self.assertFailure(d, SynapseError)
+
+ def test_POST_bad_username(self):
+ self.request_data = json.dumps({
+ "username": 777,
+ "password": "monkey"
+ })
+ d = self.servlet.on_POST(self.request)
+ return self.assertFailure(d, SynapseError)
+
+ @defer.inlineCallbacks
+ def test_POST_user_valid(self):
+ user_id = "@kermit:muppet"
+ token = "kermits_access_token"
+ self.request_data = json.dumps({
+ "username": "kermit",
+ "password": "monkey"
+ })
+ self.registration_handler.check_username = Mock(return_value=True)
+ self.auth_result = (True, None, {
+ "username": "kermit",
+ "password": "monkey"
+ })
+ self.registration_handler.register = Mock(return_value=(user_id, token))
+
+ result = yield self.servlet.on_POST(self.request)
+ self.assertEquals(result, (200, {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname
+ }))
+
+ def test_POST_disabled_registration(self):
+ self.hs.config.disable_registration = True
+ self.request_data = json.dumps({
+ "username": "kermit",
+ "password": "monkey"
+ })
+ self.registration_handler.check_username = Mock(return_value=True)
+ self.auth_result = (True, None, {
+ "username": "kermit",
+ "password": "monkey"
+ })
+ self.registration_handler.register = Mock(return_value=("@user:id", "t"))
+ d = self.servlet.on_POST(self.request)
+ return self.assertFailure(d, SynapseError)
\ No newline at end of file
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96caf8c4c1..abee2f631d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
+from synapse.util.async import ObservableDeferred
+
from synapse.storage._base import Cache, cached
@@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
- self.cache.prefill("foo", 123)
- self.cache.invalidate("foo")
+ self.cache.prefill(("foo",), 123)
+ self.cache.invalidate(("foo",))
failed = False
try:
- self.cache.get("foo")
+ self.cache.get(("foo",))
except KeyError:
failed = True
@@ -96,87 +98,102 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
- @cached()
- def func(self, key):
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ return key
- self.assertEquals((yield func(self, "foo")), "foo")
- self.assertEquals((yield func(self, "bar")), "bar")
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
- yield func(self, "foo")
+ a = A()
+ yield a.func("foo")
self.assertEquals(callcount[0], 1)
- self.assertEquals((yield func(self, "foo")), "foo")
+ self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
- yield func(self, "foo")
+ a = A()
+ yield a.func("foo")
self.assertEquals(callcount[0], 1)
- func.invalidate("foo")
+ a.func.invalidate(("foo",))
- yield func(self, "foo")
+ yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
- @cached()
- def func(self, key):
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ return key
- func.invalidate("what")
+ A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
- for k in range(0,12):
- yield func(self, k)
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
- yield func(self, k)
+ yield a.func(k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
- @defer.inlineCallbacks
def test_prefill(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ d = defer.succeed(123)
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return d
+
+ a = A()
- func.prefill("foo", 123)
+ a.func.prefill(("foo",), ObservableDeferred(d))
- self.assertEquals((yield func(self, "foo")), 123)
+ self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 78f6004204..2702291178 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id))
)
- result = yield self.store.get_user_by_token(self.tokens[1])
+ result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset(
{
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 6a0095d850..8ed48cfb70 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
yield d
self.assertTrue(d.called)
- observers[0].assert_called_once("Go")
- observers[1].assert_called_once("Go")
+ observers[0].assert_called_once_with("Go")
+ observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0],
diff --git a/tests/utils.py b/tests/utils.py
index 3b5c335911..eb035cf48f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -114,6 +114,8 @@ class MockHttpResource(HttpServer):
mock_request.method = http_method
mock_request.uri = path
+ mock_request.getClientIP.return_value = "-"
+
mock_request.requestHeaders.getRawHeaders.return_value=[
"X-Matrix origin=test,key=,sig="
]
|