diff --git a/jenkins-postgres.sh b/jenkins-postgres.sh
index 9ac86d2593..ae6b111591 100755
--- a/jenkins-postgres.sh
+++ b/jenkins-postgres.sh
@@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
+$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
diff --git a/jenkins-sqlite.sh b/jenkins-sqlite.sh
index 345d01936c..9398d9db15 100755
--- a/jenkins-sqlite.sh
+++ b/jenkins-sqlite.sh
@@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
+$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
diff --git a/jenkins.sh b/jenkins.sh
deleted file mode 100755
index b826d510c9..0000000000
--- a/jenkins.sh
+++ /dev/null
@@ -1,86 +0,0 @@
-#!/bin/bash
-
-set -eux
-
-: ${WORKSPACE:="$(pwd)"}
-
-export PYTHONDONTWRITEBYTECODE=yep
-export SYNAPSE_CACHE_FACTOR=1
-
-# Output test results as junit xml
-export TRIAL_FLAGS="--reporter=subunit"
-export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
-# Write coverage reports to a separate file for each process
-export COVERAGE_OPTS="-p"
-export DUMP_COVERAGE_COMMAND="coverage help"
-
-# Output flake8 violations to violations.flake8.log
-# Don't exit with non-0 status code on Jenkins,
-# so that the build steps continue and a later step can decided whether to
-# UNSTABLE or FAILURE this build.
-export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
-
-rm .coverage* || echo "No coverage files to remove"
-
-tox
-
-: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
-
-TOX_BIN=$WORKSPACE/.tox/py27/bin
-
-if [[ ! -e .sytest-base ]]; then
- git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
-else
- (cd .sytest-base; git fetch -p)
-fi
-
-rm -rf sytest
-git clone .sytest-base sytest --shared
-cd sytest
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
-: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
-: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
-export PERL5LIB PERL_MB_OPT PERL_MM_OPT
-
-./install-deps.pl
-
-: ${PORT_BASE:=8000}
-
-echo >&2 "Running sytest with SQLite3";
-./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
- --python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
-
-RUN_POSTGRES=""
-
-for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
- if psql synapse_jenkins_$port <<< ""; then
- RUN_POSTGRES="$RUN_POSTGRES:$port"
- cat > localhost-$port/database.yaml << EOF
-name: psycopg2
-args:
- database: synapse_jenkins_$port
-EOF
- fi
-done
-
-# Run if both postgresql databases exist
-if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
- echo >&2 "Running sytest with PostgreSQL";
- $TOX_BIN/pip install psycopg2
- ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
- --python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
-else
- echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
-fi
-
-cd ..
-cp sytest/.coverage.* .
-
-# Combine the coverage reports
-echo "Combining:" .coverage.*
-$TOX_BIN/python -m coverage combine
-# Output coverage to coverage.xml
-$TOX_BIN/coverage xml -o coverage.xml
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 253a6ef6c7..efd04da2d6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -214,6 +214,10 @@ class Porter(object):
self.progress.add_table(table, postgres_size, table_size)
+ if table == "event_search":
+ yield self.handle_search_table(postgres_size, table_size, next_chunk)
+ return
+
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
@@ -232,60 +236,95 @@ class Porter(object):
if rows:
next_chunk = rows[-1][0] + 1
- if table == "event_search":
- # We have to treat event_search differently since it has a
- # different structure in the two different databases.
- def insert(txn):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, sender, vector)"
- " VALUES (?,?,?,?,to_tsvector('english', ?))"
- )
+ self._convert_rows(table, headers, rows)
- rows_dict = [
- dict(zip(headers, row))
- for row in rows
- ]
-
- txn.executemany(sql, [
- (
- row["event_id"],
- row["room_id"],
- row["key"],
- row["sender"],
- row["value"],
- )
- for row in rows_dict
- ])
-
- self.postgres_store._simple_update_one_txn(
- txn,
- table="port_from_sqlite3",
- keyvalues={"table_name": table},
- updatevalues={"rowid": next_chunk},
- )
- else:
- self._convert_rows(table, headers, rows)
+ def insert(txn):
+ self.postgres_store.insert_many_txn(
+ txn, table, headers[1:], rows
+ )
- def insert(txn):
- self.postgres_store.insert_many_txn(
- txn, table, headers[1:], rows
- )
+ self.postgres_store._simple_update_one_txn(
+ txn,
+ table="port_from_sqlite3",
+ keyvalues={"table_name": table},
+ updatevalues={"rowid": next_chunk},
+ )
+
+ yield self.postgres_store.execute(insert)
+
+ postgres_size += len(rows)
+
+ self.progress.update(table, postgres_size)
+ else:
+ return
+
+ @defer.inlineCallbacks
+ def handle_search_table(self, postgres_size, table_size, next_chunk):
+ select = (
+ "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
+ " FROM event_search as es"
+ " INNER JOIN events AS e USING (event_id, room_id)"
+ " WHERE es.rowid >= ?"
+ " ORDER BY es.rowid LIMIT ?"
+ )
- self.postgres_store._simple_update_one_txn(
- txn,
- table="port_from_sqlite3",
- keyvalues={"table_name": table},
- updatevalues={"rowid": next_chunk},
+ while True:
+ def r(txn):
+ txn.execute(select, (next_chunk, self.batch_size,))
+ rows = txn.fetchall()
+ headers = [column[0] for column in txn.description]
+
+ return headers, rows
+
+ headers, rows = yield self.sqlite_store.runInteraction("select", r)
+
+ if rows:
+ next_chunk = rows[-1][0] + 1
+
+ # We have to treat event_search differently since it has a
+ # different structure in the two different databases.
+ def insert(txn):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key,"
+ " sender, vector, origin_server_ts, stream_ordering)"
+ " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ rows_dict = [
+ dict(zip(headers, row))
+ for row in rows
+ ]
+
+ txn.executemany(sql, [
+ (
+ row["event_id"],
+ row["room_id"],
+ row["key"],
+ row["sender"],
+ row["value"],
+ row["origin_server_ts"],
+ row["stream_ordering"],
)
+ for row in rows_dict
+ ])
+
+ self.postgres_store._simple_update_one_txn(
+ txn,
+ table="port_from_sqlite3",
+ keyvalues={"table_name": "event_search"},
+ updatevalues={"rowid": next_chunk},
+ )
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
- self.progress.update(table, postgres_size)
+ self.progress.update("event_search", postgres_size)
+
else:
return
+
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d2085a9405..df675c0ed4 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,12 +16,9 @@
import synapse
-import contextlib
import logging
import os
-import re
import sys
-import time
from synapse.config._base import ConfigError
from synapse.python_dependencies import (
@@ -35,18 +32,11 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
from synapse.server import HomeServer
-
-from twisted.conch.manhole import ColoredManhole
-from twisted.conch.insults import insults
-from twisted.conch import manhole_ssh
-from twisted.cred import checkers, portal
-
-
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
-from twisted.web.server import Site, GzipEncoderFactory, Request
+from twisted.web.server import GzipEncoderFactory
from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@@ -66,6 +56,10 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.manhole import manhole
+
+from synapse.http.site import SynapseSite
from synapse import events
@@ -74,9 +68,6 @@ from daemonize import Daemonize
logger = logging.getLogger("synapse.app.homeserver")
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
-
-
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@@ -174,7 +165,12 @@ class SynapseHomeServer(HomeServer):
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self)
- root_resource = create_resource_tree(resources)
+ if WEB_CLIENT_PREFIX in resources:
+ root_resource = RootRedirect(WEB_CLIENT_PREFIX)
+ else:
+ root_resource = Resource()
+
+ root_resource = create_resource_tree(resources, root_resource)
if tls:
reactor.listenSSL(
port,
@@ -207,24 +203,13 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
- checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
- matrix="rabbithole"
- )
-
- rlm = manhole_ssh.TerminalRealm()
- rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
- ColoredManhole,
- {
- "__name__": "__console__",
- "hs": self,
- }
- )
-
- f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
-
reactor.listenTCP(
listener["port"],
- f,
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
@@ -371,210 +356,6 @@ class SynapseService(service.Service):
return self._port.stopListening()
-class SynapseRequest(Request):
- def __init__(self, site, *args, **kw):
- Request.__init__(self, *args, **kw)
- self.site = site
- self.authenticated_entity = None
- self.start_time = 0
-
- def __repr__(self):
- # We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
- self.__class__.__name__,
- id(self),
- self.method,
- self.get_redacted_uri(),
- self.clientproto,
- self.site.site_tag,
- )
-
- def get_redacted_uri(self):
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- self.uri
- )
-
- def get_user_agent(self):
- return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
-
- def started_processing(self):
- self.site.access_logger.info(
- "%s - %s - Received request: %s %s",
- self.getClientIP(),
- self.site.site_tag,
- self.method,
- self.get_redacted_uri()
- )
- self.start_time = int(time.time() * 1000)
-
- def finished_processing(self):
-
- try:
- context = LoggingContext.current_context()
- ru_utime, ru_stime = context.get_resource_usage()
- db_txn_count = context.db_txn_count
- db_txn_duration = context.db_txn_duration
- except:
- ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration = (0, 0)
-
- self.site.access_logger.info(
- "%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%d)"
- " %sB %s \"%s %s %s\" \"%s\"",
- self.getClientIP(),
- self.site.site_tag,
- self.authenticated_entity,
- int(time.time() * 1000) - self.start_time,
- int(ru_utime * 1000),
- int(ru_stime * 1000),
- int(db_txn_duration * 1000),
- int(db_txn_count),
- self.sentLength,
- self.code,
- self.method,
- self.get_redacted_uri(),
- self.clientproto,
- self.get_user_agent(),
- )
-
- @contextlib.contextmanager
- def processing(self):
- self.started_processing()
- yield
- self.finished_processing()
-
-
-class XForwardedForRequest(SynapseRequest):
- def __init__(self, *args, **kw):
- SynapseRequest.__init__(self, *args, **kw)
-
- """
- Add a layer on top of another request that only uses the value of an
- X-Forwarded-For header as the result of C{getClientIP}.
- """
- def getClientIP(self):
- """
- @return: The client address (the first address) in the value of the
- I{X-Forwarded-For header}. If the header is not present, return
- C{b"-"}.
- """
- return self.requestHeaders.getRawHeaders(
- b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
-
-
-class SynapseRequestFactory(object):
- def __init__(self, site, x_forwarded_for):
- self.site = site
- self.x_forwarded_for = x_forwarded_for
-
- def __call__(self, *args, **kwargs):
- if self.x_forwarded_for:
- return XForwardedForRequest(self.site, *args, **kwargs)
- else:
- return SynapseRequest(self.site, *args, **kwargs)
-
-
-class SynapseSite(Site):
- """
- Subclass of a twisted http Site that does access logging with python's
- standard logging
- """
- def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
- Site.__init__(self, resource, *args, **kwargs)
-
- self.site_tag = site_tag
-
- proxied = config.get("x_forwarded", False)
- self.requestFactory = SynapseRequestFactory(self, proxied)
- self.access_logger = logging.getLogger(logger_name)
-
- def log(self, request):
- pass
-
-
-def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
- """Create the resource tree for this Home Server.
-
- This in unduly complicated because Twisted does not support putting
- child resources more than 1 level deep at a time.
-
- Args:
- web_client (bool): True to enable the web client.
- redirect_root_to_web_client (bool): True to redirect '/' to the
- location of the web client. This does nothing if web_client is not
- True.
- """
- if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
- root_resource = RootRedirect(WEB_CLIENT_PREFIX)
- else:
- root_resource = Resource()
-
- # ideally we'd just use getChild and putChild but getChild doesn't work
- # unless you give it a Request object IN ADDITION to the name :/ So
- # instead, we'll store a copy of this mapping so we can actually add
- # extra resources to existing nodes. See self._resource_id for the key.
- resource_mappings = {}
- for full_path, res in desired_tree.items():
- logger.info("Attaching %s to path %s", res, full_path)
- last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
- if path_seg not in last_resource.listNames():
- # resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
- last_resource.putChild(path_seg, child_resource)
- res_id = _resource_id(last_resource, path_seg)
- resource_mappings[res_id] = child_resource
- last_resource = child_resource
- else:
- # we have an existing Resource, use that instead.
- res_id = _resource_id(last_resource, path_seg)
- last_resource = resource_mappings[res_id]
-
- # ===========================
- # now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
-
- # if there is already a resource here, thieve its children and
- # replace it
- res_id = _resource_id(last_resource, last_path_seg)
- if res_id in resource_mappings:
- # there is a dummy resource at this path already, which needs
- # to be replaced with the desired resource.
- existing_dummy_resource = resource_mappings[res_id]
- for child_name in existing_dummy_resource.listNames():
- child_res_id = _resource_id(
- existing_dummy_resource, child_name
- )
- child_resource = resource_mappings[child_res_id]
- # steal the children
- res.putChild(child_name, child_resource)
-
- # finally, insert the desired resource in the right place
- last_resource.putChild(last_path_seg, res)
- res_id = _resource_id(last_resource, last_path_seg)
- resource_mappings[res_id] = res
-
- return root_resource
-
-
-def _resource_id(resource, path_seg):
- """Construct an arbitrary resource ID so you can retrieve the mapping
- later.
-
- If you want to represent resource A putChild resource B with path C,
- the mapping should looks like _resource_id(A,C) = B.
-
- Args:
- resource (Resource): The *parent* Resourceb
- path_seg (str): The name of the child Resource to be attached.
- Returns:
- str: A unique string which can be a key to the child Resource.
- """
- return "%s-%s" % (resource, path_seg)
-
-
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
new file mode 100644
index 0000000000..b5339f030d
--- /dev/null
+++ b/synapse/app/pusher.py
@@ -0,0 +1,315 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse
+
+from synapse.server import HomeServer
+from synapse.config._base import ConfigError
+from synapse.config.database import DatabaseConfig
+from synapse.config.logger import LoggingConfig
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.storage.engines import create_engine
+from synapse.storage import DataStore
+from synapse.util.async import sleep
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.manhole import manhole
+from synapse.util.rlimit import change_resource_limit
+from synapse.util.versionstring import get_version_string
+
+from twisted.internet import reactor, defer
+from twisted.web.resource import Resource
+
+from daemonize import Daemonize
+
+import sys
+import logging
+
+logger = logging.getLogger("synapse.app.pusher")
+
+
+class SlaveConfig(DatabaseConfig):
+ def read_config(self, config):
+ self.replication_url = config["replication_url"]
+ self.server_name = config["server_name"]
+ self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
+ "use_insecure_ssl_client_just_for_testing_do_not_use", False
+ )
+ self.user_agent_suffix = None
+ self.start_pushers = True
+ self.listeners = config["listeners"]
+ self.soft_file_limit = config.get("soft_file_limit")
+ self.daemonize = config.get("daemonize")
+ self.pid_file = self.abspath(config.get("pid_file"))
+
+ def default_config(self, server_name, **kwargs):
+ pid_file = self.abspath("pusher.pid")
+ return """\
+ # Slave configuration
+
+ # The replication listener on the synapse to talk to.
+ #replication_url: https://localhost:{replication_port}/_synapse/replication
+
+ server_name: "%(server_name)s"
+
+ listeners: []
+ # Enable a ssh manhole listener on the pusher.
+ # - type: manhole
+ # port: {manhole_port}
+ # bind_address: 127.0.0.1
+ # Enable a metric listener on the pusher.
+ # - type: http
+ # port: {metrics_port}
+ # bind_address: 127.0.0.1
+ # resources:
+ # - names: ["metrics"]
+ # compress: False
+
+ report_stats: False
+
+ daemonize: False
+
+ pid_file: %(pid_file)s
+
+ """ % locals()
+
+
+class PusherSlaveConfig(SlaveConfig, LoggingConfig):
+ pass
+
+
+class PusherSlaveStore(
+ SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore
+):
+ update_pusher_last_stream_ordering_and_success = (
+ DataStore.update_pusher_last_stream_ordering_and_success.__func__
+ )
+
+ update_pusher_failing_since = (
+ DataStore.update_pusher_failing_since.__func__
+ )
+
+ update_pusher_last_stream_ordering = (
+ DataStore.update_pusher_last_stream_ordering.__func__
+ )
+
+
+class PusherServer(HomeServer):
+
+ def get_db_conn(self, run_new_connection=True):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = PusherSlaveStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
+
+ def remove_pusher(self, app_id, push_key, user_id):
+ http_client = self.get_simple_http_client()
+ replication_url = self.config.replication_url
+ url = replication_url + "/remove_pushers"
+ return http_client.post_json_get_json(url, {
+ "remove": [{
+ "app_id": app_id,
+ "push_key": push_key,
+ "user_id": user_id,
+ }]
+ })
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_address = listener_config.get("bind_address", "")
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(self)
+
+ root_resource = create_resource_tree(resources, Resource())
+ reactor.listenTCP(
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ ),
+ interface=bind_address
+ )
+ logger.info("Synapse pusher now listening on port %d", port)
+
+ def start_listening(self):
+ for listener in self.config.listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ reactor.listenTCP(
+ listener["port"],
+ manhole(
+ username="matrix",
+ password="rabbithole",
+ globals={"hs": self},
+ ),
+ interface=listener.get("bind_address", '127.0.0.1')
+ )
+ else:
+ logger.warn("Unrecognized listener type: %s", listener["type"])
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ http_client = self.get_simple_http_client()
+ store = self.get_datastore()
+ replication_url = self.config.replication_url
+ pusher_pool = self.get_pusherpool()
+
+ def stop_pusher(user_id, app_id, pushkey):
+ key = "%s:%s" % (app_id, pushkey)
+ pushers_for_user = pusher_pool.pushers.get(user_id, {})
+ pusher = pushers_for_user.pop(key, None)
+ if pusher is None:
+ return
+ logger.info("Stopping pusher %r / %r", user_id, key)
+ pusher.on_stop()
+
+ def start_pusher(user_id, app_id, pushkey):
+ key = "%s:%s" % (app_id, pushkey)
+ logger.info("Starting pusher %r / %r", user_id, key)
+ return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
+
+ @defer.inlineCallbacks
+ def poke_pushers(results):
+ pushers_rows = set(
+ map(tuple, results.get("pushers", {}).get("rows", []))
+ )
+ deleted_pushers_rows = set(
+ map(tuple, results.get("deleted_pushers", {}).get("rows", []))
+ )
+ for row in sorted(pushers_rows | deleted_pushers_rows):
+ if row in deleted_pushers_rows:
+ user_id, app_id, pushkey = row[1:4]
+ stop_pusher(user_id, app_id, pushkey)
+ elif row in pushers_rows:
+ user_id = row[1]
+ app_id = row[5]
+ pushkey = row[8]
+ yield start_pusher(user_id, app_id, pushkey)
+
+ stream = results.get("events")
+ if stream:
+ min_stream_id = stream["rows"][0][0]
+ max_stream_id = stream["position"]
+ preserve_fn(pusher_pool.on_new_notifications)(
+ min_stream_id, max_stream_id
+ )
+
+ stream = results.get("receipts")
+ if stream:
+ rows = stream["rows"]
+ affected_room_ids = set(row[1] for row in rows)
+ min_stream_id = rows[0][0]
+ max_stream_id = stream["position"]
+ preserve_fn(pusher_pool.on_new_receipts)(
+ min_stream_id, max_stream_id, affected_room_ids
+ )
+
+ while True:
+ try:
+ args = store.stream_positions()
+ args["timeout"] = 30000
+ result = yield http_client.get_json(replication_url, args=args)
+ yield store.process_replication(result)
+ poke_pushers(result)
+ except:
+ logger.exception("Error replicating from %r", replication_url)
+ sleep(30)
+
+
+def setup(config_options):
+ try:
+ config = PusherSlaveConfig.load_config(
+ "Synapse pusher", config_options
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ if not config:
+ sys.exit(0)
+
+ config.setup_logging()
+
+ database_engine = create_engine(config.database_config)
+
+ ps = PusherServer(
+ config.server_name,
+ db_config=config.database_config,
+ config=config,
+ version_string=get_version_string("Synapse", synapse),
+ database_engine=database_engine,
+ )
+
+ ps.setup()
+ ps.start_listening()
+
+ change_resource_limit(ps.config.soft_file_limit)
+
+ def start():
+ ps.replicate()
+ ps.get_pusherpool().start()
+ ps.get_datastore().start_profiling()
+
+ reactor.callWhenRunning(start)
+
+ return ps
+
+
+if __name__ == '__main__':
+ with LoggingContext("main"):
+ ps = setup(sys.argv[1:])
+
+ if ps.config.daemonize:
+ def run():
+ with LoggingContext("run"):
+ change_resource_limit(ps.config.soft_file_limit)
+ reactor.run()
+
+ daemon = Daemonize(
+ app="synapse-pusher",
+ pid=ps.config.pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+
+ daemon.start()
+ else:
+ reactor.run()
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 4cb092bbec..47f145c589 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -13,7 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config
+from ._base import Config, ConfigError
+
+
+MISSING_JWT = (
+ """Missing jwt library. This is required for jwt login.
+
+ Install by running:
+ pip install pyjwt
+ """
+)
class JWTConfig(Config):
@@ -23,6 +32,12 @@ class JWTConfig(Config):
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
+
+ try:
+ import jwt
+ jwt # To stop unused lint.
+ except ImportError:
+ raise ConfigError(MISSING_JWT)
else:
self.jwt_enabled = False
self.jwt_secret = None
@@ -30,6 +45,8 @@ class JWTConfig(Config):
def default_config(self, **kwargs):
return """\
+ # The JWT needs to contain a globally unique "sub" (subject) claim.
+ #
# jwt_config:
# enabled: true
# secret: "a secret"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 19af39da70..04b9221908 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -33,6 +33,7 @@ class ServerConfig(Config):
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
+ self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", [])
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index b69f36aefe..ed2cda837f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -232,7 +232,7 @@ class RoomMemberHandler(BaseHandler):
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
- "Cannot %s user who was is banned" % (action,),
+ "Cannot %s user who was banned" % (action,),
errcode=Codes.BAD_STATE
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 6c89b20984..902ae7a203 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -462,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None)
- def getContext(self, hostname, port):
+ def getContext(self, hostname=None, port=None):
return self._context
+
+ def creatorForNetloc(self, hostname, port):
+ return self
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b82196fd5e..f705abab94 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -74,7 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0
-def request_handler(request_handler):
+def request_handler(report_metrics=True):
+ """Decorator for ``wrap_request_handler``"""
+ return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
+
+
+def wrap_request_handler(request_handler, report_metrics):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
@@ -96,7 +101,12 @@ def request_handler(request_handler):
global _next_request_id
request_id = "%s-%s" % (request.method, _next_request_id)
_next_request_id += 1
+
with LoggingContext(request_id) as request_context:
+ if report_metrics:
+ request_metrics = RequestMetrics()
+ request_metrics.start(self.clock)
+
request_context.request = request_id
with request.processing():
try:
@@ -133,6 +143,14 @@ def request_handler(request_handler):
},
send_cors=True
)
+ finally:
+ try:
+ if report_metrics:
+ request_metrics.stop(
+ self.clock, request, self.__class__.__name__
+ )
+ except:
+ pass
return wrapped_request_handler
@@ -197,19 +215,23 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
- @request_handler
+ # Disable metric reporting because _async_render does its own metrics.
+ # It does its own metric reporting because _async_render dispatches to
+ # a callback and it's the class name of that callback we want to report
+ # against rather than the JsonResource itself.
+ @request_handler(report_metrics=False)
@defer.inlineCallbacks
def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
- start = self.clock.time_msec()
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
- start_context = LoggingContext.current_context()
+ request_metrics = RequestMetrics()
+ request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method
# and path regex match
@@ -241,40 +263,7 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, code, response)
try:
- context = LoggingContext.current_context()
-
- tag = ""
- if context:
- tag = context.tag
-
- if context != start_context:
- logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
- )
- return
-
- incoming_requests_counter.inc(request.method, servlet_classname, tag)
-
- response_timer.inc_by(
- self.clock.time_msec() - start, request.method,
- servlet_classname, tag
- )
-
- ru_utime, ru_stime = context.get_resource_usage()
-
- response_ru_utime.inc_by(
- ru_utime, request.method, servlet_classname, tag
- )
- response_ru_stime.inc_by(
- ru_stime, request.method, servlet_classname, tag
- )
- response_db_txn_count.inc_by(
- context.db_txn_count, request.method, servlet_classname, tag
- )
- response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, servlet_classname, tag
- )
+ request_metrics.stop(self.clock, request, servlet_classname)
except:
pass
@@ -307,6 +296,48 @@ class JsonResource(HttpServer, resource.Resource):
)
+class RequestMetrics(object):
+ def start(self, clock):
+ self.start = clock.time_msec()
+ self.start_context = LoggingContext.current_context()
+
+ def stop(self, clock, request, servlet_classname):
+ context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+ response_timer.inc_by(
+ clock.time_msec() - self.start, request.method,
+ servlet_classname, tag
+ )
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ response_ru_utime.inc_by(
+ ru_utime, request.method, servlet_classname, tag
+ )
+ response_ru_stime.inc_by(
+ ru_stime, request.method, servlet_classname, tag
+ )
+ response_db_txn_count.inc_by(
+ context.db_txn_count, request.method, servlet_classname, tag
+ )
+ response_db_txn_duration.inc_by(
+ context.db_txn_duration, request.method, servlet_classname, tag
+ )
+
+
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
diff --git a/synapse/http/site.py b/synapse/http/site.py
new file mode 100644
index 0000000000..4b09d7ee66
--- /dev/null
+++ b/synapse/http/site.py
@@ -0,0 +1,146 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.logcontext import LoggingContext
+from twisted.web.server import Site, Request
+
+import contextlib
+import logging
+import re
+import time
+
+ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+
+
+class SynapseRequest(Request):
+ def __init__(self, site, *args, **kw):
+ Request.__init__(self, *args, **kw)
+ self.site = site
+ self.authenticated_entity = None
+ self.start_time = 0
+
+ def __repr__(self):
+ # We overwrite this so that we don't log ``access_token``
+ return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ self.__class__.__name__,
+ id(self),
+ self.method,
+ self.get_redacted_uri(),
+ self.clientproto,
+ self.site.site_tag,
+ )
+
+ def get_redacted_uri(self):
+ return ACCESS_TOKEN_RE.sub(
+ r'\1<redacted>\3',
+ self.uri
+ )
+
+ def get_user_agent(self):
+ return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+
+ def started_processing(self):
+ self.site.access_logger.info(
+ "%s - %s - Received request: %s %s",
+ self.getClientIP(),
+ self.site.site_tag,
+ self.method,
+ self.get_redacted_uri()
+ )
+ self.start_time = int(time.time() * 1000)
+
+ def finished_processing(self):
+
+ try:
+ context = LoggingContext.current_context()
+ ru_utime, ru_stime = context.get_resource_usage()
+ db_txn_count = context.db_txn_count
+ db_txn_duration = context.db_txn_duration
+ except:
+ ru_utime, ru_stime = (0, 0)
+ db_txn_count, db_txn_duration = (0, 0)
+
+ self.site.access_logger.info(
+ "%s - %s - {%s}"
+ " Processed request: %dms (%dms, %dms) (%dms/%d)"
+ " %sB %s \"%s %s %s\" \"%s\"",
+ self.getClientIP(),
+ self.site.site_tag,
+ self.authenticated_entity,
+ int(time.time() * 1000) - self.start_time,
+ int(ru_utime * 1000),
+ int(ru_stime * 1000),
+ int(db_txn_duration * 1000),
+ int(db_txn_count),
+ self.sentLength,
+ self.code,
+ self.method,
+ self.get_redacted_uri(),
+ self.clientproto,
+ self.get_user_agent(),
+ )
+
+ @contextlib.contextmanager
+ def processing(self):
+ self.started_processing()
+ yield
+ self.finished_processing()
+
+
+class XForwardedForRequest(SynapseRequest):
+ def __init__(self, *args, **kw):
+ SynapseRequest.__init__(self, *args, **kw)
+
+ """
+ Add a layer on top of another request that only uses the value of an
+ X-Forwarded-For header as the result of C{getClientIP}.
+ """
+ def getClientIP(self):
+ """
+ @return: The client address (the first address) in the value of the
+ I{X-Forwarded-For header}. If the header is not present, return
+ C{b"-"}.
+ """
+ return self.requestHeaders.getRawHeaders(
+ b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
+
+
+class SynapseRequestFactory(object):
+ def __init__(self, site, x_forwarded_for):
+ self.site = site
+ self.x_forwarded_for = x_forwarded_for
+
+ def __call__(self, *args, **kwargs):
+ if self.x_forwarded_for:
+ return XForwardedForRequest(self.site, *args, **kwargs)
+ else:
+ return SynapseRequest(self.site, *args, **kwargs)
+
+
+class SynapseSite(Site):
+ """
+ Subclass of a twisted http Site that does access logging with python's
+ standard logging
+ """
+ def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+ Site.__init__(self, resource, *args, **kwargs)
+
+ self.site_tag = site_tag
+
+ proxied = config.get("x_forwarded", False)
+ self.requestFactory = SynapseRequestFactory(self, proxied)
+ self.access_logger = logging.getLogger(logger_name)
+
+ def log(self, request):
+ pass
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 6950a20632..3992804845 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -230,7 +230,7 @@ class HttpPusher(object):
"Pushkey %s was rejected: removing",
pk
)
- yield self.hs.get_pusherpool().remove_pusher(
+ yield self.hs.remove_pusher(
self.app_id, pk, self.user_id
)
defer.returnValue(True)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index aa095f9d9b..6ef48d63f7 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
+ self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {}
@@ -177,6 +178,9 @@ class PusherPool:
self._start_pushers([p])
def _start_pushers(self, pushers):
+ if not self.start_pushers:
+ logger.info("Not starting pushers because they are disabled in the config")
+ return
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
try:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 618f3c43ab..e0a7a19777 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -36,7 +36,6 @@ REQUIREMENTS = {
"blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
- "pyjwt": ["jwt"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py
new file mode 100644
index 0000000000..9b01ab3c13
--- /dev/null
+++ b/synapse/replication/pusher_resource.py
@@ -0,0 +1,54 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import respond_with_json_bytes, request_handler
+from synapse.http.servlet import parse_json_object_from_request
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+
+
+class PusherResource(Resource):
+ """
+ HTTP endpoint for deleting rejected pushers
+ """
+
+ def __init__(self, hs):
+ Resource.__init__(self) # Resource is old-style, so no super()
+
+ self.version_string = hs.version_string
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.clock = hs.get_clock()
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @request_handler()
+ @defer.inlineCallbacks
+ def _async_render_POST(self, request):
+ content = parse_json_object_from_request(request)
+
+ for remove in content["remove"]:
+ yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ remove["app_id"],
+ remove["push_key"],
+ remove["user_id"],
+ )
+
+ self.notifier.on_new_replication_data()
+
+ respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index a543af68f8..ff78c60f13 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -15,6 +15,7 @@
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
+from synapse.replication.pusher_resource import PusherResource
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -102,8 +103,6 @@ class ReplicationResource(Resource):
long-polling this replication API for new data on those streams.
"""
- isLeaf = True
-
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
@@ -113,6 +112,9 @@ class ReplicationResource(Resource):
self.presence_handler = hs.get_handlers().presence_handler
self.typing_handler = hs.get_handlers().typing_notification_handler
self.notifier = hs.notifier
+ self.clock = hs.get_clock()
+
+ self.putChild("remove_pushers", PusherResource(hs))
def render_GET(self, request):
self._async_render_GET(request)
@@ -138,7 +140,7 @@ class ReplicationResource(Resource):
state_token,
))
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
limit = parse_integer(request, "limit", 100)
@@ -343,7 +345,7 @@ class ReplicationResource(Resource):
"app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data"
))
- writer.write_header_and_rows("deleted", deleted, (
+ writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey"
))
@@ -381,7 +383,7 @@ class _Writer(object):
position = rows[-1][0]
self.streams[name] = {
- "position": str(position),
+ "position": position if type(position) is int else str(position),
"field_names": fields,
"rows": rows,
}
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index cfc728a038..86f00b6ff5 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -21,6 +21,7 @@ from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
+from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -68,7 +69,19 @@ class SlavedEventStore(BaseSlavedStore):
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
+ get_invited_rooms_for_user = RoomMemberStore.__dict__[
+ "get_invited_rooms_for_user"
+ ]
+ get_unread_event_push_actions_by_room_for_user = (
+ EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
+ )
+ get_unread_push_actions_for_user_in_range = (
+ DataStore.get_unread_push_actions_for_user_in_range.__func__
+ )
+ get_push_action_users_in_range = (
+ DataStore.get_push_action_users_in_range.__func__
+ )
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
@@ -82,6 +95,7 @@ class SlavedEventStore(BaseSlavedStore):
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
+
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
@@ -104,7 +118,7 @@ class SlavedEventStore(BaseSlavedStore):
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
- result["backfill"] = self._backfill_id_gen.get_current_token()
+ result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
@@ -122,7 +136,7 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("backfill")
if stream:
- self._backfill_id_gen.advance(stream["position"])
+ self._backfill_id_gen.advance(-stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
@@ -147,11 +161,11 @@ class SlavedEventStore(BaseSlavedStore):
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
- self._invalidate_caches_for_event(
+ self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
- def _invalidate_caches_for_event(self, event, backfilled, reset_state):
+ def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
@@ -163,6 +177,10 @@ class SlavedEventStore(BaseSlavedStore):
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
+ (event.room_id,)
+ )
+
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
@@ -182,6 +200,7 @@ class SlavedEventStore(BaseSlavedStore):
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
+ self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():
return
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
new file mode 100644
index 0000000000..8faddb2595
--- /dev/null
+++ b/synapse/replication/slave/storage/pushers.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
+from synapse.storage import DataStore
+
+
+class SlavedPusherStore(BaseSlavedStore):
+
+ def __init__(self, db_conn, hs):
+ super(SlavedPusherStore, self).__init__(db_conn, hs)
+ self._pushers_id_gen = SlavedIdTracker(
+ db_conn, "pushers", "id",
+ extra_tables=[("deleted_pushers", "stream_id")],
+ )
+
+ get_all_pushers = DataStore.get_all_pushers.__func__
+ get_pushers_by = DataStore.get_pushers_by.__func__
+ get_pushers_by_app_id_and_pushkey = (
+ DataStore.get_pushers_by_app_id_and_pushkey.__func__
+ )
+ _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
+
+ def stream_positions(self):
+ result = super(SlavedPusherStore, self).stream_positions()
+ result["pushers"] = self._pushers_id_gen.get_current_token()
+ return result
+
+ def process_replication(self, result):
+ stream = result.get("pushers")
+ if stream:
+ self._pushers_id_gen.advance(stream["position"])
+
+ stream = result.get("deleted_pushers")
+ if stream:
+ self._pushers_id_gen.advance(stream["position"])
+
+ return super(SlavedPusherStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
new file mode 100644
index 0000000000..b55d5dfd08
--- /dev/null
+++ b/synapse/replication/slave/storage/receipts.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+
+from synapse.storage import DataStore
+from synapse.storage.receipts import ReceiptsStore
+
+# So, um, we want to borrow a load of functions intended for reading from
+# a DataStore, but we don't want to take functions that either write to the
+# DataStore or are cached and don't have cache invalidation logic.
+#
+# Rather than write duplicate versions of those functions, or lift them to
+# a common base class, we going to grab the underlying __func__ object from
+# the method descriptor on the DataStore and chuck them into our class.
+
+
+class SlavedReceiptsStore(BaseSlavedStore):
+
+ def __init__(self, db_conn, hs):
+ super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+
+ self._receipts_id_gen = SlavedIdTracker(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
+
+ get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
+ get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+
+ def stream_positions(self):
+ result = super(SlavedReceiptsStore, self).stream_positions()
+ result["receipts"] = self._receipts_id_gen.get_current_token()
+ return result
+
+ def process_replication(self, result):
+ stream = result.get("receipts")
+ if stream:
+ self._receipts_id_gen.advance(stream["position"])
+ for row in stream["rows"]:
+ room_id, receipt_type, user_id = row[1:4]
+ self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
+
+ return super(SlavedReceiptsStore, self).process_replication(result)
+
+ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ self.get_receipts_for_user.invalidate((user_id, receipt_type))
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d14ce3efa2..3b5544851b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -33,9 +33,6 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
-import jwt
-from jwt.exceptions import InvalidTokenError
-
logger = logging.getLogger(__name__)
@@ -224,16 +221,24 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
- token = login_submission['token']
+ token = login_submission.get("token", None)
if token is None:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Token field for JWT is missing",
+ errcode=Codes.UNAUTHORIZED
+ )
+
+ import jwt
+ from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+ except jwt.ExpiredSignatureError:
+ raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user = payload['user']
+ user = payload.get("sub", None)
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index 3db3838b7e..bd4fea5774 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -49,7 +49,6 @@ class LocalKey(Resource):
"""
def __init__(self, hs):
- self.hs = hs
self.version_string = hs.version_string
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9552016fec..7209d5a37d 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -97,7 +97,7 @@ class RemoteKey(Resource):
self.async_render_GET(request)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
@@ -122,7 +122,7 @@ class RemoteKey(Resource):
self.async_render_POST(request)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def async_render_POST(self, request):
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 510884262c..9f69620772 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -36,12 +36,13 @@ class DownloadResource(Resource):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string
+ self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id, name = parse_media_id(request)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index fecdf8ed86..dc1e5fbdb3 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -45,7 +45,17 @@ class PreviewUrlResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
+
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.version_string = hs.version_string
+ self.filepaths = media_repo.filepaths
+ self.max_spider_size = hs.config.max_spider_size
+ self.server_name = hs.hostname
+ self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs)
+ self.media_repo = media_repo
+
if hasattr(hs.config, "url_preview_url_blacklist"):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -60,18 +70,11 @@ class PreviewUrlResource(Resource):
self.downloads = {}
- self.auth = hs.get_auth()
- self.clock = hs.get_clock()
- self.version_string = hs.version_string
- self.filepaths = media_repo.filepaths
- self.max_spider_size = hs.config.max_spider_size
- self.server_name = hs.hostname
-
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
@@ -368,7 +371,7 @@ class PreviewUrlResource(Resource):
file_id = random_string(24)
fname = self.filepaths.local_media_filepath(file_id)
- self._makedirs(fname)
+ self.media_repo._makedirs(fname)
try:
with open(fname, "wb") as f:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 234dd4261c..0b9e1de1a7 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -39,12 +39,13 @@ class ThumbnailResource(Resource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
+ self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id, _ = parse_media_id(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 299e1f6e56..b716d1d892 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -41,6 +41,7 @@ class UploadResource(Resource):
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
self.version_string = hs.version_string
+ self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
@@ -50,7 +51,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
- @request_handler
+ @request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/server.py b/synapse/server.py
index 368d615576..ee138de756 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -193,6 +193,9 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def remove_pusher(self, app_id, push_key, user_id):
+ return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
+
def _make_dependency_method(depname):
def _get(hs):
diff --git a/synapse/state.py b/synapse/state.py
index 58211f5feb..d0f76dc4f5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -214,7 +214,7 @@ class StateHandler(object):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
- if cache and cache.state_group:
+ if cache:
cache.ts = self.clock.time_msec()
event_dict = yield self.store.get_events(cache.state.values())
@@ -230,22 +230,34 @@ class StateHandler(object):
(cache.state_group, state, prev_states)
)
+ logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
+
new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key
)
+ state_group = None
+ new_state_event_ids = frozenset(e.event_id for e in new_state.values())
+ for sg, events in state_groups.items():
+ if new_state_event_ids == frozenset(e.event_id for e in events):
+ state_group = sg
+ break
+
if self._state_cache is not None:
cache = _StateCacheEntry(
state={key: event.event_id for key, event in new_state.items()},
- state_group=None,
+ state_group=state_group,
ts=self.clock.time_msec()
)
self._state_cache[group_names] = cache
- defer.returnValue((None, new_state, prev_states))
+ defer.returnValue((state_group, new_state, prev_states))
def resolve_events(self, state_sets, event):
+ logger.info(
+ "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+ )
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 49904046cf..66a995157d 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -173,11 +173,12 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info(
"Updating %r. Updated %r items in %rms."
- " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
+ " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name, items_updated, duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
+ batch_size,
)
performance.update(items_updated, duration_ms)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index dd58e001dc..438eef6ba3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1145,6 +1145,12 @@ class EventsStore(SQLBaseStore):
current_backfill_id, current_forward_id, limit):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
+ have_backfill_events = last_backfill_id != current_backfill_id
+ have_forward_events = last_forward_id != current_forward_id
+
+ if not have_backfill_events and not have_forward_events:
+ return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
@@ -1157,7 +1163,7 @@ class EventsStore(SQLBaseStore):
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
)
- if last_forward_id != current_forward_id:
+ if have_forward_events:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
@@ -1201,7 +1207,7 @@ class EventsStore(SQLBaseStore):
" ORDER BY e.stream_ordering DESC"
" LIMIT ?"
)
- if last_backfill_id != current_backfill_id:
+ if have_backfill_events:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 5fb47d418a..d9afd7ec87 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -106,6 +106,9 @@ class PusherStore(SQLBaseStore):
return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed(([], []))
+
def get_all_updated_pushers_txn(txn):
sql = (
"SELECT id, user_name, access_token, profile_tag, kind,"
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 3b8805593e..935fc503d9 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -391,6 +391,9 @@ class ReceiptsStore(SQLBaseStore):
)
def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 9be977f387..70aa64fb31 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -169,20 +169,28 @@ class RoomStore(SQLBaseStore):
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
- "INSERT INTO event_search (event_id, room_id, key, vector)"
- " VALUES (?,?,?,to_tsvector('english', ?))"
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+ txn.execute(
+ sql,
+ (
+ event.event_id, event.room_id, key, value,
+ event.internal_metadata.stream_ordering,
+ event.origin_server_ts,
+ )
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
+ txn.execute(sql, (event.event_id, event.room_id, key, value,))
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
- txn.execute(sql, (event.event_id, event.room_id, key, value,))
-
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
new file mode 100644
index 0000000000..470ae0c005
--- /dev/null
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -0,0 +1,65 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT;
+ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ for statement in get_statements(ALTER_TABLE.splitlines()):
+ cur.execute(statement)
+
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ min_stream_id = rows[0][0]
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ rows = cur.fetchall()
+ max_stream_id = rows[0][0]
+
+ if min_stream_id is not None and max_stream_id is not None:
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ "rows_inserted": 0,
+ "have_added_indexes": False,
+ }
+ progress_json = ujson.dumps(progress)
+
+ sql = (
+ "INSERT into background_updates (update_name, progress_json)"
+ " VALUES (?, ?)"
+ )
+
+ sql = database_engine.convert_param_style(sql)
+
+ cur.execute(sql, ("event_search_order", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 59ac7f424c..0224299625 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -29,12 +29,17 @@ logger = logging.getLogger(__name__)
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
+ EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
def __init__(self, hs):
super(SearchStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
+ self.register_background_update_handler(
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME,
+ self._background_reindex_search_order
+ )
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
@@ -132,6 +137,82 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result)
@defer.inlineCallbacks
+ def _background_reindex_search_order(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+ have_added_index = progress['have_added_indexes']
+
+ if not have_added_index:
+ def create_index(conn):
+ conn.rollback()
+ conn.set_session(autocommit=True)
+ c = conn.cursor()
+
+ # We create with NULLS FIRST so that when we search *backwards*
+ # we get the ones with non null origin_server_ts *first*
+ c.execute(
+ "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
+ "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+ )
+ c.execute(
+ "CREATE INDEX CONCURRENTLY event_search_order ON event_search("
+ "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+ )
+ conn.set_session(autocommit=False)
+
+ yield self.runWithConnection(create_index)
+
+ pg = dict(progress)
+ pg["have_added_indexes"] = True
+
+ yield self.runInteraction(
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME,
+ self._background_update_progress_txn,
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
+ )
+
+ def reindex_search_txn(txn):
+ sql = (
+ "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
+ " origin_server_ts = e.origin_server_ts"
+ " FROM events AS e"
+ " WHERE e.event_id = es.event_id"
+ " AND ? <= e.stream_ordering AND e.stream_ordering < ?"
+ " RETURNING es.stream_ordering"
+ )
+
+ min_stream_id = max_stream_id - batch_size
+ txn.execute(sql, (min_stream_id, max_stream_id))
+ rows = txn.fetchall()
+
+ if min_stream_id < target_min_stream_id:
+ # We've recached the end.
+ return len(rows), False
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows),
+ "have_added_indexes": True,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
+ )
+
+ return len(rows), True
+
+ num_rows, finished = yield self.runInteraction(
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
+ )
+
+ if not finished:
+ yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
+
+ defer.returnValue(num_rows)
+
+ @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -310,7 +391,6 @@ class SearchStore(BackgroundUpdateStore):
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
- " NATURAL JOIN events"
" WHERE vector @@ to_tsquery('english', ?) AND "
)
args = [search_query, search_query] + args
@@ -355,7 +435,15 @@ class SearchStore(BackgroundUpdateStore):
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
- sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
+ if isinstance(self.database_engine, PostgresEngine):
+ sql += (
+ " ORDER BY origin_server_ts DESC NULLS LAST,"
+ " stream_ordering DESC NULLS LAST LIMIT ?"
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
+ else:
+ raise Exception("Unrecognized database engine")
args.append(limit)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c5d2a3a6df..5b743db67a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,6 +174,12 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
+ @cached(num_args=2, lru=True, max_entries=1000)
+ def _get_state_group_from_group(self, group, types):
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="_get_state_group_from_group",
+ list_name="groups", num_args=2, inlineCallbacks=True)
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""
@@ -201,18 +207,23 @@ class StateStore(SQLBaseStore):
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
- results = {}
+ results = {group: {} for group in groups}
for row in rows:
key = (row["type"], row["state_key"])
- results.setdefault(row["state_group"], {})[key] = row["event_id"]
+ results[row["state_group"]][key] = row["event_id"]
return results
+ results = {}
+
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
- return self.runInteraction(
+ res = yield self.runInteraction(
"_get_state_groups_from_groups",
f, chunk
)
+ results.update(res)
+
+ defer.returnValue(results)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
@@ -359,6 +370,8 @@ class StateStore(SQLBaseStore):
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned.
"""
+ if types:
+ types = frozenset(types)
results = {}
missing_groups = []
if types is not None:
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
new file mode 100644
index 0000000000..45be47159a
--- /dev/null
+++ b/synapse/util/httpresourcetree.py
@@ -0,0 +1,98 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def create_resource_tree(desired_tree, root_resource):
+ """Create the resource tree for this Home Server.
+
+ This in unduly complicated because Twisted does not support putting
+ child resources more than 1 level deep at a time.
+
+ Args:
+ web_client (bool): True to enable the web client.
+ root_resource (twisted.web.resource.Resource): The root
+ resource to add the tree to.
+ Returns:
+ twisted.web.resource.Resource: the ``root_resource`` with a tree of
+ child resources added to it.
+ """
+
+ # ideally we'd just use getChild and putChild but getChild doesn't work
+ # unless you give it a Request object IN ADDITION to the name :/ So
+ # instead, we'll store a copy of this mapping so we can actually add
+ # extra resources to existing nodes. See self._resource_id for the key.
+ resource_mappings = {}
+ for full_path, res in desired_tree.items():
+ logger.info("Attaching %s to path %s", res, full_path)
+ last_resource = root_resource
+ for path_seg in full_path.split('/')[1:-1]:
+ if path_seg not in last_resource.listNames():
+ # resource doesn't exist, so make a "dummy resource"
+ child_resource = Resource()
+ last_resource.putChild(path_seg, child_resource)
+ res_id = _resource_id(last_resource, path_seg)
+ resource_mappings[res_id] = child_resource
+ last_resource = child_resource
+ else:
+ # we have an existing Resource, use that instead.
+ res_id = _resource_id(last_resource, path_seg)
+ last_resource = resource_mappings[res_id]
+
+ # ===========================
+ # now attach the actual desired resource
+ last_path_seg = full_path.split('/')[-1]
+
+ # if there is already a resource here, thieve its children and
+ # replace it
+ res_id = _resource_id(last_resource, last_path_seg)
+ if res_id in resource_mappings:
+ # there is a dummy resource at this path already, which needs
+ # to be replaced with the desired resource.
+ existing_dummy_resource = resource_mappings[res_id]
+ for child_name in existing_dummy_resource.listNames():
+ child_res_id = _resource_id(
+ existing_dummy_resource, child_name
+ )
+ child_resource = resource_mappings[child_res_id]
+ # steal the children
+ res.putChild(child_name, child_resource)
+
+ # finally, insert the desired resource in the right place
+ last_resource.putChild(last_path_seg, res)
+ res_id = _resource_id(last_resource, last_path_seg)
+ resource_mappings[res_id] = res
+
+ return root_resource
+
+
+def _resource_id(resource, path_seg):
+ """Construct an arbitrary resource ID so you can retrieve the mapping
+ later.
+
+ If you want to represent resource A putChild resource B with path C,
+ the mapping should looks like _resource_id(A,C) = B.
+
+ Args:
+ resource (Resource): The *parent* Resourceb
+ path_seg (str): The name of the child Resource to be attached.
+ Returns:
+ str: A unique string which can be a key to the child Resource.
+ """
+ return "%s-%s" % (resource, path_seg)
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
new file mode 100644
index 0000000000..97e0f00b67
--- /dev/null
+++ b/synapse/util/manhole.py
@@ -0,0 +1,70 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.conch.manhole import ColoredManhole
+from twisted.conch.insults import insults
+from twisted.conch import manhole_ssh
+from twisted.cred import checkers, portal
+from twisted.conch.ssh.keys import Key
+
+PUBLIC_KEY = (
+ "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
+ "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS"
+ "kbh/C+BR3utDS555mV"
+)
+
+PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
+MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
+4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
+vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
+Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
+xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
+PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
+gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
+DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
+pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
+EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
+-----END RSA PRIVATE KEY-----"""
+
+
+def manhole(username, password, globals):
+ """Starts a ssh listener with password authentication using
+ the given username and password. Clients connecting to the ssh
+ listener will find themselves in a colored python shell with
+ the supplied globals.
+
+ Args:
+ username(str): The username ssh clients should auth with.
+ password(str): The password ssh clients should auth with.
+ globals(dict): The variables to expose in the shell.
+
+ Returns:
+ twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
+ """
+
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
+ **{username: password}
+ )
+
+ rlm = manhole_ssh.TerminalRealm()
+ rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
+ ColoredManhole,
+ dict(globals, __name__="__console__")
+ )
+
+ factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
+ factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+
+ return factory
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 983caafe8a..1f13cd0bc0 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
from tests import unittest
-from synapse.replication.slave.storage.events import SlavedEventStore
-
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource
@@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore()
- self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
+ self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
@defer.inlineCallbacks
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index baa4a26eb5..17587fda00 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
+from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer
@@ -43,6 +44,8 @@ def patch__eq__(cls):
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
+ STORE_TYPE = SlavedEventStore
+
def setUp(self):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
@@ -251,6 +254,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
+ @defer.inlineCallbacks
+ def test_invites(self):
+ yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ event = yield self.persist(
+ type="m.room.member", key=USER_ID_2, membership="invite"
+ )
+ yield self.replicate()
+ yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
+ ROOM_ID, USER_ID, "invite", event.event_id,
+ event.internal_metadata.stream_ordering
+ )])
+
+ @defer.inlineCallbacks
+ def test_push_actions_for_user(self):
+ yield self.persist(type="m.room.create", creator=USER_ID)
+ yield self.persist(type="m.room.join", key=USER_ID, membership="join")
+ yield self.persist(
+ type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
+ )
+ event1 = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="hello"
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 0, "notify_count": 0}
+ )
+
+ yield self.persist(
+ type="m.room.message", msgtype="m.text", body="world",
+ push_actions=[(USER_ID_2, ["notify"])],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 0, "notify_count": 1}
+ )
+
+ yield self.persist(
+ type="m.room.message", msgtype="m.text", body="world",
+ push_actions=[(USER_ID_2, [
+ "notify", {"set_tweak": "highlight", "value": True}
+ ])],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 1, "notify_count": 2}
+ )
+
event_id = 0
@defer.inlineCallbacks
@@ -258,6 +314,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
+ push_actions=[],
**content
):
"""
@@ -290,6 +347,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.event_id += 1
context = EventContext(current_state=state)
+ context.push_actions = push_actions
ordering = None
if backfill:
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
new file mode 100644
index 0000000000..6624fe4eea
--- /dev/null
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -0,0 +1,39 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+
+from twisted.internet import defer
+
+USER_ID = "@feeling:blue"
+ROOM_ID = "!room:blue"
+EVENT_ID = "$event:blue"
+
+
+class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
+
+ STORE_TYPE = SlavedReceiptsStore
+
+ @defer.inlineCallbacks
+ def test_receipt(self):
+ yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ yield self.master_store.insert_receipt(
+ ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
+ )
+ yield self.replicate()
+ yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
+ ROOM_ID: EVENT_ID
+ })
|