diff --git a/CHANGES.rst b/CHANGES.rst
index f89542a2bb..da31af9606 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,12 @@
+Changes in synapse v0.8.1 (2015-03-18)
+======================================
+
+* Disable registration by default. New users can be added using the command
+ ``register_new_matrix_user`` or by enabling registration in the config.
+* Add metrics to synapse. To enable metrics use config options
+ ``enable_metrics`` and ``metrics_port``.
+* Fix bug where banning only kicked the user.
+
Changes in synapse v0.8.0 (2015-03-06)
======================================
diff --git a/README.rst b/README.rst
index e4664ea768..874753762d 100644
--- a/README.rst
+++ b/README.rst
@@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate.
+By default, registration of new users is disabled. You can either enable
+registration in the config (it is then recommended to also set up CAPTCHA), or
+you can use the command line to register new users::
+
+ $ source ~/.synapse/bin/activate
+ $ register_new_matrix_user -c homeserver.yaml https://localhost:8448
+ New user localpart: erikj
+ Password:
+ Confirm password:
+ Success!
+
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details.
diff --git a/register_new_matrix_user b/register_new_matrix_user
new file mode 100755
index 0000000000..daddadc302
--- /dev/null
+++ b/register_new_matrix_user
@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import getpass
+import hashlib
+import hmac
+import json
+import sys
+import urllib2
+import yaml
+
+
+def request_registration(user, password, server_location, shared_secret):
+ mac = hmac.new(
+ key=shared_secret,
+ msg=user,
+ digestmod=hashlib.sha1,
+ ).hexdigest()
+
+ data = {
+ "user": user,
+ "password": password,
+ "mac": mac,
+ "type": "org.matrix.login.shared_secret",
+ }
+
+ server_location = server_location.rstrip("/")
+
+ print "Sending registration request..."
+
+ req = urllib2.Request(
+ "%s/_matrix/client/api/v1/register" % (server_location,),
+ data=json.dumps(data),
+ headers={'Content-Type': 'application/json'}
+ )
+ try:
+ f = urllib2.urlopen(req)
+ f.read()
+ f.close()
+ print "Success."
+ except urllib2.HTTPError as e:
+ print "ERROR! Received %d %s" % (e.code, e.reason,)
+ if 400 <= e.code < 500:
+ if e.info().type == "application/json":
+ resp = json.load(e)
+ if "error" in resp:
+ print resp["error"]
+ sys.exit(1)
+
+
+def register_new_user(user, password, server_location, shared_secret):
+ if not user:
+ try:
+ default_user = getpass.getuser()
+ except:
+ default_user = None
+
+ if default_user:
+ user = raw_input("New user localpart [%s]: " % (default_user,))
+ if not user:
+ user = default_user
+ else:
+ user = raw_input("New user localpart: ")
+
+ if not user:
+ print "Invalid user name"
+ sys.exit(1)
+
+ if not password:
+ password = getpass.getpass("Password: ")
+
+ if not password:
+ print "Password cannot be blank."
+ sys.exit(1)
+
+ confirm_password = getpass.getpass("Confirm password: ")
+
+ if password != confirm_password:
+ print "Passwords do not match"
+ sys.exit(1)
+
+ request_registration(user, password, server_location, shared_secret)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Used to register new users with a given home server when"
+ " registration has been disabled. The home server must be"
+ " configured with the 'registration_shared_secret' option"
+ " set.",
+ )
+ parser.add_argument(
+ "-u", "--user",
+ default=None,
+ help="Local part of the new user. Will prompt if omitted.",
+ )
+ parser.add_argument(
+ "-p", "--password",
+ default=None,
+ help="New password for user. Will prompt if omitted.",
+ )
+
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "-c", "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in shared secret.",
+ )
+
+ group.add_argument(
+ "-k", "--shared-secret",
+ help="Shared secret as defined in server config file.",
+ )
+
+ parser.add_argument(
+ "server_url",
+ default="https://localhost:8448",
+ nargs='?',
+ help="URL to use to talk to the home server. Defaults to "
+ " 'https://localhost:8448'.",
+ )
+
+ args = parser.parse_args()
+
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ secret = config.get("registration_shared_secret", None)
+ if not secret:
+ print "No 'registration_shared_secret' defined in config."
+ sys.exit(1)
+ else:
+ secret = args.shared_secret
+
+ register_new_user(args.user, args.password, args.server_url, secret)
diff --git a/setup.py b/setup.py
index 2d812fa389..ab24159be7 100755
--- a/setup.py
+++ b/setup.py
@@ -45,7 +45,7 @@ setup(
version=version,
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server",
- install_requires=dependencies["REQUIREMENTS"].keys(),
+ install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
@@ -55,5 +55,5 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
- scripts=["synctl"],
+ scripts=["synctl", "register_new_matrix_user"],
)
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f46a6df1fb..749a60329c 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.8.0"
+__version__ = "0.8.1-r2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b176db8ce1..64f605b962 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__)
+AuthEventTypes = (
+ EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+)
+
+
class Auth(object):
def __init__(self, hs):
@@ -166,6 +172,7 @@ class Auth(object):
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
+ target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
@@ -194,6 +201,7 @@ class Auth(object):
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
+ "target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
@@ -202,6 +210,11 @@ class Auth(object):
}
)
+ if ban_level:
+ ban_level = int(ban_level)
+ else:
+ ban_level = 50 # FIXME (erikj): What should we do here?
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -212,6 +225,10 @@ class Auth(object):
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
+ elif target_banned:
+ raise AuthError(
+ 403, "%s is banned from the room" % (target_user_id,)
+ )
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
@@ -221,6 +238,8 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
@@ -238,6 +257,10 @@ class Auth(object):
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
+ elif target_banned and user_level < ban_level:
+ raise AuthError(
+ 403, "You cannot unban user &s." % (target_user_id,)
+ )
elif target_user_id != event.user_id:
if kick_level:
kick_level = int(kick_level)
@@ -249,11 +272,6 @@ class Auth(object):
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- if ban_level:
- ban_level = int(ban_level)
- else:
- ban_level = 50 # FIXME (erikj): What should we do here?
-
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
@@ -370,7 +388,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
- ret = yield self.store.get_user_by_token(token=token)
+ ret = yield self.store.get_user_by_token(token)
if not ret:
raise StoreError(400, "Unknown token")
user_info = {
@@ -412,12 +430,6 @@ class Auth(object):
builder.auth_events = auth_events_entries
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
-
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 420f963d91..b16bf4247d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -60,6 +60,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
+ SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 15c454af76..500cae05fb 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -60,7 +60,6 @@ import re
import resource
import subprocess
import sqlite3
-import syweb
logger = logging.getLogger(__name__)
@@ -83,6 +82,7 @@ class SynapseHomeServer(HomeServer):
return AppServiceRestResource(self)
def build_resource_for_web_client(self):
+ import syweb
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable?
@@ -130,7 +130,7 @@ class SynapseHomeServer(HomeServer):
True.
"""
config = self.get_config()
- web_client = config.webclient
+ web_client = config.web_client
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
@@ -343,7 +343,8 @@ def setup(config_options):
config.setup_logging()
- check_requirements()
+ # check any extra requirements we have now we have a config
+ check_requirements(config)
version_string = get_version_string()
@@ -450,6 +451,7 @@ def run(hs):
def main():
with LoggingContext("main"):
+ # check base requirements
check_requirements()
hs = setup(sys.argv[1:])
run(hs)
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cca8ab5676..4401e774d1 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -15,19 +15,46 @@
from ._base import Config
+from synapse.util.stringutils import random_string_with_symbols
+
+import distutils.util
+
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
- self.disable_registration = args.disable_registration
+
+ # `args.disable_registration` may either be a bool or a string depending
+ # on if the option was given a value (e.g. --disable-registration=false
+ # would set `args.disable_registration` to "false" not False.)
+ self.disable_registration = bool(
+ distutils.util.strtobool(str(args.disable_registration))
+ )
+ self.registration_shared_secret = args.registration_shared_secret
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
+
reg_group.add_argument(
"--disable-registration",
- action='store_true',
- help="Disable registration of new users."
+ const=True,
+ default=True,
+ nargs='?',
+ help="Disable registration of new users.",
)
+ reg_group.add_argument(
+ "--registration-shared-secret", type=str,
+ help="If set, allows registration by anyone who also has the shared"
+ " secret, even if registration is otherwise disabled.",
+ )
+
+ @classmethod
+ def generate_config(cls, args, config_dir_path):
+ if args.disable_registration is None:
+ args.disable_registration = True
+
+ if args.registration_shared_secret is None:
+ args.registration_shared_secret = random_string_with_symbols(50)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index b042d4eed9..58a828cc4c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -28,7 +28,7 @@ class ServerConfig(Config):
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
- self.webclient = True
+ self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
@@ -68,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
+ server_group.add_argument('--web_client', default=True, type=bool,
+ help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7e98bdef28..4ecadf0879 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -16,8 +16,7 @@
class EventContext(object):
- def __init__(self, current_state=None, auth_events=None):
+ def __init__(self, current_state=None):
self.current_state = current_state
- self.auth_events = auth_events
self.state_group = None
self.rejected = False
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 7838a81362..2bfe0f3c9b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -361,4 +361,5 @@ SERVLET_CLASSES = (
FederationInviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
+ FederationEventAuthServlet,
)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 1773fa20aa..48816a242d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
event = builder.build()
logger.debug(
- "Created event %s with auth_events: %s, current state: %s",
- event.event_id, context.auth_events, context.current_state,
+ "Created event %s with current state: %s",
+ event.event_id, context.current_state,
)
defer.returnValue(
@@ -106,7 +106,7 @@ class BaseHandler(object):
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
- self.auth.check(event, auth_events=context.auth_events)
+ self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context)
@@ -142,7 +142,16 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
- yield self.notifier.on_new_room_event(event, extra_users=extra_users)
+ # Don't block waiting on waking up all the listeners.
+ d = self.notifier.on_new_room_event(event, extra_users=extra_users)
+
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
yield federation_handler.handle_new_event(
event, destinations=destinations,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ae4e9b316d..15ba417e06 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
+ yield self.store.clean_room_for_join(room_id)
+
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
@@ -464,11 +466,9 @@ class FederationHandler(BaseHandler):
builder=builder,
)
- self.auth.check(event, auth_events=context.auth_events)
-
- pdu = event
+ self.auth.check(event, auth_events=context.current_state)
- defer.returnValue(pdu)
+ defer.returnValue(event)
@defer.inlineCallbacks
@log_function
@@ -705,7 +705,7 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
- auth_events = context.auth_events
+ auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 731df00648..bbc7a0f200 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,6 +33,10 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
+# Don't bother bumping "last active" time if it differs by less than 60 seconds
+LAST_ACTIVE_GRANULARITY = 60*1000
+
+
# TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func):
"""Partition the list by the result of func applied to each element."""
@@ -282,6 +286,10 @@ class PresenceHandler(BaseHandler):
if now is None:
now = self.clock.time_msec()
+ prev_state = self._get_or_make_usercache(user)
+ if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
+ return
+
self.changed_presencelike_data(user, {"last_active": now})
def changed_presencelike_data(self, user, state):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cda4a8502a..c25e321099 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -31,6 +31,7 @@ import base64
import bcrypt
import json
import logging
+import urllib
logger = logging.getLogger(__name__)
@@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
+ if localpart and urllib.quote(localpart) != localpart:
+ raise SynapseError(
+ 400,
+ "User ID must only contain characters which do not"
+ " require URL encoding."
+ )
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index f1376ee243..dee49b9e18 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -46,6 +46,11 @@ outgoing_responses_counter = metrics.register_counter(
labels=["method", "code"],
)
+response_timer = metrics.register_distribution(
+ "response_time",
+ labels=["method", "servlet"]
+)
+
class HttpServer(object):
""" Interface for registering callbacks on a HTTP server
@@ -169,6 +174,10 @@ class JsonResource(HttpServer, resource.Resource):
code, response = yield callback(request, *args)
self._send_response(request, code, response)
+ response_timer.inc_by(
+ self.clock.time_msec() - start, request.method, servlet_classname
+ )
+
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a4eb6c817c..265559a3ea 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -51,8 +51,8 @@ class RestServlet(object):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
- if hasattr(self, "on_%s" % (method)):
- method_handler = getattr(self, "on_%s" % (method))
+ if hasattr(self, "on_%s" % (method,)):
+ method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8a5849d960..6b6d5508b8 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"],
- "matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@@ -18,6 +17,19 @@ REQUIREMENTS = {
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
}
+CONDITIONAL_REQUIREMENTS = {
+ "web_client": {
+ "matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
+ }
+}
+
+
+def requirements(config=None, include_conditional=False):
+ reqs = REQUIREMENTS.copy()
+ for key, req in CONDITIONAL_REQUIREMENTS.items():
+ if (config and getattr(config, key)) or include_conditional:
+ reqs.update(req)
+ return reqs
def github_link(project, version, egg):
@@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
pass
-def check_requirements():
+def check_requirements(config=None):
"""Checks that all the modules needed by synapse have been correctly
installed and are at the correct version"""
- for dependency, module_requirements in REQUIREMENTS.items():
+ for dependency, module_requirements in (
+ requirements(config, include_conditional=False).items()):
for module_requirement in module_requirements:
if ">=" in module_requirement:
module_name, required_version = module_requirement.split(">=")
@@ -110,7 +123,7 @@ def list_requirements():
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)
- for requirement in REQUIREMENTS:
+ for requirement in requirements(include_conditional=True):
is_linked = False
for link in linked:
if requirement.replace('-', '_').startswith(link):
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index f5acfb945f..a56834e365 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -27,7 +27,6 @@ from hashlib import sha1
import hmac
import simplejson as json
import logging
-import urllib
logger = logging.getLogger(__name__)
@@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
- if self.disable_registration and not is_application_server:
+ is_using_shared_secret = login_type == LoginType.SHARED_SECRET
+
+ can_register = (
+ not self.disable_registration
+ or is_application_server
+ or is_using_shared_secret
+ )
+ if not can_register:
raise SynapseError(403, "Registration has been disabled")
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
- LoginType.APPLICATION_SERVICE: self._do_app_service
+ LoginType.APPLICATION_SERVICE: self._do_app_service,
+ LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
)
password = register_json["password"].encode("utf-8")
- desired_user_id = (register_json["user"].encode("utf-8")
- if "user" in register_json else None)
- if (desired_user_id
- and urllib.quote(desired_user_id) != desired_user_id):
- raise SynapseError(
- 400,
- "User ID must only contain characters which do not " +
- "require URL encoding.")
+ desired_user_id = (
+ register_json["user"].encode("utf-8")
+ if "user" in register_json else None
+ )
+
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
@@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
+ @defer.inlineCallbacks
+ def _do_shared_secret(self, request, register_json, session):
+ yield run_on_reactor()
+
+ if not isinstance(register_json.get("mac", None), basestring):
+ raise SynapseError(400, "Expected mac.")
+ if not isinstance(register_json.get("user", None), basestring):
+ raise SynapseError(400, "Expected 'user' key.")
+ if not isinstance(register_json.get("password", None), basestring):
+ raise SynapseError(400, "Expected 'password' key.")
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ user = register_json["user"].encode("utf-8")
+
+ # str() because otherwise hmac complains that 'unicode' does not
+ # have the buffer interface
+ got_mac = str(register_json["mac"])
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret,
+ msg=user,
+ digestmod=sha1,
+ ).hexdigest()
+
+ password = register_json["password"].encode("utf-8")
+
+ if compare_digest(want_mac, got_mac):
+ handler = self.handlers.registration_handler
+ user_id, token = yield handler.register(
+ localpart=user,
+ password=password,
+ )
+ self._remove_session(session)
+ defer.returnValue({
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ })
+ else:
+ raise SynapseError(
+ 403, "HMAC incorrect",
+ )
+
def _parse_json(request):
try:
diff --git a/synapse/state.py b/synapse/state.py
index 80cced351d..ba2500d61c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-AuthEventTypes = (
- EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules,
-)
-
-
SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20
@@ -139,18 +134,6 @@ class StateHandler(object):
}
context.state_group = None
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
@@ -187,18 +170,6 @@ class StateHandler(object):
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
context.prev_state_events = prev_state
defer.returnValue(context)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e752b035e6..f4dec70393 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,15 +14,12 @@
# limitations under the License.
from twisted.internet import defer
-
-from synapse.util.logutils import log_function
-from synapse.api.constants import EventTypes
-
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
+from ._base import Cache
from .directory import DirectoryStore
-from .feedback import FeedbackStore
+from .events import EventsStore
from .presence import PresenceStore
from .profile import ProfileStore
from .registration import RegistrationStore
@@ -41,11 +38,6 @@ from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
-from syutil.base64util import decode_base64
-from syutil.jsonutil import encode_canonical_json
-
-from synapse.crypto.event_signing import compute_event_reference_hash
-
import fnmatch
import imp
@@ -63,16 +55,14 @@ SCHEMA_VERSION = 15
dir_path = os.path.abspath(os.path.dirname(__file__))
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
- pass
+# Number of msec of granularity to store the user IP 'last seen' time. Smaller
+# times give more inserts into the database even for readonly API hits
+# 120 seconds == 2 minutes
+LAST_SEEN_GRANULARITY = 120*1000
class DataStore(RoomMemberStore, RoomStore,
- RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
+ RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
@@ -83,6 +73,7 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
+ EventsStore,
):
def __init__(self, hs):
@@ -92,424 +83,28 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token()
self.min_token = None
- @defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False,
- is_new_state=True, current_state=None):
- stream_ordering = None
- if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
-
- try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
- except _RollbackButIsFineException:
- pass
-
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- event = yield self.runInteraction(
- "get_event", self._get_event_txn,
- event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- if not event and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
-
- defer.returnValue(event)
-
- @log_function
- def _persist_event_txn(self, txn, event, context, backfilled,
- stream_ordering=None, is_new_state=True,
- current_state=None):
-
- # Remove the any existing cache entries for the event_id
- self._get_event_cache.pop(event.event_id)
-
- # We purposefully do this first since if we include a `current_state`
- # key, we *want* to update the `current_state_events` table
- if current_state:
- txn.execute(
- "DELETE FROM current_state_events WHERE room_id = ?",
- (event.room_id,)
- )
-
- for s in current_state:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": s.event_id,
- "room_id": s.room_id,
- "type": s.type,
- "state_key": s.state_key,
- },
- or_replace=True,
- )
-
- if event.is_state() and is_new_state:
- if not backfilled and not context.rejected:
- self._simple_insert_txn(
- txn,
- table="state_forward_extremities",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for prev_state_id, _ in event.prev_state:
- self._simple_delete_txn(
- txn,
- table="state_forward_extremities",
- keyvalues={
- "event_id": prev_state_id,
- }
- )
-
- outlier = event.internal_metadata.is_outlier()
-
- if not outlier:
- self._store_state_groups_txn(txn, event, context)
-
- self._update_min_depth_for_room_txn(
- txn,
- event.room_id,
- event.depth
- )
-
- self._handle_prev_events(
- txn,
- outlier=outlier,
- event_id=event.event_id,
- prev_events=event.prev_events,
- room_id=event.room_id,
- )
-
- have_persisted = self._simple_select_one_onecol_txn(
- txn,
- table="event_json",
- keyvalues={"event_id": event.event_id},
- retcol="event_id",
- allow_none=True,
- )
-
- metadata_json = encode_canonical_json(
- event.internal_metadata.get_dict()
- )
-
- # If we have already persisted this event, we don't need to do any
- # more processing.
- # The processing above must be done on every call to persist event,
- # since they might not have happened on previous calls. For example,
- # if we are persisting an event that we had persisted as an outlier,
- # but is no longer one.
- if have_persisted:
- if not outlier:
- sql = (
- "UPDATE event_json SET internal_metadata = ?"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (metadata_json.decode("UTF-8"), event.event_id,)
- )
-
- sql = (
- "UPDATE events SET outlier = 0"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (event.event_id,)
- )
- return
-
- if event.type == EventTypes.Member:
- self._store_room_member_txn(txn, event)
- elif event.type == EventTypes.Feedback:
- self._store_feedback_txn(txn, event)
- elif event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
-
- event_dict = {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- self._simple_insert_txn(
- txn,
- table="event_json",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": metadata_json.decode("UTF-8"),
- "json": encode_canonical_json(event_dict).decode("UTF-8"),
- },
- or_replace=True,
- )
-
- content = encode_canonical_json(
- event.content
- ).decode("UTF-8")
-
- vals = {
- "topological_ordering": event.depth,
- "event_id": event.event_id,
- "type": event.type,
- "room_id": event.room_id,
- "content": content,
- "processed": True,
- "outlier": outlier,
- "depth": event.depth,
- }
-
- if stream_ordering is not None:
- vals["stream_ordering"] = stream_ordering
-
- unrec = {
- k: v
- for k, v in event.get_dict().items()
- if k not in vals.keys() and k not in [
- "redacted",
- "redacted_because",
- "signatures",
- "hashes",
- "prev_events",
- ]
- }
-
- vals["unrecognized_keys"] = encode_canonical_json(
- unrec
- ).decode("UTF-8")
-
- try:
- self._simple_insert_txn(
- txn,
- "events",
- vals,
- or_replace=(not outlier),
- or_ignore=bool(outlier),
- )
- except:
- logger.warn(
- "Failed to persist, probably duplicate: %s",
- event.event_id,
- exc_info=True,
- )
- raise _RollbackButIsFineException("_persist_event")
-
- if context.rejected:
- self._store_rejections_txn(txn, event.event_id, context.rejected)
-
- if event.is_state():
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- self._simple_insert_txn(
- txn,
- "state_events",
- vals,
- or_replace=True,
- )
-
- if is_new_state and not context.rejected:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for e_id, h in event.prev_state:
- self._simple_insert_txn(
- txn,
- table="event_edges",
- values={
- "event_id": event.event_id,
- "prev_event_id": e_id,
- "room_id": event.room_id,
- "is_state": 1,
- },
- or_ignore=True,
- )
-
- for hash_alg, hash_base64 in event.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_event_content_hash_txn(
- txn, event.event_id, hash_alg, hash_bytes,
- )
-
- for prev_event_id, prev_hashes in event.prev_events:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg, hash_bytes
- )
-
- for auth_id, _ in event.auth_events:
- self._simple_insert_txn(
- txn,
- table="event_auth",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- },
- or_ignore=True,
- )
-
- (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
- self._store_event_reference_hash_txn(
- txn, event.event_id, ref_alg, ref_hash_bytes
- )
-
- def _store_redaction(self, txn, event):
- # invalidate the cache for the redacted event
- self._get_event_cache.pop(event.redacts)
- txn.execute(
- "INSERT OR IGNORE INTO redactions "
- "(event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts)
- )
-
- @defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
- )
-
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- if event_type and state_key is not None:
- sql += " AND s.type = ? AND s.state_key = ? "
- args = (room_id, event_type, state_key)
- elif event_type:
- sql += " AND s.type = ?"
- args = (room_id, event_type)
- else:
- args = (room_id, )
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
- defer.returnValue(events)
-
- @defer.inlineCallbacks
- def get_room_name_and_aliases(self, room_id):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
)
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
- sql += " OR s.type = 'm.room.aliases')"
- args = (room_id,)
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
-
- name = None
- aliases = []
-
- for e in events:
- if e.type == 'm.room.name':
- if 'name' in e.content:
- name = e.content['name']
- elif e.type == 'm.room.aliases':
- if 'aliases' in e.content:
- aliases.extend(e.content['aliases'])
-
- defer.returnValue((name, aliases))
-
@defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
+ def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
+ now = int(self._clock.time_msec())
+ key = (user.to_string(), access_token, device_id, ip)
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
+ try:
+ last_seen = self.client_ip_last_seen.get(*key)
+ except KeyError:
+ last_seen = None
- logger.debug("min_token is: %s", self.min_token)
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ defer.returnValue(None)
- defer.returnValue(self.min_token)
+ self.client_ip_last_seen.prefill(*key + (now,))
- def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
- return self._simple_insert(
+ yield self._simple_insert(
"user_ips",
{
"user": user.to_string(),
@@ -517,8 +112,9 @@ class DataStore(RoomMemberStore, RoomStore,
"device_id": device_id,
"ip": ip,
"user_agent": user_agent,
- "last_seen": int(self._clock.time_msec()),
- }
+ "last_seen": now,
+ },
+ desc="insert_client_ip",
)
def get_user_ip_and_agents(self, user):
@@ -528,38 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen"
],
- )
-
- def have_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction(
- "have_events", f,
+ desc="get_user_ip_and_agents",
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 40f2fc6d76..6fa63f052e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -25,6 +25,7 @@ import synapse.metrics
from twisted.internet import defer
from collections import namedtuple, OrderedDict
+import functools
import simplejson as json
import sys
import time
@@ -38,6 +39,8 @@ transaction_logger = logging.getLogger("synapse.storage.txn")
metrics = synapse.metrics.get_metrics_for("synapse.storage")
+sql_scheduling_timer = metrics.register_distribution("schedule_time")
+
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
@@ -50,14 +53,57 @@ cache_counter = metrics.register_cache(
)
-# TODO(paul):
-# * more generic key management
-# * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
+
+ self.name = name
+ self.keylen = keylen
+
+ caches_by_name[name] = self.cache
+
+ def get(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if keyargs in self.cache:
+ cache_counter.inc_hits(self.name)
+ return self.cache[keyargs]
+
+ cache_counter.inc_misses(self.name)
+ raise KeyError()
+
+ def prefill(self, *args): # because I can't *keyargs, value
+ keyargs = args[:-1]
+ value = args[-1]
+
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[keyargs] = value
+
+ def invalidate(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
- The function is presumed to take one additional argument, which is used as
- the key for the cache. Cache hits are served directly from the cache;
+ The function is presumed to take zero or more arguments, which are used in
+ a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
@@ -68,33 +114,27 @@ def cached(max_entries=1000):
calling the calculation function.
"""
def wrap(orig):
- cache = OrderedDict()
- name = orig.__name__
-
- caches_by_name[name] = cache
-
- def prefill(key, value):
- while len(cache) > max_entries:
- cache.popitem(last=False)
-
- cache[key] = value
+ cache = Cache(
+ name=orig.__name__,
+ max_entries=max_entries,
+ keylen=num_args,
+ lru=lru,
+ )
+ @functools.wraps(orig)
@defer.inlineCallbacks
- def wrapped(self, key):
- if key in cache:
- cache_counter.inc_hits(name)
- defer.returnValue(cache[key])
+ def wrapped(self, *keyargs):
+ try:
+ defer.returnValue(cache.get(*keyargs))
+ except KeyError:
+ ret = yield orig(self, *keyargs)
- cache_counter.inc_misses(name)
- ret = yield orig(self, key)
- prefill(key, ret)
- defer.returnValue(ret)
+ cache.prefill(*keyargs + (ret,))
- def invalidate(key):
- cache.pop(key, None)
+ defer.returnValue(ret)
- wrapped.invalidate = invalidate
- wrapped.prefill = prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.prefill = cache.prefill
return wrapped
return wrap
@@ -240,6 +280,8 @@ class SQLBaseStore(object):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
+ start_time = time.time() * 1000
+
def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
@@ -252,6 +294,7 @@ class SQLBaseStore(object):
name = "%s-%x" % (desc, txn_id, )
+ sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
return func(LoggingTransaction(txn, name), *args, **kwargs)
@@ -314,7 +357,8 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
+ def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
+ desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -323,7 +367,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
- "_simple_insert",
+ desc,
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -347,7 +391,7 @@ class SQLBaseStore(object):
txn.execute(sql, values.values())
return txn.lastrowid
- def _simple_upsert(self, table, keyvalues, values):
+ def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
"""
Args:
table (str): The table to upsert into
@@ -356,7 +400,7 @@ class SQLBaseStore(object):
Returns: A deferred
"""
return self.runInteraction(
- "_simple_upsert",
+ desc,
self._simple_upsert_txn, table, keyvalues, values
)
@@ -392,7 +436,7 @@ class SQLBaseStore(object):
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
- allow_none=False):
+ allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -404,12 +448,15 @@ class SQLBaseStore(object):
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
- return self._simple_selectupdate_one(
- table, keyvalues, retcols=retcols, allow_none=allow_none
+ return self.runInteraction(
+ desc,
+ self._simple_select_one_txn,
+ table, keyvalues, retcols, allow_none,
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
- allow_none=False):
+ allow_none=False,
+ desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
@@ -419,7 +466,7 @@ class SQLBaseStore(object):
retcol : string giving the name of the column to return
"""
return self.runInteraction(
- "_simple_select_one_onecol",
+ desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
@@ -455,7 +502,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()]
- def _simple_select_onecol(self, table, keyvalues, retcol):
+ def _simple_select_onecol(self, table, keyvalues, retcol,
+ desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -468,12 +516,13 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- "_simple_select_onecol",
+ desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list(self, table, keyvalues, retcols,
+ desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -484,7 +533,7 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
- "_simple_select_list",
+ desc,
self._simple_select_list_txn,
table, keyvalues, retcols
)
@@ -516,7 +565,7 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
- retcols=None):
+ desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -534,56 +583,76 @@ class SQLBaseStore(object):
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
- return self._simple_selectupdate_one(table, keyvalues, updatevalues,
- retcols=retcols)
+ return self.runInteraction(
+ desc,
+ self._simple_update_one_txn,
+ table, keyvalues, updatevalues,
+ )
- def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
- retcols=None, allow_none=False):
- """ Combined SELECT then UPDATE."""
- if retcols:
- select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
- )
+ def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ update_sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
- if updatevalues:
- update_sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
- )
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ txn.execute(select_sql, keyvalues.values())
+ row = txn.fetchone()
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ return dict(zip(retcols, row))
+
+ def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+ retcols=None, allow_none=False,
+ desc="_simple_selectupdate_one"):
+ """ Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
- txn.execute(select_sql, keyvalues.values())
-
- row = txn.fetchone()
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched")
-
- ret = dict(zip(retcols, row))
+ ret = self._simple_select_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ retcols=retcols,
+ allow_none=allow_none,
+ )
if updatevalues:
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
+ self._simple_update_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ updatevalues=updatevalues,
)
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched")
-
return ret
- return self.runInteraction("_simple_selectupdate_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete_one(self, table, keyvalues):
+ def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -602,9 +671,9 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction("_simple_delete_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete(self, table, keyvalues):
+ def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
@@ -612,7 +681,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction("_simple_delete", self._simple_delete_txn)
+ return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
@@ -782,6 +851,13 @@ class SQLBaseStore(object):
return result[0] if result else None
+class _RollbackButIsFineException(Exception):
+ """ This exception is used to rollback a transaction without implying
+ something went wrong.
+ """
+ pass
+
+
class Table(object):
""" A base class used to store information about a particular table.
"""
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 68b7d59693..0199539fea 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.errors import SynapseError
@@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
{"room_alias": room_alias.to_string()},
"room_id",
allow_none=True,
+ desc="get_association_from_room_alias",
)
if not room_id:
@@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
+ desc="get_association_from_room_alias",
)
if not servers:
@@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias": room_alias.to_string(),
"room_id": room_id,
},
+ desc="create_room_alias_association",
)
except sqlite3.IntegrityError:
raise SynapseError(
@@ -100,16 +103,22 @@ class DirectoryStore(SQLBaseStore):
{
"room_alias": room_alias.to_string(),
"server": server,
- }
+ },
+ desc="create_room_alias_association",
)
+ self.get_aliases_for_room.invalidate(room_id)
+ @defer.inlineCallbacks
def delete_room_alias(self, room_alias):
- return self.runInteraction(
+ room_id = yield self.runInteraction(
"delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
+ self.get_aliases_for_room.invalidate(room_id)
+ defer.returnValue(room_id)
+
def _delete_room_alias_txn(self, txn, room_alias):
cursor = txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
@@ -134,9 +143,11 @@ class DirectoryStore(SQLBaseStore):
return room_id
+ @cached()
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
+ desc="get_aliases_for_room",
)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 2deda8ac50..032334bfd6 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -429,3 +429,15 @@ class EventFederationStore(SQLBaseStore):
)
return events[:limit]
+
+ def clean_room_for_join(self, room_id):
+ return self.runInteraction(
+ "clean_room_for_join",
+ self._clean_room_for_join_txn,
+ room_id,
+ )
+
+ def _clean_room_for_join_txn(self, txn, room_id):
+ query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
+
+ txn.execute(query, (room_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
new file mode 100644
index 0000000000..a86230d92c
--- /dev/null
+++ b/synapse/storage/events.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore, _RollbackButIsFineException
+
+from twisted.internet import defer
+
+from synapse.util.logutils import log_function
+from synapse.api.constants import EventTypes
+from synapse.crypto.event_signing import compute_event_reference_hash
+
+from syutil.base64util import decode_base64
+from syutil.jsonutil import encode_canonical_json
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class EventsStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ @log_function
+ def persist_event(self, event, context, backfilled=False,
+ is_new_state=True, current_state=None):
+ stream_ordering = None
+ if backfilled:
+ if not self.min_token_deferred.called:
+ yield self.min_token_deferred
+ self.min_token -= 1
+ stream_ordering = self.min_token
+
+ try:
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
+ self.get_room_events_max_id.invalidate()
+ except _RollbackButIsFineException:
+ pass
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ event = yield self.runInteraction(
+ "get_event", self._get_event_txn,
+ event_id,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not event and not allow_none:
+ raise RuntimeError("Could not find event %s" % (event_id,))
+
+ defer.returnValue(event)
+
+ @log_function
+ def _persist_event_txn(self, txn, event, context, backfilled,
+ stream_ordering=None, is_new_state=True,
+ current_state=None):
+
+ # Remove the any existing cache entries for the event_id
+ self._get_event_cache.pop(event.event_id)
+
+ # We purposefully do this first since if we include a `current_state`
+ # key, we *want* to update the `current_state_events` table
+ if current_state:
+ txn.execute(
+ "DELETE FROM current_state_events WHERE room_id = ?",
+ (event.room_id,)
+ )
+
+ for s in current_state:
+ self._simple_insert_txn(
+ txn,
+ "current_state_events",
+ {
+ "event_id": s.event_id,
+ "room_id": s.room_id,
+ "type": s.type,
+ "state_key": s.state_key,
+ },
+ or_replace=True,
+ )
+
+ if event.is_state() and is_new_state:
+ if not backfilled and not context.rejected:
+ self._simple_insert_txn(
+ txn,
+ table="state_forward_extremities",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ or_replace=True,
+ )
+
+ for prev_state_id, _ in event.prev_state:
+ self._simple_delete_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "event_id": prev_state_id,
+ }
+ )
+
+ outlier = event.internal_metadata.is_outlier()
+
+ if not outlier:
+ self._store_state_groups_txn(txn, event, context)
+
+ self._update_min_depth_for_room_txn(
+ txn,
+ event.room_id,
+ event.depth
+ )
+
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ have_persisted = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event.event_id},
+ retcol="event_id",
+ allow_none=True,
+ )
+
+ metadata_json = encode_canonical_json(
+ event.internal_metadata.get_dict()
+ )
+
+ # If we have already persisted this event, we don't need to do any
+ # more processing.
+ # The processing above must be done on every call to persist event,
+ # since they might not have happened on previous calls. For example,
+ # if we are persisting an event that we had persisted as an outlier,
+ # but is no longer one.
+ if have_persisted:
+ if not outlier:
+ sql = (
+ "UPDATE event_json SET internal_metadata = ?"
+ " WHERE event_id = ?"
+ )
+ txn.execute(
+ sql,
+ (metadata_json.decode("UTF-8"), event.event_id,)
+ )
+
+ sql = (
+ "UPDATE events SET outlier = 0"
+ " WHERE event_id = ?"
+ )
+ txn.execute(
+ sql,
+ (event.event_id,)
+ )
+ return
+
+ if event.type == EventTypes.Member:
+ self._store_room_member_txn(txn, event)
+ elif event.type == EventTypes.Feedback:
+ self._store_feedback_txn(txn, event)
+ elif event.type == EventTypes.Name:
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ self._store_redaction(txn, event)
+
+ event_dict = {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in [
+ "redacted",
+ "redacted_because",
+ ]
+ }
+
+ self._simple_insert_txn(
+ txn,
+ table="event_json",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "internal_metadata": metadata_json.decode("UTF-8"),
+ "json": encode_canonical_json(event_dict).decode("UTF-8"),
+ },
+ or_replace=True,
+ )
+
+ content = encode_canonical_json(
+ event.content
+ ).decode("UTF-8")
+
+ vals = {
+ "topological_ordering": event.depth,
+ "event_id": event.event_id,
+ "type": event.type,
+ "room_id": event.room_id,
+ "content": content,
+ "processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
+ }
+
+ if stream_ordering is not None:
+ vals["stream_ordering"] = stream_ordering
+
+ unrec = {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in vals.keys() and k not in [
+ "redacted",
+ "redacted_because",
+ "signatures",
+ "hashes",
+ "prev_events",
+ ]
+ }
+
+ vals["unrecognized_keys"] = encode_canonical_json(
+ unrec
+ ).decode("UTF-8")
+
+ try:
+ self._simple_insert_txn(
+ txn,
+ "events",
+ vals,
+ or_replace=(not outlier),
+ or_ignore=bool(outlier),
+ )
+ except:
+ logger.warn(
+ "Failed to persist, probably duplicate: %s",
+ event.event_id,
+ exc_info=True,
+ )
+ raise _RollbackButIsFineException("_persist_event")
+
+ if context.rejected:
+ self._store_rejections_txn(txn, event.event_id, context.rejected)
+
+ if event.is_state():
+ vals = {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+
+ # TODO: How does this work with backfilling?
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
+
+ self._simple_insert_txn(
+ txn,
+ "state_events",
+ vals,
+ )
+
+ if is_new_state and not context.rejected:
+ self._simple_insert_txn(
+ txn,
+ "current_state_events",
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ )
+
+ for e_id, h in event.prev_state:
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event.event_id,
+ "prev_event_id": e_id,
+ "room_id": event.room_id,
+ "is_state": 1,
+ },
+ )
+
+ for hash_alg, hash_base64 in event.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_event_content_hash_txn(
+ txn, event.event_id, hash_alg, hash_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg, hash_bytes
+ )
+
+ for auth_id, _ in event.auth_events:
+ self._simple_insert_txn(
+ txn,
+ table="event_auth",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "auth_id": auth_id,
+ },
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+ self._store_event_reference_hash_txn(
+ txn, event.event_id, ref_alg, ref_hash_bytes
+ )
+
+ def _store_redaction(self, txn, event):
+ # invalidate the cache for the redacted event
+ self._get_event_cache.pop(event.redacts)
+ txn.execute(
+ "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
+ (event.event_id, event.redacts)
+ )
+
+ def have_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Returns:
+ dict: Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps to
+ None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction(
+ "have_events", f,
+ )
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
deleted file mode 100644
index 8eab769b71..0000000000
--- a/synapse/storage/feedback.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore
-
-
-class FeedbackStore(SQLBaseStore):
-
- def _store_feedback_txn(self, txn, event):
- self._simple_insert_txn(txn, "feedback", {
- "event_id": event.event_id,
- "feedback_type": event.content["type"],
- "room_id": event.room_id,
- "target_event_id": event.content["target_event_id"],
- "sender": event.user_id,
- })
-
- @defer.inlineCallbacks
- def get_feedback_for_event(self, event_id):
- sql = (
- "SELECT events.* FROM events INNER JOIN feedback "
- "ON events.event_id = feedback.event_id "
- "WHERE feedback.target_event_id = ? "
- )
-
- rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
-
- defer.returnValue(
- [
- (yield self._parse_events(r))
- for r in rows
- ]
- )
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 457a11fd02..8800116570 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
},
retcol="filter_json",
allow_none=False,
+ desc="get_user_filter",
)
defer.returnValue(json.loads(def_json))
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 7101d2beec..7bf57234f6 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"),
allow_none=True,
+ desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
@@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
- }
+ },
+ desc="store_local_media",
)
def get_local_media_thumbnails(self, media_id):
@@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length",
- )
+ ),
+ desc="get_local_media_thumbnails",
)
def store_local_thumbnail(self, media_id, thumbnail_width,
@@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
- }
+ },
+ desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
@@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
"filesystem_id",
),
allow_none=True,
+ desc="get_cached_remote_media",
)
def store_cached_remote_media(self, origin, media_id, media_type,
@@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
- }
+ },
+ desc="store_cached_remote_media",
)
def get_remote_media_thumbnails(self, origin, media_id):
@@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id",
- )
+ ),
+ desc="get_remote_media_thumbnails",
)
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
@@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id,
- }
+ },
+ desc="store_remote_media_thumbnail",
)
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1dcd34723b..87fba55439 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
return self._simple_insert(
table="presence",
values={"user_id": user_localpart},
+ desc="create_presence",
)
def has_presence_state(self, user_localpart):
@@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
+ desc="has_presence_state",
)
def get_presence_state(self, user_localpart):
@@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
+ desc="get_presence_state",
)
def set_presence_state(self, user_localpart, new_state):
@@ -45,7 +48,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
- retcols=["state"],
+ desc="set_presence_state",
)
def allow_presence_visible(self, observed_localpart, observer_userid):
@@ -53,6 +56,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
+ desc="allow_presence_visible",
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
@@ -60,6 +64,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
+ desc="disallow_presence_visible",
)
def is_presence_visible(self, observed_localpart, observer_userid):
@@ -69,6 +74,7 @@ class PresenceStore(SQLBaseStore):
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
+ desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
@@ -77,6 +83,7 @@ class PresenceStore(SQLBaseStore):
values={"user_id": observer_localpart,
"observed_user_id": observed_userid,
"accepted": False},
+ desc="add_presence_list_pending",
)
def set_presence_list_accepted(self, observer_localpart, observed_userid):
@@ -85,6 +92,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
+ desc="set_presence_list_accepted",
)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -96,6 +104,7 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
+ desc="get_presence_list",
)
def del_presence_list(self, observer_localpart, observed_userid):
@@ -103,4 +112,5 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
+ desc="del_presence_list",
)
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 153c7ad027..a6e52cb248 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
return self._simple_insert(
table="profiles",
values={"user_id": user_localpart},
+ desc="create_profile",
)
def get_profile_displayname(self, user_localpart):
@@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
+ desc="get_profile_displayname",
)
def set_profile_displayname(self, user_localpart, new_displayname):
@@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
+ desc="set_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
@@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
+ desc="get_profile_avatar_url",
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
@@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
+ desc="set_profile_avatar_url",
)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index d769db2c78..c47bdc2861 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name},
- PushRuleEnableTable.fields
+ PushRuleEnableTable.fields,
+ desc="get_push_rules_enabled_for_user",
)
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
@@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
"""
yield self._simple_delete_one(
PushRuleTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id}
+ {'user_name': user_name, 'rule_id': rule_id},
+ desc="delete_push_rule",
)
@defer.inlineCallbacks
@@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
- {'enabled': enabled}
+ {'enabled': enabled},
+ desc="set_push_rule_enabled",
)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 587dada68f..000502b4ff 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
ts=pushkey_ts,
lang=lang,
data=data
- ))
+ ),
+ desc="add_pusher",
+ )
except Exception as e:
logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.")
@@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
yield self._simple_delete_one(
PushersTable.table_name,
- dict(app_id=app_id, pushkey=pushkey)
+ {"app_id": app_id, "pushkey": pushkey},
+ desc="delete_pusher_by_app_id_pushkey",
)
@defer.inlineCallbacks
@@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
- {'last_token': last_token}
+ {'last_token': last_token},
+ desc="update_pusher_last_token",
)
@defer.inlineCallbacks
@@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
- {'last_token': last_token, 'last_success': last_success}
+ {'last_token': last_token, 'last_success': last_success},
+ desc="update_pusher_last_token_and_success",
)
@defer.inlineCallbacks
@@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
- {'failing_since': failing_since}
+ {'failing_since': failing_since},
+ desc="update_pusher_failing_since",
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index adc8fc0794..f24154f146 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
class RegistrationStore(SQLBaseStore):
@@ -39,7 +39,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
+ row = yield self._simple_select_one(
+ "users", {"name": user_id}, ["id"],
+ desc="add_access_token_to_user",
+ )
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
@@ -48,7 +51,8 @@ class RegistrationStore(SQLBaseStore):
{
"user_id": row_id,
"token": token
- }
+ },
+ desc="add_access_token_to_user",
)
@defer.inlineCallbacks
@@ -91,6 +95,11 @@ class RegistrationStore(SQLBaseStore):
"get_user_by_id", self.cursor_to_dict, query, user_id
)
+ @cached()
+ # TODO(paul): Currently there's no code to invalidate this cache. That
+ # means if/when we ever add internal ways to invalidate access tokens or
+ # change whether a user is a server admin, those will need to invoke
+ # store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token):
"""Get a user from the given access token.
@@ -115,6 +124,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
+ desc="is_server_admin",
)
defer.returnValue(res if res else False)
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 4e1a9a2783..0838eb3d12 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
"reason": reason,
"last_check": self._clock.time_msec(),
- }
+ },
)
def get_rejection_reason(self, event_id):
@@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
},
allow_none=True,
+ desc="get_rejection_reason",
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 549c9af393..be3e28c2ea 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -15,11 +15,9 @@
from twisted.internet import defer
-from sqlite3 import IntegrityError
-
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Table
+from ._base import SQLBaseStore
import collections
import logging
@@ -27,8 +25,9 @@ import logging
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple("OpsLevel", (
- "ban_level", "kick_level", "redact_level")
+OpsLevel = collections.namedtuple(
+ "OpsLevel",
+ ("ban_level", "kick_level", "redact_level",)
)
@@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored.
"""
try:
- yield self._simple_insert(RoomsTable.table_name, dict(
- room_id=room_id,
- creator=room_creator_user_id,
- is_public=is_public
- ))
- except IntegrityError:
- raise StoreError(409, "Room ID in use.")
+ yield self._simple_insert(
+ RoomsTable.table_name,
+ {
+ "room_id": room_id,
+ "creator": room_creator_user_id,
+ "is_public": is_public,
+ },
+ desc="store_room",
+ )
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -66,9 +67,11 @@ class RoomStore(SQLBaseStore):
Returns:
A namedtuple containing the room information, or an empty list.
"""
- query = RoomsTable.select_statement("room_id=?")
- return self._execute(
- "get_room", RoomsTable.decode_single_result, query, room_id,
+ return self._simple_select_one(
+ table=RoomsTable.table_name,
+ keyvalues={"room_id": room_id},
+ retcols=RoomsTable.fields,
+ desc="get_room",
)
@defer.inlineCallbacks
@@ -143,7 +146,7 @@ class RoomStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"topic": event.content["topic"],
- }
+ },
)
def _store_room_name_txn(self, txn, event):
@@ -158,8 +161,45 @@ class RoomStore(SQLBaseStore):
}
)
+ @defer.inlineCallbacks
+ def get_room_name_and_aliases(self, room_id):
+ del_sql = (
+ "SELECT event_id FROM redactions WHERE redacts = e.event_id "
+ "LIMIT 1"
+ )
+
+ sql = (
+ "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
+ "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
+ "INNER JOIN state_events as s ON e.event_id = s.event_id "
+ "WHERE c.room_id = ? "
+ ) % {
+ "redacted": del_sql,
+ }
+
+ sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
+ sql += " OR s.type = 'm.room.aliases')"
+ args = (room_id,)
-class RoomsTable(Table):
+ results = yield self._execute_and_decode("get_current_state", sql, *args)
+
+ events = yield self._parse_events(results)
+
+ name = None
+ aliases = []
+
+ for e in events:
+ if e.type == 'm.room.name':
+ if 'name' in e.content:
+ name = e.content['name']
+ elif e.type == 'm.room.aliases':
+ if 'aliases' in e.content:
+ aliases.extend(e.content['aliases'])
+
+ defer.returnValue((name, aliases))
+
+
+class RoomsTable(object):
table_name = "rooms"
fields = [
@@ -167,5 +207,3 @@ class RoomsTable(Table):
"is_public",
"creator"
]
-
- EntryType = collections.namedtuple("RoomEntry", fields)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 65ffb4627f..52c37c76f5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
return self._simple_select_onecol(
"room_hosts",
{"room_id": room_id},
- "host"
+ "host",
+ desc="get_joined_hosts_for_room",
)
def _get_members_by_dict(self, where_dict):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 71db16d0e5..58dbf2802b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,6 +15,8 @@
from ._base import SQLBaseStore
+from twisted.internet import defer
+
import logging
logger = logging.getLogger(__name__)
@@ -82,7 +84,7 @@ class StateStore(SQLBaseStore):
if context.current_state is None:
return
- state_events = context.current_state
+ state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
@@ -122,3 +124,33 @@ class StateStore(SQLBaseStore):
},
or_replace=True,
)
+
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ del_sql = (
+ "SELECT event_id FROM redactions WHERE redacts = e.event_id "
+ "LIMIT 1"
+ )
+
+ sql = (
+ "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
+ "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
+ "INNER JOIN state_events as s ON e.event_id = s.event_id "
+ "WHERE c.room_id = ? "
+ ) % {
+ "redacted": del_sql,
+ }
+
+ if event_type and state_key is not None:
+ sql += " AND s.type = ? AND s.state_key = ? "
+ args = (room_id, event_type, state_key)
+ elif event_type:
+ sql += " AND s.type = ?"
+ args = (room_id, event_type)
+ else:
+ args = (room_id, )
+
+ results = yield self._execute_and_decode("get_current_state", sql, *args)
+
+ events = yield self._parse_events(results)
+ defer.returnValue(events)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 09bc522210..66f307e640 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,7 +35,7 @@ what sort order was used:
from twisted.internet import defer
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function
@@ -413,12 +413,32 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn
)
+ @cached(num_args=0)
def get_room_events_max_id(self):
return self.runInteraction(
"get_room_events_max_id",
self._get_room_events_max_id_txn
)
+ @defer.inlineCallbacks
+ def _get_min_token(self):
+ row = yield self._execute(
+ "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
+ )
+
+ self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
+ self.min_token = min(self.min_token, -1)
+
+ logger.debug("min_token is: %s", self.min_token)
+
+ defer.returnValue(self.min_token)
+
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
def _get_room_events_max_id_txn(self, txn):
txn.execute(
"SELECT MAX(stream_ordering) as m FROM events"
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 0b8a3b7a07..b777395e06 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- where_clause = "transaction_id = ? AND origin = ?"
- query = ReceivedTransactionsTable.select_statement(where_clause)
-
- txn.execute(query, (transaction_id, origin))
-
- results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+ result = self._simple_select_one_txn(
+ txn,
+ table=ReceivedTransactionsTable.table_name,
+ keyvalues={
+ "transaction_id": transaction_id,
+ "origin": origin,
+ },
+ retcols=ReceivedTransactionsTable.fields,
+ allow_none=True,
+ )
- if results and results[0].response_code:
- return (results[0].response_code, results[0].response_json)
+ if result and result.response_code:
+ return result["response_code"], result["response_json"]
else:
return None
diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py
index 65d5792907..2f7b615f78 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/lrucache.py
@@ -90,12 +90,16 @@ class LruCache(object):
def cache_len():
return len(cache)
+ def cache_contains(key):
+ return key in cache
+
self.sentinel = object()
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
self.len = cache_len
+ self.contains = cache_contains
def __getitem__(self, key):
result = self.get(key, self.sentinel)
@@ -114,3 +118,6 @@ class LruCache(object):
def __len__(self):
return self.len()
+
+ def __contains__(self, key):
+ return self.contains(key)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index ea53a8085c..52e66beaee 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,10 @@
import random
import string
+_string_with_symbols = (
+ string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
+)
+
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
@@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+
+
+def random_string_with_symbols(length):
+ return ''.join(
+ random.choice(_string_with_symbols) for _ in xrange(length)
+ )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 55d22f665a..96caf8c4c1 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,7 +17,79 @@
from tests import unittest
from twisted.internet import defer
-from synapse.storage._base import cached
+from synapse.storage._base import Cache, cached
+
+
+class CacheTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.cache = Cache("test")
+
+ def test_empty(self):
+ failed = False
+ try:
+ self.cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_hit(self):
+ self.cache.prefill("foo", 123)
+
+ self.assertEquals(self.cache.get("foo"), 123)
+
+ def test_invalidate(self):
+ self.cache.prefill("foo", 123)
+ self.cache.invalidate("foo")
+
+ failed = False
+ try:
+ self.cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_eviction(self):
+ cache = Cache("test", max_entries=2)
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+ cache.prefill(3, "three") # 1 will be evicted
+
+ failed = False
+ try:
+ cache.get(1)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(2)
+ cache.get(3)
+
+ def test_eviction_lru(self):
+ cache = Cache("test", max_entries=2, lru=True)
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ cache.prefill(3, "three")
+
+ failed = False
+ try:
+ cache.get(2)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(1)
+ cache.get(3)
class CacheDecoratorTestCase(unittest.TestCase):
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 55fbffa7a2..7f5845cf0c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",)
- ret = yield self.datastore._simple_update_one(
+ ret = yield self.datastore._simple_selectupdate_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"},
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index c88dd446fb..ab7625a3ca 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_room(self):
- self.assertObjectHasAttributes(
+ self.assertDictContainsSubset(
{"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True},
|