diff --git a/jenkins.sh b/jenkins.sh
index e2bb706c7f..ed3bdf80fa 100755
--- a/jenkins.sh
+++ b/jenkins.sh
@@ -26,7 +26,7 @@ TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
- (cd .sytest-base; git fetch)
+ (cd .sytest-base; git fetch -p)
fi
rm -rf sytest
@@ -52,7 +52,7 @@ RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
- RUN_POSTGRES=$RUN_POSTGRES:$port
+ RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
@@ -62,7 +62,7 @@ EOF
done
# Run if both postgresql databases exist
-if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
+if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
diff --git a/scripts-dev/dump_macaroon.py b/scripts-dev/dump_macaroon.py
new file mode 100755
index 0000000000..6e45be75d6
--- /dev/null
+++ b/scripts-dev/dump_macaroon.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python2
+
+import pymacaroons
+import sys
+
+if len(sys.argv) == 1:
+ sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
+ sys.exit(1)
+
+macaroon_string = sys.argv[1]
+key = sys.argv[2] if len(sys.argv) > 2 else None
+
+macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
+print macaroon.inspect()
+
+print ""
+
+verifier = pymacaroons.Verifier()
+verifier.satisfy_general(lambda c: True)
+try:
+ verifier.verify(macaroon, key)
+ print "Signature is correct"
+except Exception as e:
+ print e.message
diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py
new file mode 100755
index 0000000000..58d40c4ff4
--- /dev/null
+++ b/scripts-dev/list_url_patterns.py
@@ -0,0 +1,62 @@
+#! /usr/bin/python
+
+import ast
+import argparse
+import os
+import sys
+import yaml
+
+PATTERNS_V1 = []
+PATTERNS_V2 = []
+
+RESULT = {
+ "v1": PATTERNS_V1,
+ "v2": PATTERNS_V2,
+}
+
+class CallVisitor(ast.NodeVisitor):
+ def visit_Call(self, node):
+ if isinstance(node.func, ast.Name):
+ name = node.func.id
+ else:
+ return
+
+
+ if name == "client_path_patterns":
+ PATTERNS_V1.append(node.args[0].s)
+ elif name == "client_v2_patterns":
+ PATTERNS_V2.append(node.args[0].s)
+
+
+def find_patterns_in_code(input_code):
+ input_ast = ast.parse(input_code)
+ visitor = CallVisitor()
+ visitor.visit(input_ast)
+
+
+def find_patterns_in_file(filepath):
+ with open(filepath) as f:
+ find_patterns_in_code(f.read())
+
+
+parser = argparse.ArgumentParser(description='Find url patterns.')
+
+parser.add_argument(
+ "directories", nargs='+', metavar="DIR",
+ help="Directories to search for definitions"
+)
+
+args = parser.parse_args()
+
+
+for directory in args.directories:
+ for root, dirs, files in os.walk(directory):
+ for filename in files:
+ if filename.endswith(".py"):
+ filepath = os.path.join(root, filename)
+ find_patterns_in_file(filepath)
+
+PATTERNS_V1.sort()
+PATTERNS_V2.sort()
+
+yaml.dump(RESULT, sys.stdout, default_flow_style=False)
diff --git a/setup.cfg b/setup.cfg
index ba027c7d13..f8cc13c840 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -16,3 +16,4 @@ ignore =
[flake8]
max-line-length = 90
+ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b5536e8565..e2f84c4d57 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64
import logging
@@ -529,7 +530,8 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
- self.store.insert_client_ip(
+ preserve_context_over_fn(
+ self.store.insert_client_ip,
user=user,
access_token=access_token,
ip=ip_addr,
@@ -574,7 +576,7 @@ class Auth(object):
raise AuthError(
403,
"Application service has not registered this user"
- )
+ )
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -696,6 +698,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
+ logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
@@ -713,6 +716,7 @@ class Auth(object):
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
+ logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 6c13ada5df..6eff83e5f8 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
+import ujson as json
+
class Filtering(object):
@@ -149,6 +151,9 @@ class FilterCollection(object):
"include_leave", False
)
+ def __repr__(self):
+ return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
+
def get_filter_json(self):
return self._filter_json
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 19824f9a02..0fd9b7f244 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
-MEDIA_PREFIX = "/_matrix/media/v1"
+MEDIA_PREFIX = "/_matrix/media/r0"
+LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index bfebb0f644..1bc4279807 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import sys
+sys.dont_write_bytecode = True
+
+from synapse.python_dependencies import (
+ check_requirements, MissingRequirementError
+) # NOQA
+
+try:
+ check_requirements()
+except MissingRequirementError as e:
+ message = "\n".join([
+ "Missing Requirement: %s" % (e.message,),
+ "To install run:",
+ " pip install --upgrade --force \"%s\"" % (e.dependency,),
+ "",
+ ])
+ sys.stderr.writelines(message)
+ sys.exit(1)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6928d9d3e4..2b4be7bdd0 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -14,27 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse
+
+import contextlib
+import logging
+import os
+import re
+import resource
+import subprocess
import sys
-from synapse.rest import ClientRestResource
+import time
+from synapse.config._base import ConfigError
-sys.dont_write_bytecode = True
from synapse.python_dependencies import (
- check_requirements, DEPENDENCY_LINKS, MissingRequirementError
+ check_requirements, DEPENDENCY_LINKS
)
-if __name__ == '__main__':
- try:
- check_requirements()
- except MissingRequirementError as e:
- message = "\n".join([
- "Missing Requirement: %s" % (e.message,),
- "To install run:",
- " pip install --upgrade --force \"%s\"" % (e.dependency,),
- "",
- ])
- sys.stderr.writelines(message)
- sys.exit(1)
-
+from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
@@ -50,41 +46,29 @@ from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer
from twisted.application import service
-from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
-from synapse.http.server import JsonResource, RootRedirect
+from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
- SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
+ SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.federation.transport.server import TransportLayerServer
from synapse import events
from daemonize import Daemonize
-import synapse
-
-import contextlib
-import logging
-import os
-import re
-import resource
-import subprocess
-import time
-
-
logger = logging.getLogger("synapse.app.homeserver")
@@ -95,80 +79,37 @@ def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
-class SynapseHomeServer(HomeServer):
-
- def build_http_client(self):
- return MatrixFederationHttpClient(self)
-
- def build_client_resource(self):
- return ClientRestResource(self)
-
- def build_resource_for_federation(self):
- return JsonResource(self)
-
- def build_resource_for_web_client(self):
- webclient_path = self.get_config().web_client_location
- if not webclient_path:
- try:
- import syweb
- except ImportError:
- quit_with_error(
- "Could not find a webclient.\n\n"
- "Please either install the matrix-angular-sdk or configure\n"
- "the location of the source to serve via the configuration\n"
- "option `web_client_location`\n\n"
- "To install the `matrix-angular-sdk` via pip, run:\n\n"
- " pip install '%(dep)s'\n"
- "\n"
- "You can also disable hosting of the webclient via the\n"
- "configuration option `web_client`\n"
- % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
- )
- syweb_path = os.path.dirname(syweb.__file__)
- webclient_path = os.path.join(syweb_path, "webclient")
- # GZip is disabled here due to
- # https://twistedmatrix.com/trac/ticket/7678
- # (It can stay enabled for the API resources: they call
- # write() with the whole body and then finish() straight
- # after and so do not trigger the bug.
- # GzipFile was removed in commit 184ba09
- # return GzipFile(webclient_path) # TODO configurable?
- return File(webclient_path) # TODO configurable?
-
- def build_resource_for_static_content(self):
- # This is old and should go away: not going to bother adding gzip
- return File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- )
-
- def build_resource_for_content_repo(self):
- return ContentRepoResource(
- self, self.config.uploads_path, self.auth, self.content_addr
- )
-
- def build_resource_for_media_repository(self):
- return MediaRepositoryResource(self)
-
- def build_resource_for_server_key(self):
- return LocalKey(self)
-
- def build_resource_for_server_key_v2(self):
- return KeyApiV2Resource(self)
-
- def build_resource_for_metrics(self):
- if self.get_config().enable_metrics:
- return MetricsResource(self)
- else:
- return None
+def build_resource_for_web_client(hs):
+ webclient_path = hs.get_config().web_client_location
+ if not webclient_path:
+ try:
+ import syweb
+ except ImportError:
+ quit_with_error(
+ "Could not find a webclient.\n\n"
+ "Please either install the matrix-angular-sdk or configure\n"
+ "the location of the source to serve via the configuration\n"
+ "option `web_client_location`\n\n"
+ "To install the `matrix-angular-sdk` via pip, run:\n\n"
+ " pip install '%(dep)s'\n"
+ "\n"
+ "You can also disable hosting of the webclient via the\n"
+ "configuration option `web_client`\n"
+ % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
+ )
+ syweb_path = os.path.dirname(syweb.__file__)
+ webclient_path = os.path.join(syweb_path, "webclient")
+ # GZip is disabled here due to
+ # https://twistedmatrix.com/trac/ticket/7678
+ # (It can stay enabled for the API resources: they call
+ # write() with the whole body and then finish() straight
+ # after and so do not trigger the bug.
+ # GzipFile was removed in commit 184ba09
+ # return GzipFile(webclient_path) # TODO configurable?
+ return File(webclient_path) # TODO configurable?
- def build_db_pool(self):
- name = self.db_config["name"]
-
- return adbapi.ConnectionPool(
- name,
- **self.db_config.get("args", {})
- )
+class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
@@ -178,13 +119,11 @@ class SynapseHomeServer(HomeServer):
if tls and config.no_tls:
return
- metrics_resource = self.get_resource_for_metrics()
-
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
- client_resource = self.get_client_resource()
+ client_resource = ClientRestResource(self)
if res["compress"]:
client_resource = gz_wrap(client_resource)
@@ -198,31 +137,37 @@ class SynapseHomeServer(HomeServer):
if name == "federation":
resources.update({
- FEDERATION_PREFIX: self.get_resource_for_federation(),
+ FEDERATION_PREFIX: TransportLayerServer(self),
})
if name in ["static", "client"]:
resources.update({
- STATIC_PREFIX: self.get_resource_for_static_content(),
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ ),
})
if name in ["media", "federation", "client"]:
+ media_repo = MediaRepositoryResource(self)
resources.update({
- MEDIA_PREFIX: self.get_resource_for_media_repository(),
- CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path, self.auth, self.content_addr
+ ),
})
if name in ["keys", "federation"]:
resources.update({
- SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
- SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
+ SERVER_KEY_PREFIX: LocalKey(self),
+ SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
if name == "webclient":
- resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
+ resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
- if name == "metrics" and metrics_resource:
- resources[METRICS_PREFIX] = metrics_resource
+ if name == "metrics" and self.get_config().enable_metrics:
+ resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources)
if tls:
@@ -296,6 +241,18 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
+ def get_db_conn(self):
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
def quit_with_error(error_string):
message_lines = error_string.split("\n")
@@ -396,11 +353,20 @@ def setup(config_options):
Returns:
HomeServer
"""
- config = HomeServerConfig.load_config(
- "Synapse Homeserver",
- config_options,
- generate_section="Homeserver"
- )
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse Homeserver",
+ config_options,
+ generate_section="Homeserver"
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ if not config:
+ # If a config isn't returned, and an exception isn't raised, we're just
+ # generating config files and shouldn't try to continue.
+ sys.exit(0)
config.setup_logging()
@@ -432,13 +398,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
- db_conn = database_engine.module.connect(
- **{
- k: v for k, v in config.database_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
-
+ db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
@@ -453,13 +413,17 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
+ hs.setup()
hs.start_listening()
- hs.get_pusherpool().start()
- hs.get_state_handler().start_caching()
- hs.get_datastore().start_profiling()
- hs.get_datastore().start_doing_background_updates()
- hs.get_replication_layer().start_get_pdu_cache()
+ def start():
+ hs.get_pusherpool().start()
+ hs.get_state_handler().start_caching()
+ hs.get_datastore().start_profiling()
+ hs.get_datastore().start_doing_background_updates()
+ hs.get_replication_layer().start_get_pdu_cache()
+
+ reactor.callWhenRunning(start)
return hs
@@ -675,7 +639,7 @@ def _resource_id(resource, path_seg):
the mapping should looks like _resource_id(A,C) = B.
Args:
- resource (Resource): The *parent* Resource
+ resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
@@ -722,8 +686,8 @@ def run(hs):
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
- all_rooms = yield hs.get_datastore().get_rooms(False)
- stats["total_room_count"] = len(all_rooms)
+ room_count = yield hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
@@ -745,6 +709,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread():
+ # Uncomment to enable tracing of log context changes.
+ # sys.settrace(logcontext_tracer)
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e1c07028e8..bc90605324 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -29,7 +29,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing.
"""
- def __init__(self, hs):
+ def __init__(self, hs):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index ea9e7907a6..0a3b70e11f 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
@@ -21,7 +22,11 @@ if __name__ == "__main__":
if action == "read":
key = sys.argv[2]
- config = HomeServerConfig.load_config("", sys.argv[3:])
+ try:
+ config = HomeServerConfig.load_config("", sys.argv[3:])
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
print getattr(config, key)
sys.exit(0)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a9304a11ba..15d78ff33a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -17,7 +17,6 @@ import argparse
import errno
import os
import yaml
-import sys
from textwrap import dedent
@@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs))
return results
- def generate_config(self, config_dir_path, server_name, report_stats=None):
+ def generate_config(
+ self,
+ config_dir_path,
+ server_name,
+ is_generating_file,
+ report_stats=None,
+ ):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
+ is_generating_file=is_generating_file,
report_stats=report_stats,
))
@@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name
if not server_name:
- print "Must specify a server_name to a generate config for."
- sys.exit(1)
+ raise ConfigError(
+ "Must specify a server_name to a generate config for."
+ " Pass -H server.name."
+ )
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
@@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
+ is_generating_file=True
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
@@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
- sys.exit(0)
+ return
else:
print (
"Config file %r already exists. Generating any missing key"
@@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config)
if "server_name" not in specified_config:
- sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
- sys.exit(1)
+ raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
_, config = obj.generate_config(
config_dir_path=config_dir_path,
- server_name=server_name
+ server_name=server_name,
+ is_generating_file=False,
)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
- sys.stderr.write(
- "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
- MISSING_REPORT_STATS_SPIEL + "\n")
- sys.exit(1)
+ raise ConfigError(
+ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
+ MISSING_REPORT_STATS_SPIEL
+ )
if generate_keys:
obj.invoke_all("generate_files", config)
- sys.exit(0)
+ return
obj.invoke_all("read_config", config)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ac90cd3fc1..a072aec714 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519
)
from unpaddedbase64 import decode_base64
+from synapse.util.stringutils import random_string_with_symbols
import os
+import hashlib
+import logging
+
+
+logger = logging.getLogger(__name__)
class KeyConfig(Config):
@@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"]
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ self.macaroon_secret_key = config.get(
+ "macaroon_secret_key", self.registration_shared_secret
+ )
+
+ if not self.macaroon_secret_key:
+ # Unfortunately, there are people out there that don't have this
+ # set. Lets just be "nice" and derive one from their secret key.
+ logger.warn("Config is missing missing macaroon_secret_key")
+ seed = self.signing_key[0].seed
+ self.macaroon_secret_key = hashlib.sha256(seed)
+
+ def default_config(self, config_dir_path, server_name, is_generating_file=False,
+ **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
+
+ if is_generating_file:
+ macaroon_secret_key = random_string_with_symbols(50)
+ else:
+ macaroon_secret_key = None
+
return """\
+ macaroon_secret_key: "%(macaroon_secret_key)s"
+
## Signing Keys ##
# Path to the signing key to sign messages with
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index d3f4b9d543..ab062d528c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -23,22 +23,23 @@ from distutils.util import strtobool
class RegistrationConfig(Config):
def read_config(self, config):
- self.disable_registration = not bool(
+ self.enable_registration = bool(
strtobool(str(config["enable_registration"]))
)
if "disable_registration" in config:
- self.disable_registration = bool(
+ self.enable_registration = not bool(
strtobool(str(config["disable_registration"]))
)
self.registration_shared_secret = config.get("registration_shared_secret")
- self.macaroon_secret_key = config.get("macaroon_secret_key")
+
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
+ self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
- macaroon_secret_key = random_string_with_symbols(50)
+
return """\
## Registration ##
@@ -49,8 +50,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
- macaroon_secret_key: "%(macaroon_secret_key)s"
-
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
@@ -60,6 +59,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
+
+ # The list of identity servers trusted to verify third party
+ # identifiers by this server.
+ trusted_third_party_id_servers:
+ - matrix.org
+ - vector.im
""" % locals()
def add_arguments(self, parser):
@@ -71,6 +76,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args):
if args.enable_registration is not None:
- self.disable_registration = not bool(
+ self.enable_registration = bool(
strtobool(str(args.enable_registration))
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index cddec0b2bc..d08ee0aa91 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import (
+ preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+ preserve_fn
+)
from twisted.internet import defer
@@ -142,40 +146,43 @@ class Keyring(object):
for server_name, _ in server_and_json
}
- # We want to wait for any previous lookups to complete before
- # proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
- server_to_deferred,
- )
+ with PreserveLoggingContext():
- # Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
- )
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- server_to_gids = {}
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- handle_key_deferred(
+ preserve_context_over_fn(
+ handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
@@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads
]
if wait_on:
- yield defer.DeferredList(wait_on)
+ with PreserveLoggingContext():
+ yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(deferred)
+ d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d
def rm(r, server_name):
@@ -244,12 +252,13 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
- key_id,
- merged_results[group.server_name][key_id],
- ))
+ with PreserveLoggingContext():
+ group_id_to_deferred[group.group_id].callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
break
else:
missing_groups.setdefault(
@@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store_keys(
+ preserve_fn(self.store_keys)(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store.store_server_keys_json(
+ preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
- self.store.store_server_verify_key(
+ preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 0bfb79d09f..979fdf2431 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -17,15 +17,10 @@
"""
from .replication import ReplicationLayer
-from .transport import TransportLayer
+from .transport.client import TransportLayerClient
def initialize_http_replication(homeserver):
- transport = TransportLayer(
- homeserver,
- homeserver.hostname,
- server=homeserver.get_resource_for_federation(),
- client=homeserver.get_http_client()
- )
+ transport = TransportLayerClient(homeserver)
return ReplicationLayer(homeserver, transport)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c6259f9dc8..e30e2da58d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
- expiry_ms=120*1000,
+ expiry_ms=120 * 1000,
reset_expiry_on_get=False,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a97aa0c94a..90718192dd 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = []
for pdu in pdu_list:
- d = self._handle_new_pdu(transaction.origin, pdu)
-
try:
- yield d
+ yield self._handle_new_pdu(transaction.origin, pdu)
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 6e0be8ef15..3e062a5eab 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
- self.transport_layer.register_received_handler(self)
- self.transport_layer.register_request_handler(self)
self.federation_client = self
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 622adad3ae..1928da03b3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -103,7 +103,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- @defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred)
- yield defer.DeferredList(deferreds, consumeErrors=True)
-
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 155a7d5870..d9fcc520a0 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
-
-from .server import TransportLayerServer
-from .client import TransportLayerClient
-
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-
-class TransportLayer(TransportLayerServer, TransportLayerClient):
- """This is a basic implementation of the transport layer that translates
- transactions and other requests to/from HTTP.
-
- Attributes:
- server_name (str): Local home server host
-
- server (synapse.http.server.HttpServer): the http server to
- register listeners on
-
- client (synapse.http.client.HttpClient): the http client used to
- send requests
-
- request_handler (TransportRequestHandler): The handler to fire when we
- receive requests for data.
-
- received_handler (TransportReceivedHandler): The handler to fire when
- we receive data.
- """
-
- def __init__(self, homeserver, server_name, server, client):
- """
- Args:
- server_name (str): Local home server host
- server (synapse.protocol.http.HttpServer): the http server to
- register listeners on
- client (synapse.protocol.http.HttpClient): the http client used to
- send requests
- """
- self.keyring = homeserver.get_keyring()
- self.clock = homeserver.get_clock()
- self.server_name = server_name
- self.server = server
- self.client = client
- self.request_handler = None
- self.received_handler = None
-
- self.ratelimiter = FederationRateLimiter(
- self.clock,
- window_size=homeserver.config.federation_rc_window_size,
- sleep_limit=homeserver.config.federation_rc_sleep_limit,
- sleep_msec=homeserver.config.federation_rc_sleep_delay,
- reject_limit=homeserver.config.federation_rc_reject_limit,
- concurrent_requests=homeserver.config.federation_rc_concurrent,
- )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 949d01dea8..2b5d40ea7f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class TransportLayerClient(object):
"""Sends federation HTTP requests to other servers"""
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.client = hs.get_http_client()
+
@log_function
def get_room_state(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8dca0a7f6b..65e054f7dd 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
-from synapse.util.logutils import log_function
+from synapse.http.server import JsonResource
+from synapse.util.ratelimitutils import FederationRateLimiter
import functools
import logging
@@ -28,9 +29,41 @@ import re
logger = logging.getLogger(__name__)
-class TransportLayerServer(object):
+class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
+ def __init__(self, hs):
+ self.hs = hs
+ self.clock = hs.get_clock()
+
+ super(TransportLayerServer, self).__init__(hs)
+
+ self.authenticator = Authenticator(hs)
+ self.ratelimiter = FederationRateLimiter(
+ self.clock,
+ window_size=hs.config.federation_rc_window_size,
+ sleep_limit=hs.config.federation_rc_sleep_limit,
+ sleep_msec=hs.config.federation_rc_sleep_delay,
+ reject_limit=hs.config.federation_rc_reject_limit,
+ concurrent_requests=hs.config.federation_rc_concurrent,
+ )
+
+ self.register_servlets()
+
+ def register_servlets(self):
+ register_servlets(
+ self.hs,
+ resource=self,
+ ratelimiter=self.ratelimiter,
+ authenticator=self.authenticator,
+ )
+
+
+class Authenticator(object):
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.server_name = hs.hostname
+
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request):
@@ -98,37 +131,9 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
- @log_function
- def register_received_handler(self, handler):
- """ Register a handler that will be fired when we receive data.
-
- Args:
- handler (TransportReceivedHandler)
- """
- FederationSendServlet(
- handler,
- authenticator=self,
- ratelimiter=self.ratelimiter,
- server_name=self.server_name,
- ).register(self.server)
-
- @log_function
- def register_request_handler(self, handler):
- """ Register a handler that will be fired when we get asked for data.
-
- Args:
- handler (TransportRequestHandler)
- """
- for servletclass in SERVLET_CLASSES:
- servletclass(
- handler,
- authenticator=self,
- ratelimiter=self.ratelimiter,
- ).register(self.server)
-
class BaseFederationServlet(object):
- def __init__(self, handler, authenticator, ratelimiter):
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
@@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/"
def __init__(self, handler, server_name, **kwargs):
- super(FederationSendServlet, self).__init__(handler, **kwargs)
+ super(FederationSendServlet, self).__init__(
+ handler, server_name=server_name, **kwargs
+ )
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
SERVLET_CLASSES = (
+ FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
@@ -451,3 +459,13 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
)
+
+
+def register_servlets(hs, resource, authenticator, ratelimiter):
+ for servletclass in SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_replication_layer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 744a9ee507..fa83d3e464 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -147,7 +147,7 @@ class BaseHandler(object):
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000*(time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now)),
)
@defer.inlineCallbacks
@@ -293,19 +293,11 @@ class BaseHandler(object):
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
- notify_d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- notify_d.addErrback(log_failure)
-
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 691564c651..4efecb1ffd 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
- [self.server_name]
- + [s for s in servers if s != self.server_name]
+ [self.server_name] +
+ [s for s in servers if s != self.server_name]
)
else:
servers = list(servers)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 254b483da6..4933c31c19 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
+from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler
@@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
- return distributor.fire("started_user_eventstream", user)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "started_user_eventstream", user
+ )
def stopped_user_eventstream(distributor, user):
- return distributor.fire("stopped_user_eventstream", user)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "stopped_user_eventstream", user
+ )
class EventStreamHandler(BaseHandler):
@@ -130,7 +137,7 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
- timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
+ timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6c19d6ae8c..b78b0502d9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
@@ -643,19 +635,11 @@ class FederationHandler(BaseHandler):
)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -730,18 +714,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
@@ -811,19 +787,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
defer.returnValue(event)
@defer.inlineCallbacks
@@ -948,18 +916,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
new_pdu = event
destinations = set()
@@ -1186,7 +1146,13 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(e, auth_events=auth_for_e)
- except AuthError as err:
+ except SynapseError as err:
+ # we may get SynapseErrors here as well as AuthErrors. For
+ # instance, there are a couple of (ancient) events in some
+ # rooms whose senders do not have the correct sigil; these
+ # cause SynapseErrors in auth.check. We don't want to give up
+ # the attempt to federate altogether in such cases.
+
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 819ec57c4f..656ce124f9 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client()
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
- # XXX: make this configurable!
- # trustedIdServers = ['matrix.org', 'localhost:8090']
- trustedIdServers = ['matrix.org', 'vector.im']
-
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@@ -58,10 +59,19 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
- if id_server not in trustedIdServers:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server)
- defer.returnValue(None)
+ if id_server not in self.trusted_id_servers:
+ if self.trust_any_id_server_just_for_testing_do_not_use:
+ logger.warn(
+ "Trusting untrustworthy ID server %r even though it isn't"
+ " in the trusted id list for testing because"
+ " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
+ " is set in the config",
+ id_server,
+ )
+ else:
+ logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server)
+ defer.returnValue(None)
data = {}
try:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ff800f8af1..82c8cb5f0c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError, AuthError, Codes
+from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -105,8 +105,6 @@ class MessageHandler(BaseHandler):
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
- if room_token.topological is None:
- raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
@@ -117,27 +115,31 @@ class MessageHandler(BaseHandler):
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
- if membership == Membership.LEAVE:
- # If they have left the room then clamp the token to be before
- # they left the room.
- leave_token = yield self.store.get_topological_token_for_event(
- member_event_id
+
+ if source_config.direction == 'b':
+ # if we're going backwards, we might need to backfill. This
+ # requires that we have a topo token.
+ if room_token.topological:
+ max_topo = room_token.topological
+ else:
+ max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
+ room_id, room_token.stream
+ )
+
+ if membership == Membership.LEAVE:
+ # If they have left the room then clamp the token to be before
+ # they left the room, to save the effort of loading from the
+ # database.
+ leave_token = yield self.store.get_topological_token_for_event(
+ member_event_id
+ )
+ leave_token = RoomStreamToken.parse(leave_token)
+ if leave_token.topological < max_topo:
+ source_config.from_key = str(leave_token)
+
+ yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ room_id, max_topo
)
- leave_token = RoomStreamToken.parse(leave_token)
- if leave_token.topological < room_token.topological:
- source_config.from_key = str(leave_token)
-
- if source_config.direction == "f":
- if source_config.to_key is None:
- source_config.to_key = str(leave_token)
- else:
- to_token = RoomStreamToken.parse(source_config.to_key)
- if leave_token.topological < to_token.topological:
- source_config.to_key = str(leave_token)
-
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, room_token.topological
- )
events, next_key = yield data_source.get_pagination_rows(
requester.user, source_config, room_id
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d36eb3b8d7..b61394f2b5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
-LAST_ACTIVE_GRANULARITY = 60*1000
+LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000
@@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
- self.start_polling_presence(target_user, state=state)
+ yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
- self.stop_polling_presence(target_user)
+ yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
@@ -394,7 +394,8 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
- self.changed_presencelike_data(user, {"last_active": now})
+ with PreserveLoggingContext():
+ self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
@@ -466,11 +467,12 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False
)
- self.push_update_to_local_and_remote(
- observed_user=local_user,
- users_to_push=[user],
- statuscache=statuscache,
- )
+ with PreserveLoggingContext():
+ self.push_update_to_local_and_remote(
+ observed_user=local_user,
+ users_to_push=[user],
+ statuscache=statuscache,
+ )
@defer.inlineCallbacks
def send_presence_invite(self, observer_user, observed_user):
@@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
- self.start_polling_presence(
+ yield self.start_polling_presence(
observer_user, target_user=observed_user
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c11b98d0b7..24c850ae9b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
-import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
@@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs)
+ self._next_generated_user_id = None
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor()
@@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler):
Args:
localpart : The local part of the user ID to register. If None,
- one will be randomly generated.
+ one will be generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
@@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler):
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
+ was_guest = guest_access_token is not None
+
+ if not was_guest:
+ try:
+ int(localpart)
+ raise RegistrationError(
+ 400,
+ "Numeric user IDs are reserved for guest users."
+ )
+ except ValueError:
+ pass
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -118,38 +131,36 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
token=token,
password_hash=password_hash,
- was_guest=guest_access_token is not None,
+ was_guest=was_guest,
make_guest=make_guest,
)
yield registered_user(self.distributor, user)
else:
- # autogen a random user ID
+ # autogen a sequential user ID
attempts = 0
- user_id = None
token = None
- while not user_id:
+ user = None
+ while not user:
+ localpart = yield self._generate_user_id(attempts > 0)
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ yield self.check_user_id_is_valid(user_id)
+ if generate_token:
+ token = self.auth_handler().generate_access_token(user_id)
try:
- localpart = self._generate_user_id()
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
- yield self.check_user_id_is_valid(user_id)
- if generate_token:
- token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=password_hash)
-
- yield registered_user(self.distributor, user)
+ password_hash=password_hash,
+ make_guest=make_guest
+ )
except SynapseError:
# if user id is taken, just generate another
user_id = None
token = None
attempts += 1
- if attempts > 5:
- raise RegistrationError(
- 500, "Cannot generate user ID.")
+ yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
@@ -175,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token,
password_hash=""
)
- registered_user(self.distributor, user)
+ yield registered_user(self.distributor, user)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -211,7 +222,7 @@ class RegistrationHandler(BaseHandler):
400,
"User ID must only contain characters which do not"
" require URL encoding."
- )
+ )
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -281,8 +292,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
- def _generate_user_id(self):
- return "-" + stringutils.random_string(18)
+ @defer.inlineCallbacks
+ def _generate_user_id(self, reseed=False):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
+
+ id = self._next_generated_user_id
+ self._next_generated_user_id += 1
+ defer.returnValue(str(id))
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 58e2d25f97..a8e3a9029c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -18,13 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
+from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
@@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id):
- return distributor.fire("user_left_room", user=user, room_id=room_id)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_left_room", user=user, room_id=room_id
+ )
def user_joined_room(distributor, user, room_id):
- return distributor.fire("user_joined_room", user=user, room_id=room_id)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_joined_room", user=user, room_id=room_id
+ )
class RoomCreationHandler(BaseHandler):
@@ -876,39 +883,71 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
- chunk = yield self.store.get_rooms(is_public=True)
-
- room_members = yield defer.gatherResults(
- [
- self.store.get_users_in_room(room["room_id"])
- for room in chunk
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
-
- avatar_urls = yield defer.gatherResults(
- [
- self.get_room_avatar_url(room["room_id"])
- for room in chunk
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
-
- for i, room in enumerate(chunk):
- room["num_joined_members"] = len(room_members[i])
- if avatar_urls[i]:
- room["avatar_url"] = avatar_urls[i]
+ room_ids = yield self.store.get_public_room_ids()
+
+ @defer.inlineCallbacks
+ def handle_room(room_id):
+ aliases = yield self.store.get_aliases_for_room(room_id)
+ if not aliases:
+ defer.returnValue(None)
+
+ state = yield self.state_handler.get_current_state(room_id)
+
+ result = {"aliases": aliases, "room_id": room_id}
+
+ name_event = state.get((EventTypes.Name, ""), None)
+ if name_event:
+ name = name_event.content.get("name", None)
+ if name:
+ result["name"] = name
+
+ topic_event = state.get((EventTypes.Topic, ""), None)
+ if topic_event:
+ topic = topic_event.content.get("topic", None)
+ if topic:
+ result["topic"] = topic
+
+ canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
+ if canonical_event:
+ canonical_alias = canonical_event.content.get("alias", None)
+ if canonical_alias:
+ result["canonical_alias"] = canonical_alias
+
+ visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+ visibility = None
+ if visibility_event:
+ visibility = visibility_event.content.get("history_visibility", None)
+ result["world_readable"] = visibility == "world_readable"
+
+ guest_event = state.get((EventTypes.GuestAccess, ""), None)
+ guest = None
+ if guest_event:
+ guest = guest_event.content.get("guest_access", None)
+ result["guest_can_join"] = guest == "can_join"
+
+ avatar_event = state.get(("m.room.avatar", ""), None)
+ if avatar_event:
+ avatar_url = avatar_event.content.get("url", None)
+ if avatar_url:
+ result["avatar_url"] = avatar_url
+
+ result["num_joined_members"] = sum(
+ 1 for (event_type, _), ev in state.items()
+ if event_type == EventTypes.Member and ev.membership == Membership.JOIN
+ )
- # FIXME (erikj): START is no longer a valid value
- defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
+ defer.returnValue(result)
- @defer.inlineCallbacks
- def get_room_avatar_url(self, room_id):
- event = yield self.hs.get_state_handler().get_current_state(
- room_id, "m.room.avatar"
- )
- if event and "url" in event.content:
- defer.returnValue(event.content["url"])
+ result = []
+ for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
+ chunk_result = yield defer.gatherResults([
+ handle_room(room_id)
+ for room_id in chunk
+ ], consumeErrors=True).addErrback(unwrapFirstError)
+ result.extend(v for v in chunk_result if v)
+
+ # FIXME (erikj): START is no longer a valid value
+ defer.returnValue({"start": "START", "end": "END", "chunk": result})
class RoomContextHandler(BaseHandler):
@@ -927,7 +966,7 @@ class RoomContextHandler(BaseHandler):
Returns:
dict, or None if the event isn't found
"""
- before_limit = math.floor(limit/2.)
+ before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -997,6 +1036,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key()
+ from_token = RoomStreamToken.parse(from_key)
+ if from_token.topological:
+ logger.warn("Stream has topological part!!!! %r", from_key)
+ from_key = "s%s" % (from_token.stream,)
+
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
@@ -1008,15 +1052,30 @@ class RoomEventSource(object):
limit=limit,
)
else:
- events, end_key = yield self.store.get_room_events_stream(
- user_id=user.to_string(),
+ room_events = yield self.store.get_membership_changes_for_user(
+ user.to_string(), from_key, to_key
+ )
+
+ room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_ids=room_ids,
from_key=from_key,
to_key=to_key,
- limit=limit,
- room_ids=room_ids,
- is_guest=is_guest,
+ limit=limit or 10,
)
+ events = list(room_events)
+ events.extend(e for evs, _ in room_to_events.values() for e in evs)
+
+ events.sort(key=lambda e: e.internal_metadata.order)
+
+ if limit:
+ events[:] = events[:limit]
+
+ if events:
+ end_key = events[-1].internal_metadata.after
+ else:
+ end_key = to_key
+
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 328c049b03..ddeed27965 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,11 +18,14 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.metrics import Measure
from twisted.internet import defer
import collections
import logging
+import itertools
logger = logging.getLogger(__name__)
@@ -72,7 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
)
-class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
+class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
@@ -139,6 +142,15 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult.
"""
+ context = LoggingContext.current_context()
+ if context:
+ if since_token is None:
+ context.tag = "initial_sync"
+ elif full_state:
+ context.tag = "full_state_sync"
+ else:
+ context.tag = "incremental_sync"
+
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
@@ -167,18 +179,6 @@ class SyncHandler(BaseHandler):
else:
return self.incremental_sync_with_gap(sync_config, since_token)
- def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
- if room_id not in ephemeral_by_room:
- return None
- for e in ephemeral_by_room[room_id]:
- if e['type'] != 'm.receipt':
- continue
- for receipt_event_id, val in e['content'].items():
- if 'm.read' in val:
- if user_id in val['m.read']:
- return receipt_event_id
- return None
-
@defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state.
@@ -230,15 +230,16 @@ class SyncHandler(BaseHandler):
deferreds = []
for event in room_list:
if event.membership == Membership.JOIN:
- room_sync_deferred = self.full_state_sync_for_joined_room(
- room_id=event.room_id,
- sync_config=sync_config,
- now_token=now_token,
- timeline_since_token=timeline_since_token,
- ephemeral_by_room=ephemeral_by_room,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
+ with PreserveLoggingContext(LoggingContext.current_context()):
+ room_sync_deferred = self.full_state_sync_for_joined_room(
+ room_id=event.room_id,
+ sync_config=sync_config,
+ now_token=now_token,
+ timeline_since_token=timeline_since_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ )
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
@@ -251,15 +252,16 @@ class SyncHandler(BaseHandler):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
- room_sync_deferred = self.full_state_sync_for_archived_room(
- sync_config=sync_config,
- room_id=event.room_id,
- leave_event_id=event.event_id,
- leave_token=leave_token,
- timeline_since_token=timeline_since_token,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
+ with PreserveLoggingContext(LoggingContext.current_context()):
+ room_sync_deferred = self.full_state_sync_for_archived_room(
+ sync_config=sync_config,
+ room_id=event.room_id,
+ leave_event_id=event.event_id,
+ leave_token=leave_token,
+ timeline_since_token=timeline_since_token,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ )
room_sync_deferred.addCallback(archived.append)
deferreds.append(room_sync_deferred)
@@ -298,46 +300,18 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token
)
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, ephemeral_by_room
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id, sync_config,
+ now_token=now_token,
+ since_token=timeline_since_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ batch=batch,
+ full_state=True,
)
- unread_notifications = {}
- if notifs is not None:
- unread_notifications["notification_count"] = len(notifs)
- unread_notifications["highlight_count"] = len([
- 1 for notif in notifs if _action_has_highlight(notif["actions"])
- ])
-
- current_state = yield self.get_state_at(room_id, now_token)
-
- current_state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- current_state.values()
- )
- }
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
- )
-
- ephemeral = sync_config.filter_collection.filter_room_ephemeral(
- ephemeral_by_room.get(room_id, [])
- )
-
- defer.returnValue(JoinedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=current_state,
- ephemeral=ephemeral,
- account_data=account_data,
- unread_notifications=unread_notifications,
- ))
+ defer.returnValue(room_sync)
def account_data_for_user(self, account_data):
account_data_events = []
@@ -382,91 +356,68 @@ class SyncHandler(BaseHandler):
typing events for that room.
"""
- typing_key = since_token.typing_key if since_token else "0"
-
- rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
- room_ids = [room.room_id for room in rooms]
-
- typing_source = self.event_sources.sources["typing"]
- typing, typing_key = yield typing_source.get_new_events(
- user=sync_config.user,
- from_key=typing_key,
- limit=sync_config.filter_collection.ephemeral_limit(),
- room_ids=room_ids,
- is_guest=sync_config.is_guest,
- )
- now_token = now_token.copy_and_replace("typing_key", typing_key)
-
- ephemeral_by_room = {}
+ with Measure(self.clock, "ephemeral_by_room"):
+ typing_key = since_token.typing_key if since_token else "0"
- for event in typing:
- # we want to exclude the room_id from the event, but modifying the
- # result returned by the event source is poor form (it might cache
- # the object)
- room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
- if k != "room_id"}
- ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+ rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+ room_ids = [room.room_id for room in rooms]
- receipt_key = since_token.receipt_key if since_token else "0"
-
- receipt_source = self.event_sources.sources["receipt"]
- receipts, receipt_key = yield receipt_source.get_new_events(
- user=sync_config.user,
- from_key=receipt_key,
- limit=sync_config.filter_collection.ephemeral_limit(),
- room_ids=room_ids,
- is_guest=sync_config.is_guest,
- )
- now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+ typing_source = self.event_sources.sources["typing"]
+ typing, typing_key = yield typing_source.get_new_events(
+ user=sync_config.user,
+ from_key=typing_key,
+ limit=sync_config.filter_collection.ephemeral_limit(),
+ room_ids=room_ids,
+ is_guest=sync_config.is_guest,
+ )
+ now_token = now_token.copy_and_replace("typing_key", typing_key)
+
+ ephemeral_by_room = {}
+
+ for event in typing:
+ # we want to exclude the room_id from the event, but modifying the
+ # result returned by the event source is poor form (it might cache
+ # the object)
+ room_id = event["room_id"]
+ event_copy = {k: v for (k, v) in event.iteritems()
+ if k != "room_id"}
+ ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+
+ receipt_key = since_token.receipt_key if since_token else "0"
+
+ receipt_source = self.event_sources.sources["receipt"]
+ receipts, receipt_key = yield receipt_source.get_new_events(
+ user=sync_config.user,
+ from_key=receipt_key,
+ limit=sync_config.filter_collection.ephemeral_limit(),
+ room_ids=room_ids,
+ is_guest=sync_config.is_guest,
+ )
+ now_token = now_token.copy_and_replace("receipt_key", receipt_key)
- for event in receipts:
- room_id = event["room_id"]
- # exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
- if k != "room_id"}
- ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+ for event in receipts:
+ room_id = event["room_id"]
+ # exclude room id, as above
+ event_copy = {k: v for (k, v) in event.iteritems()
+ if k != "room_id"}
+ ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
- @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state
Returns:
- A Deferred JoinedSyncResult.
+ A Deferred ArchivedSyncResult.
"""
- batch = yield self.load_filtered_recents(
- room_id, sync_config, leave_token, since_token=timeline_since_token
- )
-
- leave_state = yield self.store.get_state_for_event(leave_event_id)
-
- leave_state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- leave_state.values()
- )
- }
-
- account_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
-
- account_data = sync_config.filter_collection.filter_room_account_data(
- account_data
+ return self.incremental_sync_for_archived_room(
+ sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
+ account_data_by_room, full_state=True, leave_token=leave_token,
)
- defer.returnValue(ArchivedSyncResult(
- room_id=room_id,
- timeline=batch,
- state=leave_state,
- account_data=account_data,
- ))
-
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
@@ -489,13 +440,6 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
- # We now fetch all ephemeral events for this room in order to get
- # this users current read receipt. This could almost certainly be
- # optimised.
- _, all_ephemeral_by_room = yield self.ephemeral_by_room(
- sync_config, now_token
- )
-
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
@@ -512,154 +456,126 @@ class SyncHandler(BaseHandler):
sync_config.user
)
- timeline_limit = sync_config.filter_collection.timeline_limit()
+ user_id = sync_config.user.to_string()
- room_events, _ = yield self.store.get_room_events_stream(
- sync_config.user.to_string(),
- from_key=since_token.room_key,
- to_key=now_token.room_key,
- limit=timeline_limit + 1,
- )
+ timeline_limit = sync_config.filter_collection.timeline_limit()
tags_by_room = yield self.store.get_updated_tags(
- sync_config.user.to_string(),
+ user_id,
since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
- sync_config.user.to_string(),
+ user_id,
since_token.account_data_key,
)
)
- joined = []
- archived = []
- if len(room_events) <= timeline_limit:
- # There is no gap in any of the rooms. Therefore we can just
- # partition the new events by room and return them.
- logger.debug("Got %i events for incremental sync - not limited",
- len(room_events))
-
- invite_events = []
- leave_events = []
- events_by_room_id = {}
- for event in room_events:
- events_by_room_id.setdefault(event.room_id, []).append(event)
- if event.room_id not in joined_room_ids:
- if (event.type == EventTypes.Member
- and event.state_key == sync_config.user.to_string()):
- if event.membership == Membership.INVITE:
- invite_events.append(event)
- elif event.membership in (Membership.LEAVE, Membership.BAN):
- leave_events.append(event)
-
- for room_id in joined_room_ids:
- recents = events_by_room_id.get(room_id, [])
- logger.debug("Events for room %s: %r", room_id, recents)
- state = {
- (event.type, event.state_key): event
- for event in recents if event.is_state()}
- limited = False
-
- if recents:
- prev_batch = now_token.copy_and_replace(
- "room_key", recents[0].internal_metadata.before
- )
- else:
- prev_batch = now_token
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- logger.debug("User has just joined %s: needs full state",
- room_id)
- state = yield self.get_state_at(room_id, now_token)
- # the timeline is inherently limited if we've just joined
- limited = True
-
- recents = sync_config.filter_collection.filter_room_timeline(recents)
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state.values()
- )
- }
-
- acc_data = self.account_data_for_room(
- room_id, tags_by_room, account_data_by_room
- )
+ # Get a list of membership change events that have happened.
+ rooms_changed = yield self.store.get_membership_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key
+ )
- acc_data = sync_config.filter_collection.filter_room_account_data(
- acc_data
- )
+ mem_change_events_by_room_id = {}
+ for event in rooms_changed:
+ mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
- ephemeral = sync_config.filter_collection.filter_room_ephemeral(
- ephemeral_by_room.get(room_id, [])
- )
+ newly_joined_rooms = []
+ archived = []
+ invited = []
+ for room_id, events in mem_change_events_by_room_id.items():
+ non_joins = [e for e in events if e.membership != Membership.JOIN]
+ has_join = len(non_joins) != len(events)
+
+ # We want to figure out if we joined the room at some point since
+ # the last sync (even if we have since left). This is to make sure
+ # we do send down the room, and with full state, where necessary
+ if room_id in joined_room_ids or has_join:
+ old_state = yield self.get_state_at(room_id, since_token)
+ old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+ if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
+ newly_joined_rooms.append(room_id)
+
+ if room_id in joined_room_ids:
+ continue
+
+ if not non_joins:
+ continue
- room_sync = JoinedSyncResult(
- room_id=room_id,
- timeline=TimelineBatch(
- events=recents,
- prev_batch=prev_batch,
- limited=limited,
- ),
- state=state,
- ephemeral=ephemeral,
- account_data=acc_data,
- unread_notifications={},
+ # Only bother if we're still currently invited
+ should_invite = non_joins[-1].membership == Membership.INVITE
+ if should_invite:
+ room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if room_sync:
+ invited.append(room_sync)
+
+ # Always include leave/ban events. Just take the last one.
+ # TODO: How do we handle ban -> leave in same batch?
+ leave_events = [
+ e for e in non_joins
+ if e.membership in (Membership.LEAVE, Membership.BAN)
+ ]
+
+ if leave_events:
+ leave_event = leave_events[-1]
+ room_sync = yield self.incremental_sync_for_archived_room(
+ sync_config, room_id, leave_event.event_id, since_token,
+ tags_by_room, account_data_by_room,
+ full_state=room_id in newly_joined_rooms
)
- logger.debug("Result for room %s: %r", room_id, room_sync)
-
if room_sync:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, all_ephemeral_by_room
- )
+ archived.append(room_sync)
- if notifs is not None:
- notif_dict = room_sync.unread_notifications
- notif_dict["notification_count"] = len(notifs)
- notif_dict["highlight_count"] = len([
- 1 for notif in notifs
- if _action_has_highlight(notif["actions"])
- ])
+ # Get all events for rooms we're currently joined to.
+ room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_ids=joined_room_ids,
+ from_key=since_token.room_key,
+ to_key=now_token.room_key,
+ limit=timeline_limit + 1,
+ )
- joined.append(room_sync)
+ joined = []
+ # We loop through all room ids, even if there are no new events, in case
+ # there are non room events taht we need to notify about.
+ for room_id in joined_room_ids:
+ room_entry = room_to_events.get(room_id, None)
- else:
- logger.debug("Got %i events for incremental sync - hit limit",
- len(room_events))
+ if room_entry:
+ events, start_key = room_entry
- invite_events = yield self.store.get_invites_for_user(
- sync_config.user.to_string()
- )
+ prev_batch_token = now_token.copy_and_replace("room_key", start_key)
- leave_events = yield self.store.get_leave_and_ban_events_for_user(
- sync_config.user.to_string()
- )
+ newly_joined_room = room_id in newly_joined_rooms
+ full_state = newly_joined_room
- for room_id in joined_room_ids:
- room_sync = yield self.incremental_sync_with_gap_for_room(
- room_id, sync_config, since_token, now_token,
- ephemeral_by_room, tags_by_room, account_data_by_room,
- all_ephemeral_by_room=all_ephemeral_by_room,
+ batch = yield self.load_filtered_recents(
+ room_id, sync_config, prev_batch_token,
+ since_token=since_token,
+ recents=events,
+ newly_joined_room=newly_joined_room,
)
- if room_sync:
- joined.append(room_sync)
+ else:
+ batch = TimelineBatch(
+ events=[],
+ prev_batch=since_token,
+ limited=False,
+ )
+ full_state = False
- for leave_event in leave_events:
- room_sync = yield self.incremental_sync_for_archived_room(
- sync_config, leave_event, since_token, tags_by_room,
- account_data_by_room
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id=room_id,
+ sync_config=sync_config,
+ since_token=since_token,
+ now_token=now_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ batch=batch,
+ full_state=full_state,
)
if room_sync:
- archived.append(room_sync)
-
- invited = [
- InvitedSyncResult(room_id=event.room_id, invite=event)
- for event in invite_events
- ]
+ joined.append(room_sync)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@@ -680,51 +596,73 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None):
+ since_token=None, recents=None, newly_joined_room=False):
"""
:returns a Deferred TimelineBatch
"""
- limited = True
- recents = []
- filtering_factor = 2
- timeline_limit = sync_config.filter_collection.timeline_limit()
- load_limit = max(timeline_limit * filtering_factor, 100)
- max_repeat = 3 # Only try a few times per room, otherwise
- room_key = now_token.room_key
- end_key = room_key
-
- while limited and len(recents) < timeline_limit and max_repeat:
- events, keys = yield self.store.get_recent_events_for_room(
- room_id,
- limit=load_limit + 1,
- from_token=since_token.room_key if since_token else None,
- end_token=end_key,
- )
- room_key, _ = keys
- end_key = "s" + room_key.split('-')[-1]
- loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
- loaded_recents = yield self._filter_events_for_client(
- sync_config.user.to_string(),
- loaded_recents,
- is_peeking=sync_config.is_guest,
- )
- loaded_recents.extend(recents)
- recents = loaded_recents
- if len(events) <= load_limit:
+ with Measure(self.clock, "load_filtered_recents"):
+ filtering_factor = 2
+ timeline_limit = sync_config.filter_collection.timeline_limit()
+ load_limit = max(timeline_limit * filtering_factor, 10)
+ max_repeat = 5 # Only try a few times per room, otherwise
+ room_key = now_token.room_key
+ end_key = room_key
+
+ if recents is None or newly_joined_room or timeline_limit < len(recents):
+ limited = True
+ else:
limited = False
- max_repeat -= 1
- if len(recents) > timeline_limit:
- limited = True
- recents = recents[-timeline_limit:]
- room_key = recents[0].internal_metadata.before
+ if recents is not None:
+ recents = sync_config.filter_collection.filter_room_timeline(recents)
+ recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(),
+ recents,
+ is_peeking=sync_config.is_guest,
+ )
+ else:
+ recents = []
+
+ since_key = None
+ if since_token and not newly_joined_room:
+ since_key = since_token.room_key
+
+ while limited and len(recents) < timeline_limit and max_repeat:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ loaded_recents = sync_config.filter_collection.filter_room_timeline(
+ events
+ )
+ loaded_recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(),
+ loaded_recents,
+ is_peeking=sync_config.is_guest,
+ )
+ loaded_recents.extend(recents)
+ recents = loaded_recents
+
+ if len(events) <= load_limit:
+ limited = False
+ break
+ max_repeat -= 1
- prev_batch_token = now_token.copy_and_replace(
- "room_key", room_key
- )
+ if len(recents) > timeline_limit:
+ limited = True
+ recents = recents[-timeline_limit:]
+ room_key = recents[0].internal_metadata.before
+
+ prev_batch_token = now_token.copy_and_replace(
+ "room_key", room_key
+ )
defer.returnValue(TimelineBatch(
- events=recents, prev_batch=prev_batch_token, limited=limited
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room
))
@defer.inlineCallbacks
@@ -732,62 +670,12 @@ class SyncHandler(BaseHandler):
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room,
- all_ephemeral_by_room):
- """ Get the incremental delta needed to bring the client up to date for
- the room. Gives the client the most recent events and the changes to
- state.
- Returns:
- A Deferred JoinedSyncResult
- """
- logger.debug("Doing incremental sync for room %s between %s and %s",
- room_id, since_token, now_token)
-
- # TODO(mjark): Check for redactions we might have missed.
-
- batch = yield self.load_filtered_recents(
- room_id, sync_config, now_token, since_token,
- )
-
- logger.debug("Recents %r", batch)
-
- if batch.limited:
- current_state = yield self.get_state_at(room_id, now_token)
-
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=current_state,
- )
- else:
- state = {
- (event.type, event.state_key): event
- for event in batch.events if event.is_state()
- }
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- state = yield self.get_state_at(room_id, now_token)
-
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, all_ephemeral_by_room
+ batch, full_state=False):
+ state = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token,
+ full_state=full_state
)
- unread_notifications = {}
- if notifs is not None:
- unread_notifications["notification_count"] = len(notifs)
- unread_notifications["highlight_count"] = len([
- 1 for notif in notifs if _action_has_highlight(notif["actions"])
- ])
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
- }
-
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
@@ -800,6 +688,7 @@ class SyncHandler(BaseHandler):
ephemeral_by_room.get(room_id, [])
)
+ unread_notifications = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -809,58 +698,53 @@ class SyncHandler(BaseHandler):
unread_notifications=unread_notifications,
)
+ if room_sync:
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config
+ )
+
+ if notifs is not None:
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
+
logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
- def incremental_sync_for_archived_room(self, sync_config, leave_event,
+ def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room,
- account_data_by_room):
+ account_data_by_room, full_state,
+ leave_token=None):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
- stream_token = yield self.store.get_stream_token_for_event(
- leave_event.event_id
- )
+ if not leave_token:
+ stream_token = yield self.store.get_stream_token_for_event(
+ leave_event_id
+ )
- leave_token = since_token.copy_and_replace("room_key", stream_token)
+ leave_token = since_token.copy_and_replace("room_key", stream_token)
- if since_token.is_after(leave_token):
+ if since_token and since_token.is_after(leave_token):
defer.returnValue(None)
batch = yield self.load_filtered_recents(
- leave_event.room_id, sync_config, leave_token, since_token,
+ room_id, sync_config, leave_token, since_token,
)
logger.debug("Recents %r", batch)
- state_events_at_leave = yield self.store.get_state_for_event(
- leave_event.event_id
- )
-
- state_at_previous_sync = yield self.get_state_at(
- leave_event.room_id, stream_position=since_token
- )
-
state_events_delta = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=state_events_at_leave,
+ room_id, batch, sync_config, since_token, leave_token,
+ full_state=full_state
)
- state_events_delta = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state_events_delta.values()
- )
- }
-
account_data = self.account_data_for_room(
- leave_event.room_id, tags_by_room, account_data_by_room
+ room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
@@ -868,7 +752,7 @@ class SyncHandler(BaseHandler):
)
room_sync = ArchivedSyncResult(
- room_id=leave_event.room_id,
+ room_id=room_id,
timeline=batch,
state=state_events_delta,
account_data=account_data,
@@ -912,15 +796,19 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
- def compute_state_delta(self, since_token, previous_state, current_state):
- """ Works out the differnce in state between the current state and the
- state the client got when it last performed a sync.
-
- :param str since_token: the point we are comparing against
- :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
- state to compare to
- :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
- new state
+ @defer.inlineCallbacks
+ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+ full_state):
+ """ Works out the differnce in state between the start of the timeline
+ and the previous sync.
+
+ :param str room_id
+ :param TimelineBatch batch: The timeline batch for the room that will
+ be sent to the user.
+ :param sync_config
+ :param str since_token: Token of the end of the previous batch. May be None.
+ :param str now_token: Token of the end of the current batch.
+ :param bool full_state: Whether to force returning the full state.
:returns A new event dictionary
"""
@@ -929,12 +817,53 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
- state_delta = {}
- for key, event in current_state.iteritems():
- if (key not in previous_state or
- previous_state[key].event_id != event.event_id):
- state_delta[key] = event
- return state_delta
+ with Measure(self.clock, "compute_state_delta"):
+ if full_state:
+ if batch:
+ state = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+ else:
+ state = yield self.get_state_at(
+ room_id, stream_position=now_token
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state,
+ previous={},
+ )
+ elif batch.limited:
+ state_at_previous_sync = yield self.get_state_at(
+ room_id, stream_position=since_token
+ )
+
+ state_at_timeline_start = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state_at_timeline_start,
+ previous=state_at_previous_sync,
+ )
+ else:
+ state = {}
+
+ defer.returnValue({
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(state.values())
+ })
def check_joined_room(self, sync_config, state_delta):
"""
@@ -955,21 +884,24 @@ class SyncHandler(BaseHandler):
return False
@defer.inlineCallbacks
- def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
- last_unread_event_id = self.last_read_event_id_for_room_and_user(
- room_id, sync_config.user.to_string(), ephemeral_by_room
- )
-
- notifs = []
- if last_unread_event_id:
- notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
- room_id, sync_config.user.to_string(), last_unread_event_id
+ def unread_notifs_for_room_id(self, room_id, sync_config):
+ with Measure(self.clock, "unread_notifs_for_room_id"):
+ last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+ user_id=sync_config.user.to_string(),
+ room_id=room_id,
+ receipt_type="m.read"
)
- defer.returnValue(notifs)
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- defer.returnValue(None)
+ notifs = []
+ if last_unread_event_id:
+ notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ defer.returnValue(notifs)
+
+ # There is no new information in this period, so your notification
+ # count is whatever it was last time.
+ defer.returnValue(None)
def _action_has_highlight(actions):
@@ -981,3 +913,37 @@ def _action_has_highlight(actions):
pass
return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+ """Works out what state to include in a sync response.
+
+ Args:
+ timeline_contains (dict): state in the timeline
+ timeline_start (dict): state at the start of the timeline
+ previous (dict): state at the end of the previous sync (or empty dict
+ if this is an initial sync)
+
+ Returns:
+ dict
+ """
+ event_id_to_state = {
+ e.event_id: e
+ for e in itertools.chain(
+ timeline_contains.values(),
+ previous.values(),
+ timeline_start.values(),
+ )
+ }
+
+ tc_ids = set(e.event_id for e in timeline_contains.values())
+ p_ids = set(e.event_id for e in previous.values())
+ ts_ids = set(e.event_id for e in timeline_start.values())
+
+ state_ids = (ts_ids - p_ids) - tc_ids
+
+ evs = (event_id_to_state[e] for e in state_ids)
+ return {
+ (e.type, e.state_key): e
+ for e in evs
+ }
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 43bf600913..b16d0017df 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.metrics import Measure
from synapse.types import UserID
import logging
@@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
+ self.clock = hs.get_clock()
self._handler = None
self._room_member_handler = None
@@ -247,19 +249,20 @@ class TypingNotificationEventSource(object):
}
def get_new_events(self, from_key, room_ids, **kwargs):
- from_key = int(from_key)
- handler = self.handler()
+ with Measure(self.clock, "typing.get_new_events"):
+ from_key = int(from_key)
+ handler = self.handler()
- events = []
- for room_id in room_ids:
- if room_id not in handler._room_serials:
- continue
- if handler._room_serials[room_id] <= from_key:
- continue
+ events = []
+ for room_id in room_ids:
+ if room_id not in handler._room_serials:
+ continue
+ if handler._room_serials[room_id] <= from_key:
+ continue
- events.append(self._make_event_for(room_id))
+ events.append(self._make_event_for(room_id))
- return events, handler._latest_room_serial
+ return events, handler._latest_room_serial
def get_current_key(self):
return self.handler()._latest_room_serial
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index da13e32e78..c3589534f8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred(
request_deferred,
- time_out=timeout/1000. if timeout else 60,
+ time_out=timeout / 1000. if timeout else 60,
)
response = yield preserve_context_over_fn(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 10d1fcd3f6..a90e2e1125 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter(
"requests",
- labels=["method", "servlet"],
+ labels=["method", "servlet", "tag"],
)
outgoing_responses_counter = metrics.register_counter(
"responses",
@@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution(
"response_time",
- labels=["method", "servlet"]
+ labels=["method", "servlet", "tag"]
)
response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet"]
+ "response_ru_utime", labels=["method", "servlet", "tag"]
)
response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet"]
+ "response_ru_stime", labels=["method", "servlet", "tag"]
)
response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet"]
+ "response_db_txn_count", labels=["method", "servlet", "tag"]
)
response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet"]
+ "response_db_txn_duration", labels=["method", "servlet", "tag"]
)
@@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id
with request.processing():
try:
- d = request_handler(self, request)
- with PreserveLoggingContext():
- yield d
+ with PreserveLoggingContext(request_context):
+ yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
@@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
+
+ start_context = LoggingContext.current_context()
+
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
- incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
- response_timer.inc_by(
- self.clock.time_msec() - start, request.method, servlet_classname
- )
-
try:
context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+ response_timer.inc_by(
+ self.clock.time_msec() - start, request.method,
+ servlet_classname, tag
+ )
+
ru_utime, ru_stime = context.get_resource_usage()
- response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
- response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
+ response_ru_utime.inc_by(
+ ru_utime, request.method, servlet_classname, tag
+ )
+ response_ru_stime.inc_by(
+ ru_stime, request.method, servlet_classname, tag
+ )
response_db_txn_count.inc_by(
- context.db_txn_count, request.method, servlet_classname
+ context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, servlet_classname
+ context.db_txn_duration, request.method, servlet_classname, tag
)
except:
pass
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6eaa65071e..560866b26e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, ObservableDeferred
+from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken
import synapse.metrics
+from collections import namedtuple
+
import logging
@@ -71,7 +74,8 @@ class _NotifierUserStream(object):
self.current_token = current_token
self.last_notified_ms = time_now_ms
- self.notify_deferred = ObservableDeferred(defer.Deferred())
+ with PreserveLoggingContext():
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
@@ -86,8 +90,10 @@ class _NotifierUserStream(object):
)
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
- self.notify_deferred = ObservableDeferred(defer.Deferred())
- noify_deferred.callback(self.current_token)
+
+ with PreserveLoggingContext():
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
+ noify_deferred.callback(self.current_token)
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
@@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe())
+class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
+ def __nonzero__(self):
+ return bool(self.events)
+
+
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
new events available for it.
@@ -177,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()),
)
- @log_function
- @defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -192,12 +201,11 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- yield run_on_reactor()
-
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
- self._notify_pending_new_room_events(max_room_stream_id)
+ with PreserveLoggingContext():
+ self.pending_new_room_events.append((
+ room_stream_id, event, extra_users
+ ))
+ self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
@@ -244,31 +252,29 @@ class Notifier(object):
extra_streams=app_streams,
)
- @defer.inlineCallbacks
- @log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
- yield run_on_reactor()
- user_streams = set()
+ with PreserveLoggingContext():
+ user_streams = set()
- for user in users:
- user_stream = self.user_to_user_stream.get(str(user))
- if user_stream is not None:
- user_streams.add(user_stream)
+ for user in users:
+ user_stream = self.user_to_user_stream.get(str(user))
+ if user_stream is not None:
+ user_streams.add(user_stream)
- for room in rooms:
- user_streams |= self.room_to_user_streams.get(room, set())
+ for room in rooms:
+ user_streams |= self.room_to_user_streams.get(room, set())
- time_now_ms = self.clock.time_msec()
- for user_stream in user_streams:
- try:
- user_stream.notify(stream_key, new_token, time_now_ms)
- except:
- logger.exception("Failed to notify listener")
+ time_now_ms = self.clock.time_msec()
+ for user_stream in user_streams:
+ try:
+ user_stream.notify(stream_key, new_token, time_now_ms)
+ except:
+ logger.exception("Failed to notify listener")
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -301,7 +307,7 @@ class Notifier(object):
def timed_out():
if listener:
listener.deferred.cancel()
- timer = self.clock.call_later(timeout/1000., timed_out)
+ timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token
while not result:
@@ -318,7 +324,8 @@ class Notifier(object):
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
- yield listener.deferred
+ with PreserveLoggingContext():
+ yield listener.deferred
except defer.CancelledError:
break
@@ -356,7 +363,7 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
- defer.returnValue(None)
+ defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = []
end_token = from_token
@@ -369,6 +376,7 @@ class Notifier(object):
continue
if only_keys and name not in only_keys:
continue
+
new_events, new_key = yield source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
@@ -388,10 +396,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- if events:
- defer.returnValue((events, (from_token, end_token)))
- else:
- defer.returnValue(None)
+ defer.returnValue(EventStreamResult(events, (from_token, end_token)))
user_id_for_stream = user.to_string()
if is_peeking:
@@ -415,9 +420,6 @@ class Notifier(object):
from_token=from_token,
)
- if result is None:
- result = ([], (from_token, from_token))
-
defer.returnValue(result)
@defer.inlineCallbacks
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9bc0b356f4..8da2d8716c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -17,6 +17,8 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken
+from synapse.util.logcontext import LoggingContext
+from synapse.util.metrics import Measure
import synapse.util.async
import push_rule_evaluator as push_rule_evaluator
@@ -27,6 +29,16 @@ import random
logger = logging.getLogger(__name__)
+_NEXT_ID = 1
+
+
+def _get_next_id():
+ global _NEXT_ID
+ _id = _NEXT_ID
+ _NEXT_ID += 1
+ return _id
+
+
# Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the
# rules again.
@@ -57,6 +69,8 @@ class Pusher(object):
self.alive = True
self.badge = None
+ self.name = "Pusher-%d" % (_get_next_id(),)
+
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@@ -86,38 +100,46 @@ class Pusher(object):
@defer.inlineCallbacks
def start(self):
- if not self.last_token:
- # First-time setup: get a token to start from (we can't
- # just start from no token, ie. 'now'
- # because we need the result to be reproduceable in case
- # we fail to dispatch the push)
- config = PaginationConfig(from_token=None, limit='1')
- chunk = yield self.evStreamHandler.get_stream(
- self.user_id, config, timeout=0, affect_presence=False
- )
- self.last_token = chunk['end']
- self.store.update_pusher_last_token(
- self.app_id, self.pushkey, self.user_id, self.last_token
- )
- logger.info("Pusher %s for user %s starting from token %s",
- self.pushkey, self.user_id, self.last_token)
-
- wait = 0
- while self.alive:
- try:
- if wait > 0:
- yield synapse.util.async.sleep(wait)
- yield self.get_and_dispatch()
- wait = 0
- except:
- if wait == 0:
- wait = 1
- else:
- wait = min(wait * 2, 1800)
- logger.exception(
- "Exception in pusher loop for pushkey %s. Pausing for %ds",
- self.pushkey, wait
+ with LoggingContext(self.name):
+ if not self.last_token:
+ # First-time setup: get a token to start from (we can't
+ # just start from no token, ie. 'now'
+ # because we need the result to be reproduceable in case
+ # we fail to dispatch the push)
+ config = PaginationConfig(from_token=None, limit='1')
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_id, config, timeout=0, affect_presence=False
)
+ self.last_token = chunk['end']
+ yield self.store.update_pusher_last_token(
+ self.app_id, self.pushkey, self.user_id, self.last_token
+ )
+ logger.info("New pusher %s for user %s starting from token %s",
+ self.pushkey, self.user_id, self.last_token)
+
+ else:
+ logger.info(
+ "Old pusher %s for user %s starting",
+ self.pushkey, self.user_id,
+ )
+
+ wait = 0
+ while self.alive:
+ try:
+ if wait > 0:
+ yield synapse.util.async.sleep(wait)
+ with Measure(self.clock, "push"):
+ yield self.get_and_dispatch()
+ wait = 0
+ except:
+ if wait == 0:
+ wait = 1
+ else:
+ wait = min(wait * 2, 1800)
+ logger.exception(
+ "Exception in pusher loop for pushkey %s. Pausing for %ds",
+ self.pushkey, wait
+ )
@defer.inlineCallbacks
def get_and_dispatch(self):
@@ -316,7 +338,7 @@ class Pusher(object):
r.room_id, self.user_id, last_unread_event_id
)
)
- badge += len(notifs)
+ badge += notifs["notify_count"]
defer.returnValue(badge)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index dca018af95..2a2b4437dc 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
- _flatten_dict(value, prefix=(prefix+[key]), result=result)
+ _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d1b7c0802f..d7dcb2de4b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
+from synapse.util.logcontext import preserve_fn
import logging
@@ -76,7 +77,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
- self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id):
@@ -91,7 +92,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
- self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@@ -110,7 +111,7 @@ class PusherPool:
lang=lang,
data=data,
)
- self._refresh_pusher(app_id, pushkey, user_id)
+ yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
@@ -166,7 +167,7 @@ class PusherPool:
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
- p.start()
+ preserve_fn(p.start)()
logger.info("Started pushers")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index d37dc11470..a87628ed85 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -22,7 +22,7 @@ REQUIREMENTS = {
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
- "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
+ "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 07836709fb..7199113dac 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
- relay_state = "&RelayState="+urllib.quote(
+ relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b15defdd07..3c5a212920 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -33,7 +33,11 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {"displayname": displayname}))
+ ret = {}
+ if displayname is not None:
+ ret["displayname"] = displayname
+
+ defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -66,7 +70,11 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {"avatar_url": avatar_url}))
+ ret = {}
+ if avatar_url is not None:
+ ret["avatar_url"] = avatar_url
+
+ defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {
- "displayname": displayname,
- "avatar_url": avatar_url
- }))
+ ret = {}
+ if displayname is not None:
+ ret["displayname"] = displayname
+ if avatar_url is not None:
+ ret["avatar_url"] = avatar_url
+
+ defer.returnValue((200, ret))
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index cb3ec23872..96633a176c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, e.message)
before = request.args.get("before", None)
- if before and len(before):
- before = before[0]
+ if before:
+ before = _namespaced_rule_id(spec, before[0])
+
after = request.args.get("after", None)
- if after and len(after):
- after = after[0]
+ if after:
+ after = _namespaced_rule_id(spec, after[0])
try:
yield self.hs.get_datastore().add_push_rule(
@@ -452,11 +453,15 @@ def _strip_device_condition(rule):
def _namespaced_rule_id_from_spec(spec):
+ return _namespaced_rule_id(spec, spec['rule_id'])
+
+
+def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
- return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
+ return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index e218ed215c..5547f1b112 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
if i not in content:
missing.append(i)
if len(missing):
- raise SynapseError(400, "Missing parameters: "+','.join(missing),
+ raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
data=content['data']
)
except PusherConfigException as pce:
- raise SynapseError(400, "Config Error: "+pce.message,
+ raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 5378a9a938..6d6d03c34c 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
- compare_digest = lambda a, b: a == b
+ def compare_digest(a, b):
+ return a == b
class RegisterRestServlet(ClientV1RestServlet):
@@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# }
# TODO: persistent storage
self.sessions = {}
- self.disable_registration = hs.config.disable_registration
+ self.enable_registration = hs.config.enable_registration
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
- not self.disable_registration
+ self.enable_registration
or is_application_server
or is_using_shared_secret
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c7ea15c624..81bfe377bd 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -429,8 +429,6 @@ class RoomEventContext(ClientV1RestServlet):
serialize_event(event, time_now) for event in results["state"]
]
- logger.info("Responding with %r", results)
-
defer.returnValue((200, results))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d507172704..a614b79d45 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -116,9 +116,10 @@ class ThreepidRestServlet(RestServlet):
body = parse_json_dict_from_request(request)
- if 'threePidCreds' not in body:
+ threePidCreds = body.get('threePidCreds')
+ threePidCreds = body.get('three_pid_creds', threePidCreds)
+ if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
- threePidCreds = body['threePidCreds']
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 985efe2a62..1456881c1a 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body
)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
user_id, room_id, account_data_type, body
)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c4d025b465..ec5c21fa1f 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
- compare_digest = lambda a, b: a == b
+ def compare_digest(a, b):
+ return a == b
logger = logging.getLogger(__name__)
@@ -116,7 +117,7 @@ class RegisterRestServlet(RestServlet):
return
# == Normal User Registration == (everyone else)
- if self.hs.config.disable_registration:
+ if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None)
@@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
desired_username = params.get("username", None)
new_password = params.get("password", None)
+ guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 07b5b5dfd5..140ce2704b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -20,7 +20,6 @@ from synapse.http.servlet import (
)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
-from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
@@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state
timeline_events = room.timeline.events
- state_dict = SyncRestServlet._rollback_state_for_timeline(
- state_dict, timeline_events)
-
state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events]
@@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result
- @staticmethod
- def _rollback_state_for_timeline(state, timeline):
- """
- Wind the state dictionary backwards, so that it represents the
- state at the start of the timeline, rather than at the end.
-
- :param dict[(str, str), synapse.events.EventBase] state: the
- state dictionary. Will be updated to the state before the timeline.
- :param list[synapse.events.EventBase] timeline: the event timeline
- :return: updated state dictionary
- """
-
- result = state.copy()
-
- for timeline_event in reversed(timeline):
- if not timeline_event.is_state():
- continue
-
- event_key = (timeline_event.type, timeline_event.state_key)
-
- logger.debug("Considering %s for removal", event_key)
-
- state_event = result.get(event_key)
- if (state_event is None or
- state_event.event_id != timeline_event.event_id):
- # the event in the timeline isn't present in the state
- # dictionary.
- #
- # the most likely cause for this is that there was a fork in
- # the event graph, and the state is no longer valid. Really,
- # the event shouldn't be in the timeline. We're going to ignore
- # it for now, however.
- logger.debug("Found state event %r in timeline which doesn't "
- "match state dictionary", timeline_event)
- continue
-
- prev_event_id = timeline_event.unsigned.get("replaces_state", None)
-
- prev_content = timeline_event.unsigned.get('prev_content')
- prev_sender = timeline_event.unsigned.get('prev_sender')
- # Empircally it seems possible for the event to have a
- # "replaces_state" key but not a prev_content or prev_sender
- # markjh conjectures that it could be due to the server not
- # having a copy of that event.
- # If this is the case the we ignore the previous event. This will
- # cause the displayname calculations on the client to be incorrect
- if prev_event_id is None or not prev_content or not prev_sender:
- logger.debug(
- "Removing %r from the state dict, as it is missing"
- " prev_content (prev_event_id=%r)",
- timeline_event.event_id, prev_event_id
- )
- del result[event_key]
- else:
- logger.debug(
- "Replacing %r with %r in state dict",
- timeline_event.event_id, prev_event_id
- )
- result[event_key] = FrozenEvent({
- "type": timeline_event.type,
- "state_key": timeline_event.state_key,
- "content": prev_content,
- "sender": prev_sender,
- "event_id": prev_event_id,
- "room_id": timeline_event.room_id,
- })
-
- logger.debug("New value: %r", result.get(event_key))
-
- return result
-
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 42f2203f3d..79c436a8cf 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -80,7 +80,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@@ -94,7 +94,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 349ef6b396..ca5468c402 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
- "versions": [
- "r0.0.1",
- ]
+ "versions": ["r0.0.1"]
})
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index bdc65f0198..58d56ec7a4 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -28,6 +28,7 @@ from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
+from synapse.util.logcontext import preserve_context_over_fn
import os
@@ -276,7 +277,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
- t_len = yield threads.deferToThread(
+ t_len = yield preserve_context_over_fn(
+ threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@@ -298,7 +300,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
- t_len = yield threads.deferToThread(
+ t_len = yield preserve_context_over_fn(
+ threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@@ -372,7 +375,7 @@ class BaseMediaResource(Resource):
media_id, t_width, t_height, t_type, t_method, t_len
))
- yield threads.deferToThread(generate_thumbnails)
+ yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
@@ -445,7 +448,7 @@ class BaseMediaResource(Resource):
t_width, t_height, t_type, t_method, t_len
])
- yield threads.deferToThread(generate_thumbnails)
+ yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
diff --git a/synapse/server.py b/synapse/server.py
index 4a5796b982..368d615576 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -20,8 +20,10 @@
# Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
+from twisted.enterprise import adbapi
+
from synapse.federation import initialize_http_replication
-from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
+from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
@@ -36,8 +38,15 @@ from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+
+import logging
+
-class BaseHomeServer(object):
+logger = logging.getLogger(__name__)
+
+
+class HomeServer(object):
"""A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as
@@ -98,39 +107,18 @@ class BaseHomeServer(object):
self.hostname = hostname
self._building = {}
+ self.clock = Clock()
+ self.distributor = Distributor()
+ self.ratelimiter = Ratelimiter()
+
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
- @classmethod
- def _make_dependency_method(cls, depname):
- def _get(self):
- if hasattr(self, depname):
- return getattr(self, depname)
-
- if hasattr(self, "build_%s" % (depname)):
- # Prevent cyclic dependencies from deadlocking
- if depname in self._building:
- raise ValueError("Cyclic dependency while building %s" % (
- depname,
- ))
- self._building[depname] = 1
-
- builder = getattr(self, "build_%s" % (depname))
- dep = builder()
- setattr(self, depname, dep)
-
- del self._building[depname]
-
- return dep
-
- raise NotImplementedError(
- "%s has no %s nor a builder for it" % (
- type(self).__name__, depname,
- )
- )
-
- setattr(BaseHomeServer, "get_%s" % (depname), _get)
+ def setup(self):
+ logger.info("Setting up.")
+ self.datastore = DataStore(self.get_db_conn(), self)
+ logger.info("Finished setting up.")
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
@@ -142,33 +130,9 @@ class BaseHomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
-# Build magic accessors for every dependency
-for depname in BaseHomeServer.DEPENDENCIES:
- BaseHomeServer._make_dependency_method(depname)
-
-
-class HomeServer(BaseHomeServer):
- """A homeserver object that will construct most of its dependencies as
- required.
-
- It still requires the following to be specified by the caller:
- resource_for_client
- resource_for_web_client
- resource_for_federation
- resource_for_content_repo
- http_client
- db_pool
- """
-
- def build_clock(self):
- return Clock()
-
def build_replication_layer(self):
return initialize_http_replication(self)
- def build_datastore(self):
- return DataStore(self)
-
def build_handlers(self):
return Handlers(self)
@@ -179,10 +143,9 @@ class HomeServer(BaseHomeServer):
return Auth(self)
def build_http_client_context_factory(self):
- config = self.get_config()
return (
InsecureInterceptableContextFactory()
- if config.use_insecure_ssl_client_just_for_testing_do_not_use
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
@@ -201,15 +164,9 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self):
return StateHandler(self)
- def build_distributor(self):
- return Distributor()
-
def build_event_sources(self):
return EventSources(self)
- def build_ratelimiter(self):
- return Ratelimiter()
-
def build_keyring(self):
return Keyring(self)
@@ -224,3 +181,55 @@ class HomeServer(BaseHomeServer):
def build_pusherpool(self):
return PusherPool(self)
+
+ def build_http_client(self):
+ return MatrixFederationHttpClient(self)
+
+ def build_db_pool(self):
+ name = self.db_config["name"]
+
+ return adbapi.ConnectionPool(
+ name,
+ **self.db_config.get("args", {})
+ )
+
+
+def _make_dependency_method(depname):
+ def _get(hs):
+ try:
+ return getattr(hs, depname)
+ except AttributeError:
+ pass
+
+ try:
+ builder = getattr(hs, "build_%s" % (depname))
+ except AttributeError:
+ builder = None
+
+ if builder:
+ # Prevent cyclic dependencies from deadlocking
+ if depname in hs._building:
+ raise ValueError("Cyclic dependency while building %s" % (
+ depname,
+ ))
+ hs._building[depname] = 1
+
+ dep = builder()
+ setattr(hs, depname, dep)
+
+ del hs._building[depname]
+
+ return dep
+
+ raise NotImplementedError(
+ "%s has no %s nor a builder for it" % (
+ type(hs).__name__, depname,
+ )
+ )
+
+ setattr(HomeServer, "get_%s" % (depname), _get)
+
+
+# Build magic accessors for every dependency
+for depname in HomeServer.DEPENDENCIES:
+ _make_dependency_method(depname)
diff --git a/synapse/state.py b/synapse/state.py
index 0acf309fe0..b9a1387520 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -63,7 +63,7 @@ class StateHandler(object):
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
reset_expiry_on_get=True,
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..5a9e7720d9 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,6 +45,10 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
+from util.id_generators import IdGenerator, StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
import logging
@@ -55,7 +59,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
-LAST_SEEN_GRANULARITY = 120*1000
+LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore,
@@ -79,18 +83,97 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
- def __init__(self, hs):
- super(DataStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
self.hs = hs
+ self.database_engine = hs.database_engine
- self.min_token_deferred = self._get_min_token()
- self.min_token = None
+ cur = db_conn.cursor()
+ try:
+ cur.execute("SELECT MIN(stream_ordering) FROM events",)
+ rows = cur.fetchall()
+ self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+ self.min_stream_token = min(self.min_stream_token, -1)
+ finally:
+ cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering"
+ )
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
+ self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+ self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+
+ events_max = self._stream_id_gen.get_max_token(None)
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
+ )
+
+ account_max = self._account_data_id_gen.get_max_token(None)
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(DataStore, self).__init__(hs)
+
+ def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - 100000"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ }
+
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+ rows = txn.fetchall()
+
+ cache = {
+ row[0]: int(row[1])
+ for row in rows
+ }
+
+ if cache:
+ min_val = min(cache.values())
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 90d7aee94a..2e97ac84a8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
-from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
- self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -197,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
- ratio = (curr - prev)/(time_now - time_then)
+ ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
@@ -310,10 +298,10 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
@@ -338,14 +326,15 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
defer.returnValue(result)
- def cursor_to_dict(self, cursor):
+ @staticmethod
+ def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
- @log_function
- def _simple_insert_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
- def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ @classmethod
+ def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
- ret = self._simple_select_onecol_txn(
+ ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
- def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ @staticmethod
+ def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
- def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+ @classmethod
+ def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -627,7 +620,7 @@ class SQLBaseStore(object):
)
txn.execute(sql)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
@@ -650,7 +643,10 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
- chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
+ chunks = [
+ iterable[i:i + batch_size]
+ for i in xrange(0, len(iterable), batch_size)
+ ]
for chunk in chunks:
rows = yield self.runInteraction(
desc,
@@ -662,7 +658,8 @@ class SQLBaseStore(object):
defer.returnValue(results)
- def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+ @classmethod
+ def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -699,7 +696,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@@ -726,7 +723,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ @staticmethod
+ def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -743,7 +741,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
- def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ @staticmethod
+ def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@@ -784,7 +783,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
- def _simple_delete_txn(self, txn, table, keyvalues):
+ @staticmethod
+ def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 9c6597e012..b8387fc500 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -83,7 +83,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
"""Get all the client account_data for a that's changed.
Args:
@@ -120,6 +120,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room)
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return ({}, {})
+
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -151,6 +157,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
@@ -186,6 +196,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index b5aa55c0a3..1100c67714 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
"application_services_state",
dict(as_id=service.id),
["state"],
- allow_none=True
+ allow_none=True,
+ desc="get_appservice_state",
)
if result:
defer.returnValue(result.get("state"))
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 400c10103c..91fac33b8b 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -54,7 +54,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5f32eec6f8..ce2c794025 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set()
front_list = list(front)
chunks = [
- front_list[x:x+100]
+ front_list[x:x + 100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a05c4f84cf..d0a969f50b 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -37,7 +37,11 @@ class EventPushActionsStore(SQLBaseStore):
'event_id': event.event_id,
'user_id': uid,
'profile_tag': profile_tag,
- 'actions': json.dumps(actions)
+ 'actions': json.dumps(actions),
+ 'stream_ordering': event.internal_metadata.stream_ordering,
+ 'topological_ordering': event.depth,
+ 'notif': 1,
+ 'highlight': 1 if _action_has_highlight(actions) else 0,
})
def f(txn):
@@ -68,32 +72,34 @@ class EventPushActionsStore(SQLBaseStore):
)
results = txn.fetchall()
if len(results) == 0:
- return []
+ return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
topological_ordering = results[0][1]
sql = (
- "SELECT ea.event_id, ea.actions"
- " FROM event_push_actions ea, events e"
- " WHERE ea.room_id = e.room_id"
- " AND ea.event_id = e.event_id"
- " AND ea.user_id = ?"
- " AND ea.room_id = ?"
+ "SELECT sum(notif), sum(highlight)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " user_id = ?"
+ " AND room_id = ?"
" AND ("
- " e.topological_ordering > ?"
- " OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
+ " topological_ordering > ?"
+ " OR (topological_ordering = ? AND stream_ordering > ?)"
")"
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
- )
- )
- return [
- {"event_id": row[0], "actions": json.loads(row[1])}
- for row in txn.fetchall()
- ]
+ ))
+ row = txn.fetchone()
+ if row:
+ return {
+ "notify_count": row[0] or 0,
+ "highlight_count": row[1] or 0,
+ }
+ else:
+ return {"notify_count": 0, "highlight_count": 0}
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -117,3 +123,14 @@ class EventPushActionsStore(SQLBaseStore):
"remove_push_actions_for_event_id",
f
)
+
+
+def _action_has_highlight(actions):
+ for action in actions:
+ try:
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
+ except AttributeError:
+ pass
+
+ return False
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..c6ed54721c 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
-from synapse.util.logcontext import preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
@@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- start = self.min_token - 1
- self.min_token -= len(events_and_contexts) + 1
- stream_orderings = range(start, self.min_token, -1)
+ start = self.min_stream_token - 1
+ self.min_stream_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
@@ -86,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream
chunks = [
- events_and_contexts[x:x+100]
+ events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
@@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
+ self.min_stream_token -= 1
+ stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@@ -214,6 +210,12 @@ class EventsStore(SQLBaseStore):
for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id)
+ if not backfilled:
+ txn.call_after(
+ self._events_stream_cache.entity_has_changed,
+ event.room_id, event.internal_metadata.stream_ordering,
+ )
+
depth_updates = {}
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -662,14 +664,16 @@ class EventsStore(SQLBaseStore):
for ids, d in lst:
if not d.called:
try:
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
except:
logger.exception("Failed to callback")
- reactor.callFromThread(fire, event_list, row_dict)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@@ -677,10 +681,12 @@ class EventsStore(SQLBaseStore):
def fire(evs):
for _, d in evs:
if not d.called:
- d.errback(e)
+ with PreserveLoggingContext():
+ d.errback(e)
if event_list:
- reactor.callFromThread(fire, event_list)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
@@ -707,18 +713,20 @@ class EventsStore(SQLBaseStore):
should_start = False
if should_start:
- self.runWithConnection(
- self._do_fetch
- )
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
- rows = yield preserve_context_over_deferred(events_d)
+ with PreserveLoggingContext():
+ rows = yield events_d
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
- self._get_event_from_row(
+ preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -738,7 +746,7 @@ class EventsStore(SQLBaseStore):
rows = []
N = 200
for i in range(1 + len(events) / N):
- evs = events[i*N:(i + 1)*N]
+ evs = events[i * N:(i + 1) * N]
if not evs:
break
@@ -753,7 +761,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
- ) % (",".join(["?"]*len(evs)),)
+ ) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index f8fc9bdddc..5248736816 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -16,12 +16,13 @@
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json
class FilteringStore(SQLBaseStore):
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 8022b8cfc6..fd05bfe54e 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
table="server_tls_certificates",
keyvalues={"server_name": server_name},
retcols=("tls_certificate",),
+ desc="get_server_certificate",
)
tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c1f5f99789..850736c85e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 28
+SCHEMA_VERSION = 29
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
- logger.debug("Upgrading schema to v%d", v)
+ logger.info("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 9b3aecaf8c..ef525f34c5 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
for row in rows
})
+ @defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state):
- res = self._simple_update_one(
+ res = yield self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
@@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
)
self.get_presence_state.invalidate((user_localpart,))
- return res
+ defer.returnValue(res)
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 1f51c90ee5..f9a48171ba 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore):
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None)
- relative_to_rule = kwargs.pop("before", after)
+ before = kwargs.pop("before", None)
+ relative_to_rule = before or after
res = self._simple_select_one_txn(
txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..4202a6b3dc 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -15,11 +15,10 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
-from blist import sorteddict
import logging
import ujson as json
@@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
- self._receipts_stream_cache = _RoomStreamChangeCache()
+ self._receipts_stream_cache = StreamChangeCache(
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+ )
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -45,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room",
)
+ @cached(num_args=3)
+ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
+ return self._simple_select_one_onecol(
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id
+ },
+ retcol="event_id",
+ desc="get_own_receipt_for_user",
+ allow_none=True,
+ )
+
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
def f(txn):
@@ -76,8 +91,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids)
if from_key:
- room_ids = yield self._receipts_stream_cache.get_rooms_changed(
- self, room_ids, from_key
+ room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
@@ -220,6 +235,16 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+ txn.call_after(
+ self._receipts_stream_cache.entity_has_changed,
+ room_id, stream_id
+ )
+
+ txn.call_after(
+ self.get_last_receipt_event_id_for_user.invalidate,
+ (user_id, room_id, receipt_type)
+ )
+
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
@@ -307,9 +332,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
- yield self._receipts_stream_cache.room_has_changed(
- self, room_id, stream_id
- )
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
@@ -368,63 +390,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
-
-class _RoomStreamChangeCache(object):
- """Keeps track of the stream_id of the latest change in rooms.
-
- Given a list of rooms and stream key, it will give a subset of rooms that
- may have changed since that key. If the key is too old then the cache
- will simply return all rooms.
- """
- def __init__(self, size_of_cache=10000):
- self._size_of_cache = size_of_cache
- self._room_to_key = {}
- self._cache = sorteddict()
- self._earliest_key = None
- self.name = "ReceiptsRoomChangeCache"
- caches_by_name[self.name] = self._cache
-
- @defer.inlineCallbacks
- def get_rooms_changed(self, store, room_ids, key):
- """Returns subset of room ids that have had new receipts since the
- given key. If the key is too old it will just return the given list.
- """
- if key > (yield self._get_earliest_key(store)):
- keys = self._cache.keys()
- i = keys.bisect_right(key)
-
- result = set(
- self._cache[k] for k in keys[i:]
- ).intersection(room_ids)
-
- cache_counter.inc_hits(self.name)
- else:
- result = room_ids
- cache_counter.inc_misses(self.name)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def room_has_changed(self, store, room_id, key):
- """Informs the cache that the room has been changed at the given key.
- """
- if key > (yield self._get_earliest_key(store)):
- old_key = self._room_to_key.get(room_id, None)
- if old_key:
- key = max(key, old_key)
- self._cache.pop(old_key, None)
- self._cache[key] = room_id
-
- while len(self._cache) > self._size_of_cache:
- k, r = self._cache.popitem()
- self._earliest_key = max(k, self._earliest_key)
- self._room_to_key.pop(r, None)
-
- @defer.inlineCallbacks
- def _get_earliest_key(self, store):
- if self._earliest_key is None:
- self._earliest_key = yield store.get_max_receipt_stream_id()
- self._earliest_key = int(self._earliest_key)
-
- defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 70cde0d04d..967c732bda 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
@@ -134,6 +136,7 @@ class RegistrationStore(SQLBaseStore):
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
+ desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id):
@@ -350,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def find_next_generated_user_id_localpart(self):
+ """
+ Gets the localpart of the next generated user ID.
+
+ Generated user IDs are integers, and we aim for them to be as small as
+ we can. Unfortunately, it's possible some of them are already taken by
+ existing users, and there may be gaps in the already taken range. This
+ function returns the start of the first allocatable gap. This is to
+ avoid the case of ID 10000000 being pre-allocated, so us wasting the
+ first (and shortest) many generated user IDs.
+ """
+ def _find_next_generated_user_id(txn):
+ txn.execute("SELECT name FROM users")
+ rows = self.cursor_to_dict(txn)
+
+ regex = re.compile("^@(\d+):")
+
+ found = set()
+
+ for r in rows:
+ user_id = r["name"]
+ match = regex.search(user_id)
+ if match:
+ found.add(int(match.group(1)))
+ for i in xrange(len(found) + 1):
+ if i not in found:
+ return i
+
+ defer.returnValue((yield self.runInteraction(
+ "find_next_generated_user_id",
+ _find_next_generated_user_id
+ )))
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index dc09a3aaba..46ab38a313 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
desc="get_public_room_ids",
)
- @defer.inlineCallbacks
- def get_rooms(self, is_public):
- """Retrieve a list of all public rooms.
-
- Args:
- is_public (bool): True if the rooms returned should be public.
- Returns:
- A list of room dicts containing at least a "room_id" key, a
- "topic" key if one is set, and a "name" key if one is set
+ def get_room_count(self):
+ """Retrieve a list of all rooms
"""
def f(txn):
- def subquery(table_name, column_name=None):
- column_name = column_name or table_name
- return (
- "SELECT %(table_name)s.event_id as event_id, "
- "%(table_name)s.room_id as room_id, %(column_name)s "
- "FROM %(table_name)s "
- "INNER JOIN current_state_events as c "
- "ON c.event_id = %(table_name)s.event_id " % {
- "column_name": column_name,
- "table_name": table_name,
- }
- )
-
- sql = (
- "SELECT"
- " r.room_id,"
- " max(n.name),"
- " max(t.topic),"
- " max(v.history_visibility),"
- " max(g.guest_access)"
- " FROM rooms AS r"
- " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
- " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
- " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
- " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
- " WHERE r.is_public = ?"
- " GROUP BY r.room_id" % {
- "topic": subquery("topics", "topic"),
- "name": subquery("room_names", "name"),
- "history_visibility": subquery("history_visibility"),
- "guest_access": subquery("guest_access"),
- }
- )
-
- txn.execute(sql, (is_public,))
-
- rows = txn.fetchall()
-
- for i, row in enumerate(rows):
- room_id = row[0]
- aliases = self._simple_select_onecol_txn(
- txn,
- table="room_aliases",
- keyvalues={
- "room_id": room_id
- },
- retcol="room_alias",
- )
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
- rows[i] = list(row) + [aliases]
-
- return rows
-
- rows = yield self.runInteraction(
+ return self.runInteraction(
"get_rooms", f
)
- ret = [
- {
- "room_id": r[0],
- "name": r[1],
- "topic": r[2],
- "world_readable": r[3] == "world_readable",
- "guest_can_join": r[4] == "can_join",
- "aliases": r[5],
- }
- for r in rows
- if r[5] # We only return rooms that have at least one alias.
- ]
-
- defer.returnValue(ret)
-
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn(
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index edfecced05..3065b0c1a5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -241,7 +245,7 @@ class RoomMemberStore(SQLBaseStore):
return rows
- @cached()
+ @cached(max_entries=5000)
def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql
new file mode 100644
index 0000000000..200c35e6e2
--- /dev/null
+++ b/synapse/storage/schema/delta/28/events_room_stream.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX events_room_stream on events(room_id, stream_ordering);
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql
new file mode 100644
index 0000000000..ba62a974a4
--- /dev/null
+++ b/synapse/storage/schema/delta/28/public_roms_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX public_room_index on rooms(is_public);
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql
new file mode 100644
index 0000000000..7e7b09820a
--- /dev/null
+++ b/synapse/storage/schema/delta/29/push_actions.sql
@@ -0,0 +1,31 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
+
+UPDATE event_push_actions SET stream_ordering = (
+ SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
+), topological_ordering = (
+ SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
+);
+
+UPDATE event_push_actions SET notif = 1, highlight = 0;
+
+CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
+ user_id, room_id, topological_ordering, stream_ordering
+);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..367ffc9543 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
-from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn
import logging
@@ -77,7 +77,6 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
-
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
@@ -157,7 +156,147 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
- @log_function
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+ room_ids = yield self._events_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+
+ if not room_ids:
+ defer.returnValue({})
+
+ results = {}
+ room_ids = list(room_ids)
+ for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
+ res = yield defer.gatherResults([
+ preserve_fn(self.get_room_events_stream_for_room)(
+ room_id, from_key, to_key, limit,
+ )
+ for room_id in room_ids
+ ])
+ results.update(dict(zip(rm_ids, res)))
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ defer.returnValue(([], from_key))
+
+ if from_id:
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
+
+ if not has_changed:
+ defer.returnValue(([], from_key))
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC LIMIT ?"
+ )
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+ else:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC LIMIT ?"
+ )
+ txn.execute(sql, (room_id, to_id, limit))
+
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows, topo_order=False)
+
+ ret.reverse()
+
+ if rows:
+ key = "s%d" % min(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = from_key
+
+ defer.returnValue((ret, key))
+
+ @defer.inlineCallbacks
+ def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ defer.returnValue([])
+
+ if from_id:
+ has_changed = self._membership_stream_cache.has_entity_changed(
+ user_id, int(from_id)
+ )
+ if not has_changed:
+ defer.returnValue([])
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+ else:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, to_id,))
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction("get_membership_changes_for_user", f)
+
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows, topo_order=False)
+
+ defer.returnValue(ret)
+
def get_room_events_stream(
self,
user_id,
@@ -174,7 +313,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
- " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+ " WHERE c.room_id IN (%s)"
+ " AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
@@ -187,11 +327,6 @@ class StreamStore(SQLBaseStore):
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
- if room_ids:
- current_room_membership_sql += " AND m.room_id in (%s)" % (
- ",".join(map(lambda _: "?", room_ids))
- )
- current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
@@ -430,10 +565,23 @@ class StreamStore(SQLBaseStore):
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
+ desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
+ def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ sql = (
+ "SELECT max(topological_ordering) FROM events"
+ " WHERE room_id = ? AND stream_ordering < ?"
+ )
+ return self._execute(
+ "get_max_topological_token_for_stream_and_room", None,
+ sql, room_id, stream_key,
+ ).addCallback(
+ lambda r: r[0][0] if r else 0
+ )
+
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
@@ -444,27 +592,21 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
-
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
-
- logger.debug("min_token is: %s", self.min_token)
-
- defer.returnValue(self.min_token)
-
@staticmethod
- def _set_before_and_after(events, rows):
+ def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows):
stream = row["stream_ordering"]
- topo = event.depth
+ if topo_order:
+ topo = event.depth
+ else:
+ topo = None
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
+ internal.order = (
+ int(topo) if topo else 0,
+ int(stream),
+ )
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..e1a9c0c261 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@@ -25,13 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
- def __init__(self, hs):
- super(TagsStore, self).__init__(hs)
-
- self._account_data_id_gen = StreamIdGenerator(
- "account_data_max_stream_id", "stream_id"
- )
-
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@@ -87,6 +79,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ defer.returnValue({})
+
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
@@ -184,6 +182,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to.
"""
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id
+ )
+
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..5c522f4ab9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
- def __init__(self, table, column):
+ def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
- self._current_max = None
+ cur = db_conn.cursor()
+ self._current_max = self._get_or_compute_current_max(cur)
+ cur.close()
+
self._unfinished_ids = deque()
- @defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
self._current_max += 1
next_id = self._current_max
@@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
if self._unfinished_ids:
- defer.returnValue(self._unfinished_ids[0] - 1)
+ return self._unfinished_ids[0] - 1
- defer.returnValue(self._current_max)
+ return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f1fe963adf..133671e238 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -46,7 +46,7 @@ class Clock(object):
def looping_call(self, f, msec):
l = task.LoopingCall(f)
- l.start(msec/1000.0, now=False)
+ l.start(msec / 1000.0, now=False)
return l
def stop_looping_call(self, loop):
@@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
- current_context = LoggingContext.current_context()
-
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext(current_context):
+ with PreserveLoggingContext():
callback(*args, **kwargs)
with PreserveLoggingContext():
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 200edd404c..640fae3890 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,13 +16,16 @@
from twisted.internet import defer, reactor
-from .logcontext import preserve_context_over_deferred
+from .logcontext import PreserveLoggingContext
+@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
- reactor.callLater(seconds, d.callback, seconds)
- return preserve_context_over_deferred(d)
+ with PreserveLoggingContext():
+ reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def run_on_reactor():
@@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().callback(r)
except:
pass
@@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().errback(f)
except:
pass
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -149,7 +152,7 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 494226f5ea..62cae99649 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -55,7 +55,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
- self._clock.looping_call(f, self._expiry_ms/2)
+ self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e6a66dc041..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -37,7 +37,7 @@ class LruCache(object):
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
- self.size = 0
+ self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
- self.size += 1
def move_node_to_front(node):
prev_node = node[PREV]
@@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- self.size -= 1
@synchronized
def cache_get(key, default=None):
@@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE]
else:
add_node(key, value)
- if self.size > max_size:
+ if len(cache) > max_size:
todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@@ -145,7 +143,7 @@ class LruCache(object):
@synchronized
def cache_len():
- return self.size
+ return len(cache)
@synchronized
def cache_contains(key):
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
new file mode 100644
index 0000000000..b37f1c0725
--- /dev/null
+++ b/synapse/util/caches/stream_change_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class StreamChangeCache(object):
+ """Keeps track of the stream positions of the latest change in a set of entities.
+
+ Typically the entity will be a room or user id.
+
+ Given a list of entities and a stream position, it will give a subset of
+ entities that may have changed since that position. If position key is too
+ old then the cache will simply return all given entities.
+ """
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
+ self._max_size = max_size
+ self._entity_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_known_stream_pos = current_stream_pos
+ self.name = name
+ caches_by_name[self.name] = self._cache
+
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
+
+ def has_entity_changed(self, entity, stream_pos):
+ """Returns True if the entity may have been updated since stream_pos
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos < self._earliest_known_stream_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ latest_entity_change_pos = self._entity_to_key.get(entity, None)
+ if latest_entity_change_pos is None:
+ cache_counter.inc_hits(self.name)
+ return False
+
+ if stream_pos < latest_entity_change_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ cache_counter.inc_hits(self.name)
+ return False
+
+ def get_entities_changed(self, entities, stream_pos):
+ """Returns subset of entities that have had new things since the
+ given position. If the position is too old it will just return the given list.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(entities)
+
+ cache_counter.inc_hits(self.name)
+ else:
+ result = entities
+ cache_counter.inc_misses(self.name)
+
+ return result
+
+ def entity_has_changed(self, entity, stream_pos):
+ """Informs the cache that the entity has been changed at the given
+ position.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos > self._earliest_known_stream_pos:
+ old_pos = self._entity_to_key.get(entity, None)
+ if old_pos is not None:
+ stream_pos = max(stream_pos, old_pos)
+ self._cache.pop(old_pos, None)
+ self._cache[stream_pos] = entity
+ self._entity_to_key[entity] = stream_pos
+
+ while len(self._cache) > self._max_size:
+ k, r = self._cache.popitem()
+ self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 3b58860910..03bc1401b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples.
"""
def __init__(self):
+ self.size = 0
self.root = {}
def __setitem__(self, key, value):
@@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
- node[key[-1]] = value
+ node[key[-1]] = _Entry(value)
+ self.size += 1
def get(self, key, default=None):
node = self.root
@@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None)
if node is None:
return default
- return node.get(key[-1], default)
+ return node.get(key[-1], _Entry(default)).value
def clear(self):
+ self.size = 0
self.root = {}
def pop(self, key, default=None):
@@ -55,6 +58,35 @@ class TreeCache(object):
if n:
break
- node_and_keys[i+1][0].pop(k)
+ node_and_keys[i + 1][0].pop(k)
+ popped, cnt = _strip_and_count_entires(popped)
+ self.size -= cnt
return popped
+
+ def __len__(self):
+ return self.size
+
+
+class _Entry(object):
+ __slots__ = ["value"]
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _strip_and_count_entires(d):
+ """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+ value or a dictionary with _Entry's replaced by their values.
+
+ Also returns the count of _Entry's
+ """
+ if isinstance(d, dict):
+ cnt = 0
+ for key, value in d.items():
+ v, n = _strip_and_count_entires(value)
+ d[key] = v
+ cnt += n
+ return d, cnt
+ else:
+ return d.value, 1
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 4ebfebf701..8875813de4 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,9 +15,7 @@
from twisted.internet import defer
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import unwrapFirstError
@@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
+ @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject()))
if not self.suppress_failures:
return failure
+
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers
]
- d = defer.gatherResults(deferreds, consumeErrors=True)
+ res = yield defer.gatherResults(
+ deferreds, consumeErrors=True
+ ).addErrback(unwrapFirstError)
- d.addErrback(unwrapFirstError)
+ defer.returnValue(res)
- return preserve_context_over_deferred(d)
+ def __repr__(self):
+ return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 0595c0fa4f..b22a36336b 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -47,7 +47,8 @@ class LoggingContext(object):
"""
__slots__ = [
- "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
+ "parent_context", "name", "usage_start", "usage_end", "main_thread",
+ "__dict__", "tag", "alive",
]
thread_local = threading.local()
@@ -72,6 +73,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def __nonzero__(self):
+ return False
+
sentinel = Sentinel()
def __init__(self, name=None):
@@ -83,6 +87,8 @@ class LoggingContext(object):
self.db_txn_duration = 0.
self.usage_start = None
self.main_thread = threading.current_thread()
+ self.tag = ""
+ self.alive = True
def __str__(self):
return "%s@%x" % (self.name, id(self))
@@ -101,6 +107,7 @@ class LoggingContext(object):
The context that was previously active
"""
current = cls.current_context()
+
if current is not context:
current.stop()
cls.thread_local.current_context = context
@@ -112,6 +119,7 @@ class LoggingContext(object):
if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times")
self.parent_context = self.set_current_context(self)
+ self.alive = True
return self
def __exit__(self, type, value, traceback):
@@ -131,6 +139,7 @@ class LoggingContext(object):
self
)
self.parent_context = None
+ self.alive = False
def __getattr__(self, name):
"""Delegate member lookup to parent context"""
@@ -208,7 +217,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor."""
- __slots__ = ["current_context", "new_context"]
+ __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
@@ -219,12 +228,27 @@ class PreserveLoggingContext(object):
self.new_context
)
+ if self.current_context:
+ self.has_parent = self.current_context.parent_context is not None
+ if not self.current_context.alive:
+ logger.debug(
+ "Entering dead context: %s",
+ self.current_context,
+ )
+
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
- LoggingContext.set_current_context(self.current_context)
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ logger.debug(
+ "Unexpected logging context: %s is not %s",
+ context, self.new_context,
+ )
+
if self.current_context is not LoggingContext.sentinel:
- if self.current_context.parent_context is None:
- logger.warn(
+ if not self.current_context.alive:
+ logger.debug(
"Restoring dead context: %s",
self.current_context,
)
@@ -284,3 +308,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d)
return d
+
+
+def preserve_fn(f):
+ """Ensures that function is called with correct context and that context is
+ restored after return. Useful for wrapping functions that return a deferred
+ which you don't yield on.
+ """
+ current = LoggingContext.current_context()
+
+ def g(*args, **kwargs):
+ with PreserveLoggingContext(current):
+ return f(*args, **kwargs)
+
+ return g
+
+
+# modules to ignore in `logcontext_tracer`
+_to_ignore = [
+ "synapse.util.logcontext",
+ "synapse.http.server",
+ "synapse.storage._base",
+ "synapse.util.async",
+]
+
+
+def logcontext_tracer(frame, event, arg):
+ """A tracer that logs whenever a logcontext "unexpectedly" changes within
+ a function. Probably inaccurate.
+
+ Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+ """
+ if event == 'call':
+ name = frame.f_globals["__name__"]
+ if name.startswith("synapse"):
+ if name == "synapse.util.logcontext":
+ if frame.f_code.co_name in ["__enter__", "__exit__"]:
+ tracer = frame.f_back.f_trace
+ if tracer:
+ tracer.just_changed = True
+
+ tracer = frame.f_trace
+ if tracer:
+ return tracer
+
+ if not any(name.startswith(ig) for ig in _to_ignore):
+ return LineTracer()
+
+
+class LineTracer(object):
+ __slots__ = ["context", "just_changed"]
+
+ def __init__(self):
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+
+ def __call__(self, frame, event, arg):
+ if event in 'line':
+ if self.just_changed:
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+ else:
+ c = LoggingContext.current_context()
+ if c != self.context:
+ logger.info(
+ "Context changed! %s -> %s, %s, %s",
+ self.context, c,
+ frame.f_code.co_filename, frame.f_lineno
+ )
+ self.context = c
+
+ return self
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index d5b1a37eff..3a83828d25 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f(
f,
"[FUNC END] {%s-%d} %f",
- (func_name, id, end-start,),
+ (func_name, id, end - start,),
)
return r
@@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name
return wrapped
+
+
+def get_previous_frames():
+ s = inspect.currentframe().f_back.f_back
+ to_return = []
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ to_return.append("{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ ))
+
+ s = s.f_back
+
+ return ", ". join(to_return)
+
+
+def get_previous_frame(ignore=[]):
+ s = inspect.currentframe().f_back.f_back
+
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ return "{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ )
+
+ s = s.f_back
+
+ return None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
new file mode 100644
index 0000000000..c51b641125
--- /dev/null
+++ b/synapse/util/metrics.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.util.logcontext import LoggingContext
+import synapse.metrics
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+block_timer = metrics.register_distribution(
+ "block_timer",
+ labels=["block_name"]
+)
+
+block_ru_utime = metrics.register_distribution(
+ "block_ru_utime", labels=["block_name"]
+)
+
+block_ru_stime = metrics.register_distribution(
+ "block_ru_stime", labels=["block_name"]
+)
+
+block_db_txn_count = metrics.register_distribution(
+ "block_db_txn_count", labels=["block_name"]
+)
+
+block_db_txn_duration = metrics.register_distribution(
+ "block_db_txn_duration", labels=["block_name"]
+)
+
+
+class Measure(object):
+ __slots__ = [
+ "clock", "name", "start_context", "start", "new_context", "ru_utime",
+ "ru_stime", "db_txn_count", "db_txn_duration"
+ ]
+
+ def __init__(self, clock, name):
+ self.clock = clock
+ self.name = name
+ self.start_context = None
+ self.start = None
+
+ def __enter__(self):
+ self.start = self.clock.time_msec()
+ self.start_context = LoggingContext.current_context()
+ if self.start_context:
+ self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+ self.db_txn_count = self.start_context.db_txn_count
+ self.db_txn_duration = self.start_context.db_txn_duration
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None or not self.start_context:
+ return
+
+ duration = self.clock.time_msec() - self.start
+ block_timer.inc_by(duration, self.name)
+
+ context = LoggingContext.current_context()
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+ context, self.start_context, self.name
+ )
+ return
+
+ if not context:
+ logger.warn("Expected context. (%r)", self.name)
+ return
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
+ block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
+ block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+ block_db_txn_duration.inc_by(
+ context.db_txn_duration - self.db_txn_duration, self.name
+ )
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index c37d6f12e3..4076eed269 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
+from synapse.util.logcontext import preserve_fn
import collections
import contextlib
@@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = sleep(self.sleep_msec/1000.0)
+ ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 1a4e439d30..ceb0089268 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.*"]
}
}
- user = UserID.from_string("@" + user_localpart + ":test")
+
filter_id = yield self.datastore.add_user_filter(
- user_localpart=user_localpart,
+ user_localpart=user_localpart + "2",
user_filter=user_filter_json,
)
event = MockEvent(
+ event_id="$asdasd:localhost",
sender="@foo:bar",
type="custom.avatar.3d.crazy",
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart,
+ user_localpart=user_localpart + "2",
filter_id=filter_id,
)
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/config/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
new file mode 100644
index 0000000000..4329d73974
--- /dev/null
+++ b/tests/config/test_generate.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+import shutil
+import tempfile
+from synapse.config.homeserver import HomeServerConfig
+from tests import unittest
+
+
+class ConfigGenerationTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.dir = tempfile.mkdtemp()
+ print self.dir
+ self.file = os.path.join(self.dir, "homeserver.yaml")
+
+ def tearDown(self):
+ shutil.rmtree(self.dir)
+
+ def test_generate_config_generates_files(self):
+ HomeServerConfig.load_config("", [
+ "--generate-config",
+ "-c", self.file,
+ "--report-stats=yes",
+ "-H", "lemurs.win"
+ ])
+
+ self.assertSetEqual(
+ set([
+ "homeserver.yaml",
+ "lemurs.win.log.config",
+ "lemurs.win.signing.key",
+ "lemurs.win.tls.crt",
+ "lemurs.win.tls.dh",
+ "lemurs.win.tls.key",
+ ]),
+ set(os.listdir(self.dir))
+ )
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
new file mode 100644
index 0000000000..fbbbf93fef
--- /dev/null
+++ b/tests/config/test_load.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+import shutil
+import tempfile
+import yaml
+from synapse.config.homeserver import HomeServerConfig
+from tests import unittest
+
+
+class ConfigLoadingTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.dir = tempfile.mkdtemp()
+ print self.dir
+ self.file = os.path.join(self.dir, "homeserver.yaml")
+
+ def tearDown(self):
+ shutil.rmtree(self.dir)
+
+ def test_load_fails_if_server_name_missing(self):
+ self.generate_config_and_remove_lines_containing("server_name")
+ with self.assertRaises(Exception):
+ HomeServerConfig.load_config("", ["-c", self.file])
+
+ def test_generates_and_loads_macaroon_secret_key(self):
+ self.generate_config()
+
+ with open(self.file,
+ "r") as f:
+ raw = yaml.load(f)
+ self.assertIn("macaroon_secret_key", raw)
+
+ config = HomeServerConfig.load_config("", ["-c", self.file])
+ self.assertTrue(
+ hasattr(config, "macaroon_secret_key"),
+ "Want config to have attr macaroon_secret_key"
+ )
+ if len(config.macaroon_secret_key) < 5:
+ self.fail(
+ "Want macaroon secret key to be string of at least length 5,"
+ "was: %r" % (config.macaroon_secret_key,)
+ )
+
+ def test_load_succeeds_if_macaroon_secret_key_missing(self):
+ self.generate_config_and_remove_lines_containing("macaroon")
+ config1 = HomeServerConfig.load_config("", ["-c", self.file])
+ config2 = HomeServerConfig.load_config("", ["-c", self.file])
+ self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
+
+ def generate_config(self):
+ HomeServerConfig.load_config("", [
+ "--generate-config",
+ "-c", self.file,
+ "--report-stats=yes",
+ "-H", "lemurs.win"
+ ])
+
+ def generate_config_and_remove_lines_containing(self, needle):
+ self.generate_config()
+
+ with open(self.file, "r") as f:
+ contents = f.readlines()
+ contents = [l for l in contents if needle not in l]
+ with open(self.file, "w") as f:
+ f.write("".join(contents))
diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/federation/__init__.py
+++ /dev/null
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
deleted file mode 100644
index f2c2ee4127..0000000000
--- a/tests/federation/test_federation.py
+++ /dev/null
@@ -1,303 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# trial imports
-from twisted.internet import defer
-from tests import unittest
-
-# python imports
-from mock import Mock, ANY
-
-from ..utils import MockHttpResource, MockClock, setup_test_homeserver
-
-from synapse.federation import initialize_http_replication
-from synapse.events import FrozenEvent
-
-
-def make_pdu(prev_pdus=[], **kwargs):
- """Provide some default fields for making a PduTuple."""
- pdu_fields = {
- "state_key": None,
- "prev_events": prev_pdus,
- }
- pdu_fields.update(kwargs)
-
- return FrozenEvent(pdu_fields)
-
-
-class FederationTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource()
- self.mock_http_client = Mock(spec=[
- "get_json",
- "put_json",
- ])
- self.mock_persistence = Mock(spec=[
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_auth_chain",
- ])
- self.mock_persistence.get_received_txn_response.return_value = (
- defer.succeed(None)
- )
-
- retry_timings_res = {
- "destination": "",
- "retry_last_ts": 0,
- "retry_interval": 0,
- }
- self.mock_persistence.get_destination_retry_timings.return_value = (
- defer.succeed(retry_timings_res)
- )
- self.mock_persistence.get_auth_chain.return_value = []
- self.clock = MockClock()
- hs = yield setup_test_homeserver(
- resource_for_federation=self.mock_resource,
- http_client=self.mock_http_client,
- datastore=self.mock_persistence,
- clock=self.clock,
- keyring=Mock(),
- )
- self.federation = initialize_http_replication(hs)
- self.distributor = hs.get_distributor()
-
- @defer.inlineCallbacks
- def test_get_state(self):
- mock_handler = Mock(spec=[
- "get_state_for_pdu",
- ])
-
- self.federation.set_handler(mock_handler)
-
- mock_handler.get_state_for_pdu.return_value = defer.succeed([])
-
- # Empty context initially
- (code, response) = yield self.mock_resource.trigger(
- "GET",
- "/_matrix/federation/v1/state/my-context/",
- None
- )
- self.assertEquals(200, code)
- self.assertFalse(response["pdus"])
-
- # Now lets give the context some state
- mock_handler.get_state_for_pdu.return_value = (
- defer.succeed([
- make_pdu(
- event_id="the-pdu-id",
- origin="red",
- user_id="@a:red",
- room_id="my-context",
- type="m.topic",
- origin_server_ts=123456789000,
- depth=1,
- content={"topic": "The topic"},
- state_key="",
- power_level=1000,
- prev_state="last-pdu-id",
- ),
- ])
- )
-
- (code, response) = yield self.mock_resource.trigger(
- "GET",
- "/_matrix/federation/v1/state/my-context/",
- None
- )
- self.assertEquals(200, code)
- self.assertEquals(1, len(response["pdus"]))
-
- @defer.inlineCallbacks
- def test_get_pdu(self):
- mock_handler = Mock(spec=[
- "get_persisted_pdu",
- ])
-
- self.federation.set_handler(mock_handler)
-
- mock_handler.get_persisted_pdu.return_value = (
- defer.succeed(None)
- )
-
- (code, response) = yield self.mock_resource.trigger(
- "GET",
- "/_matrix/federation/v1/event/abc123def456/",
- None
- )
- self.assertEquals(404, code)
-
- # Now insert such a PDU
- mock_handler.get_persisted_pdu.return_value = (
- defer.succeed(
- make_pdu(
- event_id="abc123def456",
- origin="red",
- user_id="@a:red",
- room_id="my-context",
- type="m.text",
- origin_server_ts=123456789001,
- depth=1,
- content={"text": "Here is the message"},
- )
- )
- )
-
- (code, response) = yield self.mock_resource.trigger(
- "GET",
- "/_matrix/federation/v1/event/abc123def456/",
- None
- )
- self.assertEquals(200, code)
- self.assertEquals(1, len(response["pdus"]))
- self.assertEquals("m.text", response["pdus"][0]["type"])
-
- @defer.inlineCallbacks
- def test_send_pdu(self):
- self.mock_http_client.put_json.return_value = defer.succeed(
- (200, "OK")
- )
-
- pdu = make_pdu(
- event_id="abc123def456",
- origin="red",
- user_id="@a:red",
- room_id="my-context",
- type="m.text",
- origin_server_ts=123456789001,
- depth=1,
- content={"text": "Here is the message"},
- )
-
- yield self.federation.send_pdu(pdu, ["remote"])
-
- self.mock_http_client.put_json.assert_called_with(
- "remote",
- path="/_matrix/federation/v1/send/1000000/",
- data={
- "origin_server_ts": 1000000,
- "origin": "test",
- "pdus": [
- pdu.get_pdu_json(),
- ],
- 'pdu_failures': [],
- },
- json_data_callback=ANY,
- long_retries=True,
- )
-
- @defer.inlineCallbacks
- def test_send_edu(self):
- self.mock_http_client.put_json.return_value = defer.succeed(
- (200, "OK")
- )
-
- yield self.federation.send_edu(
- destination="remote",
- edu_type="m.test",
- content={"testing": "content here"},
- )
-
- # MockClock ensures we can guess these timestamps
- self.mock_http_client.put_json.assert_called_with(
- "remote",
- path="/_matrix/federation/v1/send/1000000/",
- data={
- "origin": "test",
- "origin_server_ts": 1000000,
- "pdus": [],
- "edus": [
- {
- "edu_type": "m.test",
- "content": {"testing": "content here"},
- }
- ],
- 'pdu_failures': [],
- },
- json_data_callback=ANY,
- long_retries=True,
- )
-
- @defer.inlineCallbacks
- def test_recv_edu(self):
- recv_observer = Mock()
- recv_observer.return_value = defer.succeed(())
-
- self.federation.register_edu_handler("m.test", recv_observer)
-
- yield self.mock_resource.trigger(
- "PUT",
- "/_matrix/federation/v1/send/1001000/",
- """{
- "origin": "remote",
- "origin_server_ts": 1001000,
- "pdus": [],
- "edus": [
- {
- "origin": "remote",
- "destination": "test",
- "edu_type": "m.test",
- "content": {"testing": "reply here"}
- }
- ]
- }"""
- )
-
- recv_observer.assert_called_with(
- "remote", {"testing": "reply here"}
- )
-
- @defer.inlineCallbacks
- def test_send_query(self):
- self.mock_http_client.get_json.return_value = defer.succeed(
- {"your": "response"}
- )
-
- response = yield self.federation.make_query(
- destination="remote",
- query_type="a-question",
- args={"one": "1", "two": "2"},
- )
-
- self.assertEquals({"your": "response"}, response)
-
- self.mock_http_client.get_json.assert_called_with(
- destination="remote",
- path="/_matrix/federation/v1/query/a-question",
- args={"one": "1", "two": "2"},
- retry_on_dns_fail=True,
- )
-
- @defer.inlineCallbacks
- def test_recv_query(self):
- recv_handler = Mock()
- recv_handler.return_value = defer.succeed({"another": "response"})
-
- self.federation.register_query_handler("a-question", recv_handler)
-
- code, response = yield self.mock_resource.trigger(
- "GET",
- "/_matrix/federation/v1/query/a-question?three=3&four=4",
- None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"another": "response"}, response)
-
- recv_handler.assert_called_with(
- {"three": "3", "four": "4"}
- )
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index b260e269ac..e9698bfdc9 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.config.enable_registration_captcha = False
- hs.config.disable_registration = False
+ hs.config.enable_registration = True
hs.get_handlers().federation_handler = Mock()
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 90b911f879..8d7cfd79ab 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase):
}
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
+ clock = Mock(spec=[
+ "call_later",
+ "cancel_call_later",
+ "time_msec",
+ "looping_call",
+ ])
+
+ clock.time_msec.return_value = 1000000
+
hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
@@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"get_presence_list",
"get_rooms_for_user",
]),
- clock=Mock(spec=[
- "call_later",
- "cancel_call_later",
- "time_msec",
- "looping_call",
- ]),
+ clock=clock,
)
- hs.get_clock().time_msec.return_value = 1000000
-
def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index cd03106e88..ad5dd3bd6e 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1045,8 +1045,13 @@ class RoomMessageListTestCase(RestTestCase):
self.assertTrue("end" in response)
@defer.inlineCallbacks
- def test_stream_token_is_rejected(self):
+ def test_stream_token_is_accepted_for_fwd_pagianation(self):
+ token = "s0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/messages?access_token=x&from=s0_0_0_0" %
- self.room_id)
- self.assertEquals(400, code)
+ "/rooms/%s/messages?access_token=x&from=%s" %
+ (self.room_id, token))
+ self.assertEquals(200, code)
+ self.assertTrue("start" in response)
+ self.assertEquals(token, response['start'])
+ self.assertTrue("chunk" in response)
+ self.assertTrue("end" in response)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index f9a2b22485..df0841b0b1 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
- self.hs.config.disable_registration = False
+ self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
@@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
}))
def test_POST_disabled_registration(self):
- self.hs.config.disable_registration = True
+ self.hs.config.enable_registration = False
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5abecdf6e0..ed8af10d87 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2])
- hs = yield setup_test_homeserver(config=config)
+ hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs)
@@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
- hs = yield setup_test_homeserver(config=config)
+ hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
@@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
- hs = yield setup_test_homeserver(config=config)
+ hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index a35efcc71e..7b3b4c13bc 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -18,7 +18,6 @@ from tests import unittest
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.registration import RegistrationStore
from synapse.util import stringutils
from tests.utils import setup_test_homeserver
@@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver()
self.db_pool = hs.get_db_pool()
- self.store = RegistrationStore(hs)
+ self.store = hs.get_datastore()
self.user_id = "@my-user:test"
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 7fdbfc60f1..0baaf3df21 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string()))
)
- @defer.inlineCallbacks
- def test_get_rooms(self):
- # get_rooms does an INNER JOIN on the room_aliases table :(
-
- rooms = yield self.store.get_rooms(is_public=True)
- # Should be empty before we add the alias
- self.assertEquals([], rooms)
-
- yield self.store.create_room_alias_association(
- room_alias=self.alias,
- room_id=self.room.to_string(),
- servers=["test"]
- )
-
- rooms = yield self.store.get_rooms(is_public=True)
-
- self.assertEquals(1, len(rooms))
- self.assertEquals({
- "name": None,
- "room_id": self.room.to_string(),
- "topic": None,
- "aliases": [self.alias.to_string()],
- "world_readable": False,
- "guest_can_join": False,
- }, rooms[0])
-
class RoomEventsStoreTestCase(unittest.TestCase):
diff --git a/tests/test_types.py b/tests/test_types.py
index b9534329e6..24d61dbe54 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -16,10 +16,10 @@
from tests import unittest
from synapse.api.errors import SynapseError
-from synapse.server import BaseHomeServer
+from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias
-mock_homeserver = BaseHomeServer(hostname="my.domain")
+mock_homeserver = HomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
@@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
UserID.from_string("")
-
def test_build(self):
user = UserID("5678efgh", "my.domain")
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 2cd3d26454..bab366fb7f 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,7 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+
class LruCacheTestCase(unittest.TestCase):
def test_get_set(self):
@@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("vehicles", "car")), "vroom")
self.assertEquals(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
+
+ def test_clear(self):
+ cache = LruCache(1)
+ cache["key"] = 1
+ cache.clear()
+ self.assertEquals(len(cache), 0)
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 9946ceb3f1..1efbeb6b33 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
cache[("b",)] = "B"
self.assertEquals(cache.get(("a",)), "A")
self.assertEquals(cache.get(("b",)), "B")
+ self.assertEquals(len(cache), 2)
def test_pop_onelevel(self):
cache = TreeCache()
@@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.pop(("a",)), "A")
self.assertEquals(cache.pop(("a",)), None)
self.assertEquals(cache.get(("b",)), "B")
+ self.assertEquals(len(cache), 1)
def test_get_set_twolevel(self):
cache = TreeCache()
@@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), "AA")
self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.get(("b", "a")), "BA")
+ self.assertEquals(len(cache), 3)
def test_pop_twolevel(self):
cache = TreeCache()
@@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.pop(("b", "a")), "BA")
self.assertEquals(cache.pop(("b", "a")), None)
+ self.assertEquals(len(cache), 1)
def test_pop_mixedlevel(self):
cache = TreeCache()
@@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), None)
self.assertEquals(cache.get(("a", "b")), None)
self.assertEquals(cache.get(("b", "a")), "BA")
+ self.assertEquals(len(cache), 1)
+
+ def test_clear(self):
+ cache = TreeCache()
+ cache[("a",)] = "A"
+ cache[("b",)] = "B"
+ cache.clear()
+ self.assertEquals(len(cache), 0)
diff --git a/tests/utils.py b/tests/utils.py
index 358b5b72b7..3b1eb50d8d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
+from synapse.federation.transport import server
+from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -44,9 +46,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config = Mock()
config.signing_key = [MockKey()]
config.event_cache_size = 1
- config.disable_registration = False
+ config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test"
+ config.trusted_third_party_id_servers = []
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -58,8 +61,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
database_engine=create_engine("sqlite3"),
+ get_db_conn=db_pool.get_db_conn,
**kargs
)
+ hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
@@ -80,6 +85,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
+ fed = kargs.get("resource_for_federation", None)
+ if fed:
+ server.register_servlets(
+ hs,
+ resource=fed,
+ authenticator=server.Authenticator(hs),
+ ratelimiter=FederationRateLimiter(
+ hs.get_clock(),
+ window_size=hs.config.federation_rc_window_size,
+ sleep_limit=hs.config.federation_rc_sleep_limit,
+ sleep_msec=hs.config.federation_rc_sleep_delay,
+ reject_limit=hs.config.federation_rc_reject_limit,
+ concurrent_requests=hs.config.federation_rc_concurrent
+ ),
+ )
+
defer.returnValue(hs)
@@ -262,6 +283,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
lambda conn: prepare_database(conn, engine)
)
+ def get_db_conn(self):
+ conn = self.connect()
+ engine = create_engine("sqlite3")
+ prepare_database(conn, engine)
+ return conn
+
class MemoryDataStore(object):
diff --git a/tox.ini b/tox.ini
index bd313a4f36..6f01599af7 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,7 +11,7 @@ deps =
setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code
commands =
- /bin/bash -c "coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
+ /bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
|