diff options
author | Mark Haines <mark.haines@matrix.org> | 2015-04-29 13:15:14 +0100 |
---|---|---|
committer | Mark Haines <mark.haines@matrix.org> | 2015-04-29 13:15:14 +0100 |
commit | 4ad8b451559682645c818f6c1180e5f9d42a7eeb (patch) | |
tree | 94fe0beaaaa4ed8c1bb9451f1274df819d9ce8d7 /synapse | |
parent | Update the query format used by keyring to match current key v2 spec (diff) | |
parent | Mention that postgres databases must have the correct charset encoding (diff) | |
download | synapse-4ad8b451559682645c818f6c1180e5f9d42a7eeb.tar.xz |
Merge branch 'develop' into key_distribution
Conflicts: synapse/config/homeserver.py
Diffstat (limited to 'synapse')
95 files changed, 3332 insertions, 1011 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a21120b313..77322a5c10 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, StoreError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.types import UserID, ClientInfo @@ -40,6 +40,7 @@ class Auth(object): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() + self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -222,6 +223,13 @@ class Auth(object): elif target_in_room: # the target is already in the room. raise AuthError(403, "%s is already in the room." % target_user_id) + else: + invite_level = self._get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: # invited: They are accepting the invitation @@ -362,7 +370,10 @@ class Auth(object): defer.returnValue((user, ClientInfo(device_id, token_id))) except KeyError: - raise AuthError(403, "Missing access token.") + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) @defer.inlineCallbacks def get_user_by_token(self, token): @@ -376,21 +387,20 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ - try: - ret = yield self.store.get_user_by_token(token) - if not ret: - raise StoreError(400, "Unknown token") - user_info = { - "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - } + ret = yield self.store.get_user_by_token(token) + if not ret: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", + errcode=Codes.UNKNOWN_TOKEN + ) + user_info = { + "admin": bool(ret.get("admin", False)), + "device_id": ret.get("device_id"), + "user": UserID.from_string(ret.get("name")), + "token_id": ret.get("token_id", None), + } - defer.returnValue(user_info) - except StoreError: - raise AuthError(403, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN) + defer.returnValue(user_info) @defer.inlineCallbacks def get_appservice_by_req(self, request): @@ -398,11 +408,16 @@ class Auth(object): token = request.args["access_token"][0] service = yield self.store.get_app_service_by_token(token) if not service: - raise AuthError(403, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Unrecognised access token.", + errcode=Codes.UNKNOWN_TOKEN + ) defer.returnValue(service) except KeyError: - raise AuthError(403, "Missing access token.") + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." + ) def is_server_admin(self, user): return self.store.is_server_admin(user) @@ -561,6 +576,7 @@ class Auth(object): ("ban", []), ("redact", []), ("kick", []), + ("invite", []), ] old_list = current_state.content.get("users") diff --git a/synapse/api/constants.py b/synapse/api/constants.py index b16bf4247d..d8a18ee87b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -59,6 +59,9 @@ class LoginType(object): EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" RECAPTCHA = u"m.login.recaptcha" + DUMMY = u"m.login.dummy" + + # Only for C/S API v1 APPLICATION_SERVICE = u"m.login.application_service" SHARED_SECRET = u"org.matrix.login.shared_secret" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 72d2bd5b4c..0b3320e62c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -31,6 +31,7 @@ class Codes(object): BAD_PAGINATION = "M_BAD_PAGINATION" UNKNOWN = "M_UNKNOWN" NOT_FOUND = "M_NOT_FOUND" + MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" @@ -38,6 +39,7 @@ class Codes(object): MISSING_PARAM = "M_MISSING_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" + THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" class CodeMessageException(RuntimeError): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e681941612..bedf797bd2 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -17,8 +17,9 @@ import sys sys.dont_write_bytecode = True +from synapse.storage.engines import create_engine from synapse.storage import ( - prepare_database, prepare_sqlite3_database, UpgradeDatabaseException, + are_all_users_on_domain, UpgradeDatabaseException, ) from synapse.server import HomeServer @@ -59,9 +60,9 @@ import os import re import resource import subprocess -import sqlite3 -logger = logging.getLogger(__name__) + +logger = logging.getLogger("synapse.app.homeserver") class SynapseHomeServer(HomeServer): @@ -108,13 +109,11 @@ class SynapseHomeServer(HomeServer): return None def build_db_pool(self): + name = self.db_config["name"] + return adbapi.ConnectionPool( - "sqlite3", self.get_db_name(), - check_same_thread=False, - cp_min=1, - cp_max=1, - cp_openfun=prepare_database, # Prepare the database for each conn - # so that :memory: sqlite works + name, + **self.db_config.get("args", {}) ) def create_resource_tree(self, redirect_root_to_web_client): @@ -247,6 +246,21 @@ class SynapseHomeServer(HomeServer): ) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) + def run_startup_checks(self, db_conn, database_engine): + all_users_native = are_all_users_on_domain( + db_conn.cursor(), database_engine, self.hostname + ) + if not all_users_native: + sys.stderr.write( + "\n" + "******************************************************\n" + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured\n" + "******************************************************\n" + "\n" % (self.hostname,) + ) + sys.exit(1) + def get_version_string(): try: @@ -358,15 +372,20 @@ def setup(config_options): tls_context_factory = context_factory.ServerContextFactory(config) + database_engine = create_engine(config.database_config["name"]) + config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection + hs = SynapseHomeServer( config.server_name, domain_with_port=domain_with_port, upload_dir=os.path.abspath("uploads"), db_name=config.database_path, + db_config=config.database_config, tls_context_factory=tls_context_factory, config=config, content_addr=config.content_addr, version_string=version_string, + database_engine=database_engine, ) hs.create_resource_tree( @@ -378,9 +397,17 @@ def setup(config_options): logger.info("Preparing database: %s...", db_name) try: - with sqlite3.connect(db_name) as db_conn: - prepare_sqlite3_database(db_conn) - prepare_database(db_conn) + db_conn = database_engine.module.connect( + **{ + k: v for k, v in config.database_config.get("args", {}).items() + if not k.startswith("cp_") + } + ) + + database_engine.prepare_database(db_conn) + hs.run_startup_checks(db_conn, database_engine) + + db_conn.commit() except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 6fd086a471..b59f4e45e2 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -158,9 +158,10 @@ class Config(object): and value is not None): config[key] = value with open(config_args.config_path, "w") as config_file: - # TODO(paul) it would be lovely if we wrote out vim- and emacs- - # style mode markers into the file, to hint to people that - # this is a YAML file. + # TODO(mark/paul) We might want to output emacs-style mode + # markers as well as vim-style mode markers into the file, + # to further hint to people this is a YAML file. + config_file.write("# vim:ft=yaml\n") yaml.dump(config, config_file, default_flow_style=False) print ( "A config file has been generated in %s for server name" diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 7e21c7414d..07fbfadc0f 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -20,6 +20,7 @@ class CaptchaConfig(Config): def __init__(self, args): super(CaptchaConfig, self).__init__(args) self.recaptcha_private_key = args.recaptcha_private_key + self.recaptcha_public_key = args.recaptcha_public_key self.enable_registration_captcha = args.enable_registration_captcha self.captcha_ip_origin_is_x_forwarded = ( args.captcha_ip_origin_is_x_forwarded @@ -31,8 +32,12 @@ class CaptchaConfig(Config): super(CaptchaConfig, cls).add_arguments(parser) group = parser.add_argument_group("recaptcha") group.add_argument( + "--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY", + help="This Home Server's ReCAPTCHA public key." + ) + group.add_argument( "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", - help="The matching private key for the web client's public key." + help="This Home Server's ReCAPTCHA private key." ) group.add_argument( "--enable-registration-captcha", type=bool, default=False, diff --git a/synapse/config/database.py b/synapse/config/database.py index 87efe54645..190d119df4 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -15,6 +15,7 @@ from ._base import Config import os +import yaml class DatabaseConfig(Config): @@ -26,18 +27,45 @@ class DatabaseConfig(Config): self.database_path = self.abspath(args.database_path) self.event_cache_size = self.parse_size(args.event_cache_size) + if args.database_config: + with open(args.database_config) as f: + self.database_config = yaml.safe_load(f) + else: + self.database_config = { + "name": "sqlite3", + "args": { + "database": self.database_path, + }, + } + + name = self.database_config.get("name", None) + if name == "psycopg2": + pass + elif name == "sqlite3": + self.database_config.setdefault("args", {}).update({ + "cp_min": 1, + "cp_max": 1, + "check_same_thread": False, + }) + else: + raise RuntimeError("Unsupported database type '%s'" % (name,)) + @classmethod def add_arguments(cls, parser): super(DatabaseConfig, cls).add_arguments(parser) db_group = parser.add_argument_group("database") db_group.add_argument( "-d", "--database-path", default="homeserver.db", - help="The database name." + metavar="SQLITE_DATABASE_PATH", help="The database name." ) db_group.add_argument( "--event-cache-size", default="100K", help="Number of events to cache in memory." ) + db_group.add_argument( + "--database-config", default=None, + help="Location of the database configuration file." + ) @classmethod def generate_config(cls, args, config_dir_path): diff --git a/synapse/config/email.py b/synapse/config/email.py deleted file mode 100644 index f0854f8c37..0000000000 --- a/synapse/config/email.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config - - -class EmailConfig(Config): - - def __init__(self, args): - super(EmailConfig, self).__init__(args) - self.email_from_address = args.email_from_address - self.email_smtp_server = args.email_smtp_server - - @classmethod - def add_arguments(cls, parser): - super(EmailConfig, cls).add_arguments(parser) - email_group = parser.add_argument_group("email") - email_group.add_argument( - "--email-from-address", - default="FROM@EXAMPLE.COM", - help="The address to send emails from (e.g. for password resets)." - ) - email_group.add_argument( - "--email-smtp-server", - default="", - help=( - "The SMTP server to send emails from (e.g. for password" - " resets)." - ) - ) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 967a0f45d6..1c8ff38465 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -20,7 +20,6 @@ from .database import DatabaseConfig from .ratelimiting import RatelimitConfig from .repository import ContentRepositoryConfig from .captcha import CaptchaConfig -from .email import EmailConfig from .voip import VoipConfig from .registration import RegistrationConfig from .metrics import MetricsConfig @@ -30,7 +29,7 @@ from .key import KeyConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - EmailConfig, VoipConfig, RegistrationConfig, + VoipConfig, RegistrationConfig, MetricsConfig, AppServiceConfig, KeyConfig,): pass diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 63c8e36930..247b324816 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -78,7 +78,6 @@ class LoggingConfig(Config): handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) - logger.info("Test") else: with open(self.log_config, 'r') as f: logging.config.dictConfig(yaml.load(f)) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 0c51d615ec..685792dbdc 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -30,6 +30,8 @@ from .typing import TypingNotificationHandler from .admin import AdminHandler from .appservice import ApplicationServicesHandler from .sync import SyncHandler +from .auth import AuthHandler +from .identity import IdentityHandler class Handlers(object): @@ -64,3 +66,5 @@ class Handlers(object): ) ) self.sync_handler = SyncHandler(hs) + self.auth_handler = AuthHandler(hs) + self.identity_handler = IdentityHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 48816a242d..dffb033fbd 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -16,7 +16,6 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError, SynapseError -from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes from synapse.types import UserID @@ -58,8 +57,6 @@ class BaseHandler(object): @defer.inlineCallbacks def _create_new_client_event(self, builder): - yield run_on_reactor() - latest_ret = yield self.store.get_latest_events_in_room( builder.room_id, ) @@ -101,8 +98,6 @@ class BaseHandler(object): @defer.inlineCallbacks def handle_new_client_event(self, event, context, extra_destinations=[], extra_users=[], suppress_auth=False): - yield run_on_reactor() - # We now need to go and hit out to wherever we need to hit out to. if not suppress_auth: @@ -143,7 +138,9 @@ class BaseHandler(object): ) # Don't block waiting on waking up all the listeners. - d = self.notifier.on_new_room_event(event, extra_users=extra_users) + notify_d = self.notifier.on_new_room_event( + event, extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -151,8 +148,10 @@ class BaseHandler(object): event.event_id, f.value ) - d.addErrback(log_failure) + notify_d.addErrback(log_failure) - yield federation_handler.handle_new_event( + fed_d = federation_handler.handle_new_event( event, destinations=destinations, ) + + fed_d.addErrback(log_failure) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py new file mode 100644 index 0000000000..2e8009d3c3 --- /dev/null +++ b/synapse/handlers/auth.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler +from synapse.api.constants import LoginType +from synapse.types import UserID +from synapse.api.errors import LoginError, Codes +from synapse.http.client import SimpleHttpClient +from synapse.util.async import run_on_reactor + +from twisted.web.client import PartialDownloadError + +import logging +import bcrypt +import simplejson + +import synapse.util.stringutils as stringutils + + +logger = logging.getLogger(__name__) + + +class AuthHandler(BaseHandler): + + def __init__(self, hs): + super(AuthHandler, self).__init__(hs) + self.checkers = { + LoginType.PASSWORD: self._check_password_auth, + LoginType.RECAPTCHA: self._check_recaptcha, + LoginType.EMAIL_IDENTITY: self._check_email_identity, + LoginType.DUMMY: self._check_dummy_auth, + } + self.sessions = {} + + @defer.inlineCallbacks + def check_auth(self, flows, clientdict, clientip=None): + """ + Takes a dictionary sent by the client in the login / registration + protocol and handles the login flow. + + Args: + flows: list of list of stages + authdict: The dictionary from the client root level, not the + 'auth' key: this method prompts for auth if none is sent. + Returns: + A tuple of authed, dict, dict where authed is true if the client + has successfully completed an auth flow. If it is true, the first + dict contains the authenticated credentials of each stage. + + If authed is false, the first dictionary is the server response to + the login request and should be passed back to the client. + + In either case, the second dict contains the parameters for this + request (which may have been given only in a previous call). + """ + + authdict = None + sid = None + if clientdict and 'auth' in clientdict: + authdict = clientdict['auth'] + del clientdict['auth'] + if 'session' in authdict: + sid = authdict['session'] + sess = self._get_session_info(sid) + + if len(clientdict) > 0: + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a home server. + # sess['clientdict'] = clientdict + # self._save_session(sess) + pass + elif 'clientdict' in sess: + clientdict = sess['clientdict'] + + if not authdict: + defer.returnValue( + (False, self._auth_dict_for_flows(flows, sess), clientdict) + ) + + if 'creds' not in sess: + sess['creds'] = {} + creds = sess['creds'] + + # check auth type currently being presented + if 'type' in authdict: + if authdict['type'] not in self.checkers: + raise LoginError(400, "", Codes.UNRECOGNIZED) + result = yield self.checkers[authdict['type']](authdict, clientip) + if result: + creds[authdict['type']] = result + self._save_session(sess) + + for f in flows: + if len(set(f) - set(creds.keys())) == 0: + logger.info("Auth completed with creds: %r", creds) + self._remove_session(sess) + defer.returnValue((True, creds, clientdict)) + + ret = self._auth_dict_for_flows(flows, sess) + ret['completed'] = creds.keys() + defer.returnValue((False, ret, clientdict)) + + @defer.inlineCallbacks + def add_oob_auth(self, stagetype, authdict, clientip): + """ + Adds the result of out-of-band authentication into an existing auth + session. Currently used for adding the result of fallback auth. + """ + if stagetype not in self.checkers: + raise LoginError(400, "", Codes.MISSING_PARAM) + if 'session' not in authdict: + raise LoginError(400, "", Codes.MISSING_PARAM) + + sess = self._get_session_info( + authdict['session'] + ) + if 'creds' not in sess: + sess['creds'] = {} + creds = sess['creds'] + + result = yield self.checkers[stagetype](authdict, clientip) + if result: + creds[stagetype] = result + self._save_session(sess) + defer.returnValue(True) + defer.returnValue(False) + + @defer.inlineCallbacks + def _check_password_auth(self, authdict, _): + if "user" not in authdict or "password" not in authdict: + raise LoginError(400, "", Codes.MISSING_PARAM) + + user = authdict["user"] + password = authdict["password"] + if not user.startswith('@'): + user = UserID.create(user, self.hs.hostname).to_string() + + user_info = yield self.store.get_user_by_id(user_id=user) + if not user_info: + logger.warn("Attempted to login as %s but they do not exist", user) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + stored_hash = user_info[0]["password_hash"] + if bcrypt.checkpw(password, stored_hash): + defer.returnValue(user) + else: + logger.warn("Failed password login for user %s", user) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + @defer.inlineCallbacks + def _check_recaptcha(self, authdict, clientip): + try: + user_response = authdict["response"] + except KeyError: + # Client tried to provide captcha but didn't give the parameter: + # bad request. + raise LoginError( + 400, "Captcha response is required", + errcode=Codes.CAPTCHA_NEEDED + ) + + logger.info( + "Submitting recaptcha response %s with remoteip %s", + user_response, clientip + ) + + # TODO: get this from the homeserver rather than creating a new one for + # each request + try: + client = SimpleHttpClient(self.hs) + data = yield client.post_urlencoded_get_json( + "https://www.google.com/recaptcha/api/siteverify", + args={ + 'secret': self.hs.config.recaptcha_private_key, + 'response': user_response, + 'remoteip': clientip, + } + ) + except PartialDownloadError as pde: + # Twisted is silly + data = pde.response + resp_body = simplejson.loads(data) + if 'success' in resp_body and resp_body['success']: + defer.returnValue(True) + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + @defer.inlineCallbacks + def _check_email_identity(self, authdict, _): + yield run_on_reactor() + + if 'threepid_creds' not in authdict: + raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) + + threepid_creds = authdict['threepid_creds'] + identity_handler = self.hs.get_handlers().identity_handler + + logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,)) + threepid = yield identity_handler.threepid_from_creds(threepid_creds) + + if not threepid: + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + threepid['threepid_creds'] = authdict['threepid_creds'] + + defer.returnValue(threepid) + + @defer.inlineCallbacks + def _check_dummy_auth(self, authdict, _): + yield run_on_reactor() + defer.returnValue(True) + + def _get_params_recaptcha(self): + return {"public_key": self.hs.config.recaptcha_public_key} + + def _auth_dict_for_flows(self, flows, session): + public_flows = [] + for f in flows: + public_flows.append(f) + + get_params = { + LoginType.RECAPTCHA: self._get_params_recaptcha, + } + + params = {} + + for f in public_flows: + for stage in f: + if stage in get_params and stage not in params: + params[stage] = get_params[stage]() + + return { + "session": session['id'], + "flows": [{"stages": f} for f in public_flows], + "params": params + } + + def _get_session_info(self, session_id): + if session_id not in self.sessions: + session_id = None + + if not session_id: + # create a new session + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + self.sessions[session_id] = { + "id": session_id, + } + + return self.sessions[session_id] + + def _save_session(self, session): + # TODO: Persistent storage + logger.debug("Saving session %s", session) + self.sessions[session["id"]] = session + + def _remove_session(self, session): + logger.debug("Removing session %s", session) + del self.sessions[session["id"]] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8aceac28cf..98148c13d7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -179,7 +179,7 @@ class FederationHandler(BaseHandler): # it's probably a good idea to mark it as not in retry-state # for sending (although this is a bit of a leap) retry_timings = yield self.store.get_destination_retry_timings(origin) - if (retry_timings and retry_timings.retry_last_ts): + if retry_timings and retry_timings["retry_last_ts"]: self.store.set_destination_retry_timings(origin, 0, 0) room = yield self.store.get_room(event.room_id) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py new file mode 100644 index 0000000000..ad8246b58c --- /dev/null +++ b/synapse/handlers/identity.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for interacting with Identity Servers""" +from twisted.internet import defer + +from synapse.api.errors import ( + CodeMessageException +) +from ._base import BaseHandler +from synapse.http.client import SimpleHttpClient +from synapse.util.async import run_on_reactor + +import json +import logging + +logger = logging.getLogger(__name__) + + +class IdentityHandler(BaseHandler): + + def __init__(self, hs): + super(IdentityHandler, self).__init__(hs) + + @defer.inlineCallbacks + def threepid_from_creds(self, creds): + yield run_on_reactor() + + # TODO: get this from the homeserver rather than creating a new one for + # each request + http_client = SimpleHttpClient(self.hs) + # XXX: make this configurable! + # trustedIdServers = ['matrix.org', 'localhost:8090'] + trustedIdServers = ['matrix.org'] + if not creds['id_server'] in trustedIdServers: + logger.warn('%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', creds['id_server']) + defer.returnValue(None) + + data = {} + try: + data = yield http_client.get_json( + "https://%s%s" % ( + creds['id_server'], + "/_matrix/identity/api/v1/3pid/getValidated3pid" + ), + {'sid': creds['sid'], 'client_secret': creds['client_secret']} + ) + except CodeMessageException as e: + data = json.loads(e.msg) + + if 'medium' in data: + defer.returnValue(data) + defer.returnValue(None) + + @defer.inlineCallbacks + def bind_threepid(self, creds, mxid): + yield run_on_reactor() + logger.debug("binding threepid %r to %s", creds, mxid) + http_client = SimpleHttpClient(self.hs) + data = None + try: + data = yield http_client.post_urlencoded_get_json( + "https://%s%s" % ( + creds['id_server'], "/_matrix/identity/api/v1/3pid/bind" + ), + { + 'sid': creds['sid'], + 'client_secret': creds['client_secret'], + 'mxid': mxid, + } + ) + logger.debug("bound threepid %r to %s", creds, mxid) + except CodeMessageException as e: + data = json.loads(e.msg) + defer.returnValue(data) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index 7447800460..91d87d503d 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -16,13 +16,9 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import LoginError, Codes, CodeMessageException -from synapse.http.client import SimpleHttpClient -from synapse.util.emailutils import EmailException -import synapse.util.emailutils as emailutils +from synapse.api.errors import LoginError, Codes import bcrypt -import json import logging logger = logging.getLogger(__name__) @@ -57,7 +53,7 @@ class LoginHandler(BaseHandler): logger.warn("Attempted to login as %s but they do not exist", user) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - stored_hash = user_info[0]["password_hash"] + stored_hash = user_info["password_hash"] if bcrypt.checkpw(password, stored_hash): # generate an access token and store it. token = self.reg_handler._generate_token(user) @@ -69,48 +65,19 @@ class LoginHandler(BaseHandler): raise LoginError(403, "", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def reset_password(self, user_id, email): - is_valid = yield self._check_valid_association(user_id, email) - logger.info("reset_password user=%s email=%s valid=%s", user_id, email, - is_valid) - if is_valid: - try: - # send an email out - emailutils.send_email( - smtp_server=self.hs.config.email_smtp_server, - from_addr=self.hs.config.email_from_address, - to_addr=email, - subject="Password Reset", - body="TODO." - ) - except EmailException as e: - logger.exception(e) + def set_password(self, user_id, newpassword, token_id=None): + password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) - @defer.inlineCallbacks - def _check_valid_association(self, user_id, email): - identity = yield self._query_email(email) - if identity and "mxid" in identity: - if identity["mxid"] == user_id: - defer.returnValue(True) - return - defer.returnValue(False) + yield self.store.user_set_password_hash(user_id, password_hash) + yield self.store.user_delete_access_tokens_apart_from(user_id, token_id) + yield self.hs.get_pusherpool().remove_pushers_by_user_access_token( + user_id, token_id + ) + yield self.store.flush_user(user_id) @defer.inlineCallbacks - def _query_email(self, email): - http_client = SimpleHttpClient(self.hs) - try: - data = yield http_client.get_json( - # TODO FIXME This should be configurable. - # XXX: ID servers need to use HTTPS - "http://%s%s" % ( - "matrix.org:8090", "/_matrix/identity/api/v1/lookup" - ), - { - 'medium': 'email', - 'address': email - } - ) - defer.returnValue(data) - except CodeMessageException as e: - data = json.loads(e.msg) - defer.returnValue(data) + def add_threepid(self, user_id, medium, address, validated_at): + yield self.store.user_add_threepid( + user_id, medium, address, validated_at, + self.hs.get_clock().time_msec() + ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7b9685be7f..9667bb8674 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -274,7 +274,8 @@ class MessageHandler(BaseHandler): if limit is None: limit = 10 - for event in room_list: + @defer.inlineCallbacks + def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -290,12 +291,19 @@ class MessageHandler(BaseHandler): rooms_ret.append(d) if event.membership != Membership.JOIN: - continue + return try: - messages, token = yield self.store.get_recent_events_for_room( - event.room_id, - limit=limit, - end_token=now_token.room_key, + (messages, token), current_state = yield defer.gatherResults( + [ + self.store.get_recent_events_for_room( + event.room_id, + limit=limit, + end_token=now_token.room_key, + ), + self.state_handler.get_current_state( + event.room_id + ), + ] ) start_token = now_token.copy_and_replace("room_key", token[0]) @@ -311,9 +319,6 @@ class MessageHandler(BaseHandler): "end": end_token.to_string(), } - current_state = yield self.state_handler.get_current_state( - event.room_id - ) d["state"] = [ serialize_event(c, time_now, as_client_event) for c in current_state.values() @@ -321,6 +326,11 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") + yield defer.gatherResults( + [handle_room(e) for e in room_list], + consumeErrors=True + ) + ret = { "rooms": rooms_ret, "presence": presence, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 571eacd343..42cd528908 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -858,22 +858,24 @@ class PresenceEventSource(object): presence = self.hs.get_handlers().presence_handler cachemap = presence._user_cachemap + + max_serial = presence._user_cachemap_latest_serial + clock = self.clock - latest_serial = None + latest_serial = 0 updates = [] # TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in cachemap.keys(): cached = cachemap[observed_user] - if cached.serial <= from_key: + if cached.serial <= from_key or cached.serial > max_serial: continue if not (yield self.is_visible(observer_user, observed_user)): continue - if latest_serial is None or cached.serial > latest_serial: - latest_serial = cached.serial + latest_serial = max(cached.serial, latest_serial) updates.append(cached.make_event(user=observed_user, clock=clock)) # TODO(paul): limit @@ -882,6 +884,10 @@ class PresenceEventSource(object): if serial < from_key: break + if serial > max_serial: + continue + + latest_serial = max(latest_serial, serial) for u in user_ids: updates.append({ "type": "m.presence", diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c25e321099..7b68585a17 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -18,18 +18,15 @@ from twisted.internet import defer from synapse.types import UserID from synapse.api.errors import ( - AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, - CodeMessageException + AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) from ._base import BaseHandler import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor -from synapse.http.client import SimpleHttpClient from synapse.http.client import CaptchaServerHttpClient import base64 import bcrypt -import json import logging import urllib @@ -45,6 +42,30 @@ class RegistrationHandler(BaseHandler): self.distributor.declare("registered_user") @defer.inlineCallbacks + def check_username(self, localpart): + yield run_on_reactor() + + if urllib.quote(localpart) != localpart: + raise SynapseError( + 400, + "User ID must only contain characters which do not" + " require URL encoding." + ) + + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + + yield self.check_user_id_is_valid(user_id) + + u = yield self.store.get_user_by_id(user_id) + if u: + raise SynapseError( + 400, + "User ID already taken.", + errcode=Codes.USER_IN_USE, + ) + + @defer.inlineCallbacks def register(self, localpart=None, password=None): """Registers a new client on the server. @@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler): password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) if localpart: - if localpart and urllib.quote(localpart) != localpart: - raise SynapseError( - 400, - "User ID must only contain characters which do not" - " require URL encoding." - ) + yield self.check_username(localpart) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) - token = self._generate_token(user_id) yield self.store.register( user_id=user_id, @@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): - """Checks a recaptcha is correct.""" + """ + Checks a recaptcha is correct. + + Used only by c/s api v1 + """ captcha_response = yield self._validate_captcha( ip, @@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def register_email(self, threepidCreds): - """Registers emails with an identity server.""" + """ + Registers emails with an identity server. + + Used only by c/s api v1 + """ for c in threepidCreds: logger.info("validating theeepidcred sid %s on id server %s", c['sid'], c['idServer']) try: - threepid = yield self._threepid_from_creds(c) + identity_handler = self.hs.get_handlers().identity_handler + threepid = yield identity_handler.threepid_from_creds(c) except: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") @@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): - """Links emails with a user ID and informs an identity server.""" + """Links emails with a user ID and informs an identity server. + + Used only by c/s api v1 + """ # Now we have a matrix ID, bind it to the threepids we were given for c in threepidCreds: + identity_handler = self.hs.get_handlers().identity_handler # XXX: This should be a deferred list, shouldn't it? - yield self._bind_threepid(c, user_id) + yield identity_handler.bind_threepid(c, user_id) @defer.inlineCallbacks def check_user_id_is_valid(self, user_id): @@ -227,61 +254,11 @@ class RegistrationHandler(BaseHandler): return "-" + stringutils.random_string(18) @defer.inlineCallbacks - def _threepid_from_creds(self, creds): - # TODO: get this from the homeserver rather than creating a new one for - # each request - http_client = SimpleHttpClient(self.hs) - # XXX: make this configurable! - trustedIdServers = ['matrix.org:8090', 'matrix.org'] - if not creds['idServer'] in trustedIdServers: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', creds['idServer']) - defer.returnValue(None) - - data = {} - try: - data = yield http_client.get_json( - # XXX: This should be HTTPS - "http://%s%s" % ( - creds['idServer'], - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} - ) - except CodeMessageException as e: - data = json.loads(e.msg) - - if 'medium' in data: - defer.returnValue(data) - defer.returnValue(None) - - @defer.inlineCallbacks - def _bind_threepid(self, creds, mxid): - yield - logger.debug("binding threepid") - http_client = SimpleHttpClient(self.hs) - data = None - try: - data = yield http_client.post_urlencoded_get_json( - # XXX: Change when ID servers are all HTTPS - "http://%s%s" % ( - creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'clientSecret': creds['clientSecret'], - 'mxid': mxid, - } - ) - logger.debug("bound threepid") - except CodeMessageException as e: - data = json.loads(e.msg) - defer.returnValue(data) - - @defer.inlineCallbacks def _validate_captcha(self, ip_addr, private_key, challenge, response): """Validates the captcha provided. + Used only by c/s api v1 + Returns: dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. @@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _submit_captcha(self, ip_addr, private_key, challenge, response): + """ + Used only by c/s api v1 + """ # TODO: get this from the homeserver rather than creating a new one for # each request client = CaptchaServerHttpClient(self.hs) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f9fc4a9c98..47456a28e9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler): "state_default": 50, "ban": 50, "kick": 50, - "redact": 50 + "redact": 50, + "invite": 0, }, ) @@ -311,25 +312,6 @@ class RoomMemberHandler(BaseHandler): defer.returnValue(chunk_data) @defer.inlineCallbacks - def get_room_member(self, room_id, member_user_id, auth_user_id): - """Retrieve a room member from a room. - - Args: - room_id : The room the member is in. - member_user_id : The member's user ID - auth_user_id : The user ID of the user making this request. - Returns: - The room member, or None if this member does not exist. - Raises: - SynapseError if something goes wrong. - """ - yield self.auth.check_joined_room(room_id, auth_user_id) - - member = yield self.store.get_room_member(user_id=member_user_id, - room_id=room_id) - defer.returnValue(member) - - @defer.inlineCallbacks def change_membership(self, event, context, do_auth=True): """ Change the membership status of a user in a room. diff --git a/synapse/http/client.py b/synapse/http/client.py index 2ae1c4d3a4..e8a5dedab4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient): """ Separate HTTP client for talking to google's captcha servers Only slightly special because accepts partial download responses + + used only by c/s api v1 """ @defer.inlineCallbacks diff --git a/synapse/http/server.py b/synapse/http/server.py index b3706889ab..05636e683b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -131,10 +131,10 @@ class HttpServer(object): """ def register_path(self, method, path_pattern, callback): - """ Register a callback that get's fired if we receive a http request + """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. - If the regex contains groups these get's passed to the calback via + If the regex contains groups these gets passed to the calback via an unpacked tuple. Args: @@ -153,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource): Resources. Register callbacks via register_path() + + Callbacks can return a tuple of status code and a dict in which case the + the dict will automatically be sent to the client as a JSON object. + + The JsonResource is primarily intended for returning JSON, but callbacks + may send something other than JSON, they may do so by using the methods + on the request object and instead returning None. """ isLeaf = True @@ -185,9 +192,8 @@ class JsonResource(HttpServer, resource.Resource): interface=self.hs.config.bind_host ) - # Gets called by twisted def render(self, request): - """ This get's called by twisted every time someone sends us a request. + """ This gets called by twisted every time someone sends us a request. """ self._async_render(request) return server.NOT_DONE_YET @@ -195,7 +201,7 @@ class JsonResource(HttpServer, resource.Resource): @request_handler @defer.inlineCallbacks def _async_render(self, request): - """ This get's called by twisted every time someone sends us a request. + """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. """ @@ -227,9 +233,11 @@ class JsonResource(HttpServer, resource.Resource): urllib.unquote(u).decode("UTF-8") for u in m.groups() ] - code, response = yield callback(request, *args) + callback_return = yield callback(request, *args) + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) - self._send_response(request, code, response) response_timer.inc_by( self.clock.time_msec() - start, request.method, servlet_classname ) diff --git a/synapse/notifier.py b/synapse/notifier.py index d750a6fcf7..ea854482b5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -98,7 +98,7 @@ class _NotificationListener(object): try: notifier.clock.cancel_call_later(self.timer) except: - logger.exception("Failed to cancel notifier timer") + logger.warn("Failed to cancel notifier timer") class Notifier(object): diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 0727f772a5..5575c847f9 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -253,7 +253,8 @@ class Pusher(object): self.user_name, config, timeout=0) self.last_token = chunk['end'] self.store.update_pusher_last_token( - self.app_id, self.pushkey, self.last_token) + self.app_id, self.pushkey, self.user_name, self.last_token + ) logger.info("Pusher %s for user %s starting from token %s", self.pushkey, self.user_name, self.last_token) @@ -314,7 +315,7 @@ class Pusher(object): pk ) yield self.hs.get_pusherpool().remove_pusher( - self.app_id, pk + self.app_id, pk, self.user_name ) if not self.alive: @@ -326,6 +327,7 @@ class Pusher(object): self.store.update_pusher_last_token_and_success( self.app_id, self.pushkey, + self.user_name, self.last_token, self.clock.time_msec() ) @@ -334,6 +336,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since) else: if not self.failing_since: @@ -341,6 +344,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since ) @@ -358,6 +362,7 @@ class Pusher(object): self.store.update_pusher_last_token( self.app_id, self.pushkey, + self.user_name, self.last_token ) @@ -365,6 +370,7 @@ class Pusher(object): self.store.update_pusher_failing_since( self.app_id, self.pushkey, + self.user_name, self.failing_since ) else: diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 90babd7224..58fe6ad747 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,10 +19,7 @@ from twisted.internet import defer from httppusher import HttpPusher from synapse.push import PusherConfigException -from syutil.jsonutil import encode_canonical_json - import logging -import simplejson as json logger = logging.getLogger(__name__) @@ -52,12 +49,10 @@ class PusherPool: @defer.inlineCallbacks def start(self): pushers = yield self.store.get_all_pushers() - for p in pushers: - p['data'] = json.loads(p['data']) self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_name, profile_tag, kind, app_id, + def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, lang, data): # we try to create the pusher just to validate the config: it # will then get pulled out of the database, @@ -71,7 +66,7 @@ class PusherPool: "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "pushkey_ts": self.hs.get_clock().time_msec(), + "ts": self.hs.get_clock().time_msec(), "lang": lang, "data": data, "last_token": None, @@ -79,17 +74,50 @@ class PusherPool: "failing_since": None }) yield self._add_pusher_to_store( - user_name, profile_tag, kind, app_id, + user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, lang, data ) @defer.inlineCallbacks - def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id, - app_display_name, device_display_name, + def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, + not_user_id): + to_remove = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id, pushkey + ) + for p in to_remove: + if p['user_name'] != not_user_id: + logger.info( + "Removing pusher for app id %s, pushkey %s, user %s", + app_id, pushkey, p['user_name'] + ) + self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + + @defer.inlineCallbacks + def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): + all = yield self.store.get_all_pushers() + logger.info( + "Removing all pushers for user %s except access token %s", + user_id, not_access_token_id + ) + for p in all: + if ( + p['user_name'] == user_id and + p['access_token'] != not_access_token_id + ): + logger.info( + "Removing pusher for app id %s, pushkey %s, user %s", + p['app_id'], p['pushkey'], p['user_name'] + ) + self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + + @defer.inlineCallbacks + def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, pushkey, lang, data): yield self.store.add_pusher( user_name=user_name, + access_token=access_token, profile_tag=profile_tag, kind=kind, app_id=app_id, @@ -98,9 +126,9 @@ class PusherPool: pushkey=pushkey, pushkey_ts=self.hs.get_clock().time_msec(), lang=lang, - data=encode_canonical_json(data).decode("UTF-8"), + data=data, ) - self._refresh_pusher((app_id, pushkey)) + self._refresh_pusher(app_id, pushkey, user_name) def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': @@ -112,7 +140,7 @@ class PusherPool: app_display_name=pusherdict['app_display_name'], device_display_name=pusherdict['device_display_name'], pushkey=pusherdict['pushkey'], - pushkey_ts=pusherdict['pushkey_ts'], + pushkey_ts=pusherdict['ts'], data=pusherdict['data'], last_token=pusherdict['last_token'], last_success=pusherdict['last_success'], @@ -125,30 +153,42 @@ class PusherPool: ) @defer.inlineCallbacks - def _refresh_pusher(self, app_id_pushkey): - p = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id_pushkey + def _refresh_pusher(self, app_id, pushkey, user_name): + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id, pushkey ) - p['data'] = json.loads(p['data']) - self._start_pushers([p]) + p = None + for r in resultlist: + if r['user_name'] == user_name: + p = r + + if p: + + self._start_pushers([p]) def _start_pushers(self, pushers): logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: p = self._create_pusher(pusherdict) if p: - fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) + fullid = "%s:%s:%s" % ( + pusherdict['app_id'], + pusherdict['pushkey'], + pusherdict['user_name'] + ) if fullid in self.pushers: self.pushers[fullid].stop() self.pushers[fullid] = p p.start() @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey): - fullid = "%s:%s" % (app_id, pushkey) + def remove_pusher(self, app_id, pushkey, user_name): + fullid = "%s:%s:%s" % (app_id, pushkey, user_name) if fullid in self.pushers: logger.info("Stopping pusher %s", fullid) self.pushers[fullid].stop() del self.pushers[fullid] - yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) + yield self.store.delete_pusher_by_app_id_pushkey_user_name( + app_id, pushkey, user_name + ) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 72332bdb10..504a5e432f 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet): self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_auth() + self.auth = hs.get_v1auth() self.txns = HttpTransactionStore() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 6045e86f34..c83287c028 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet): and 'kind' in content and content['kind'] is None): yield pusher_pool.remove_pusher( - content['app_id'], content['pushkey'] + content['app_id'], content['pushkey'], user_name=user.to_string() ) defer.returnValue((200, {})) @@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet): raise SynapseError(400, "Missing parameters: "+','.join(missing), errcode=Codes.MISSING_PARAM) + append = False + if 'append' in content: + append = content['append'] + + if not append: + yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + app_id=content['app_id'], + pushkey=content['pushkey'], + not_user_id=user.to_string() + ) + try: yield pusher_pool.add_pusher( user_name=user.to_string(), + access_token=client.token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index bca65f2a6a..28d95b2729 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -15,7 +15,10 @@ from . import ( sync, - filter + filter, + account, + register, + auth ) from synapse.http.server import JsonResource @@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource): def register_servlets(client_resource, hs): sync.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource) + account.register_servlets(hs, client_resource) + register.register_servlets(hs, client_resource) + auth.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 22dc5cb862..4540e8dcf7 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,9 +17,11 @@ """ from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.api.errors import SynapseError import re import logging +import simplejson logger = logging.getLogger(__name__) @@ -36,3 +38,23 @@ def client_v2_pattern(path_regex): SRE_Pattern """ return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) + + +def parse_request_allow_empty(request): + content = request.content.read() + if content is None or content == '': + return None + try: + return simplejson.loads(content) + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.") + + +def parse_json_dict_from_request(request): + try: + content = simplejson.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.") + return content + except simplejson.JSONDecodeError: + raise SynapseError(400, "Content not JSON.") diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py new file mode 100644 index 0000000000..3e522ad39b --- /dev/null +++ b/synapse/rest/client/v2_alpha/account.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import LoginError, SynapseError, Codes +from synapse.http.servlet import RestServlet +from synapse.util.async import run_on_reactor + +from ._base import client_v2_pattern, parse_json_dict_from_request + +import logging + + +logger = logging.getLogger(__name__) + + +class PasswordRestServlet(RestServlet): + PATTERN = client_v2_pattern("/account/password") + + def __init__(self, hs): + super(PasswordRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.login_handler = hs.get_handlers().login_handler + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_dict_from_request(request) + + authed, result, params = yield self.auth_handler.check_auth([ + [LoginType.PASSWORD], + [LoginType.EMAIL_IDENTITY] + ], body) + + if not authed: + defer.returnValue((401, result)) + + user_id = None + + if LoginType.PASSWORD in result: + # if using password, they should also be logged in + auth_user, client = yield self.auth.get_user_by_req(request) + if auth_user.to_string() != result[LoginType.PASSWORD]: + raise LoginError(400, "", Codes.UNKNOWN) + user_id = auth_user.to_string() + elif LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if 'medium' not in threepid or 'address' not in threepid: + raise SynapseError(500, "Malformed threepid") + # if using email, we must know about the email they're authing with! + threepid_user = yield self.hs.get_datastore().get_user_by_threepid( + threepid['medium'], threepid['address'] + ) + if not threepid_user: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + if 'new_password' not in params: + raise SynapseError(400, "", Codes.MISSING_PARAM) + new_password = params['new_password'] + + yield self.login_handler.set_password( + user_id, new_password, None + ) + + defer.returnValue((200, {})) + + def on_OPTIONS(self, _): + return 200, {} + + +class ThreepidRestServlet(RestServlet): + PATTERN = client_v2_pattern("/account/3pid") + + def __init__(self, hs): + super(ThreepidRestServlet, self).__init__() + self.hs = hs + self.login_handler = hs.get_handlers().login_handler + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + + @defer.inlineCallbacks + def on_GET(self, request): + yield run_on_reactor() + + auth_user, _ = yield self.auth.get_user_by_req(request) + + threepids = yield self.hs.get_datastore().user_get_threepids( + auth_user.to_string() + ) + + defer.returnValue((200, {'threepids': threepids})) + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_json_dict_from_request(request) + + if 'threePidCreds' not in body: + raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) + threePidCreds = body['threePidCreds'] + + auth_user, client = yield self.auth.get_user_by_req(request) + + threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) + + if not threepid: + raise SynapseError( + 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED + ) + + for reqd in ['medium', 'address', 'validated_at']: + if reqd not in threepid: + logger.warn("Couldn't add 3pid: invalid response from ID sevrer") + raise SynapseError(500, "Invalid response from ID Server") + + yield self.login_handler.add_threepid( + auth_user.to_string(), + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + if 'bind' in body and body['bind']: + logger.debug( + "Binding emails %s to %s", + threepid, auth_user.to_string() + ) + yield self.identity_handler.bind_threepid( + threePidCreds, auth_user.to_string() + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + PasswordRestServlet(hs).register(http_server) + ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py new file mode 100644 index 0000000000..4c726f05f5 --- /dev/null +++ b/synapse/rest/client/v2_alpha/auth.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern + +import logging + + +logger = logging.getLogger(__name__) + +RECAPTCHA_TEMPLATE = """ +<html> +<head> +<title>Authentication</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<script src="https://www.google.com/recaptcha/api.js" + async defer></script> +<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +function captchaDone() { + $('#registrationForm').submit(); +} +</script> +</head> +<body> +<form id="registrationForm" method="post" action="%(myurl)s"> + <div> + <p> + Hello! We need to prevent computer programs and other automated + things from creating accounts on this server. + </p> + <p> + Please verify that you're not a robot. + </p> + <input type="hidden" name="session" value="%(session)s" /> + <div class="g-recaptcha" + data-sitekey="%(sitekey)s" + data-callback="captchaDone"> + </div> + <noscript> + <input type="submit" value="All Done" /> + </noscript> + </div> + </div> +</form> +</body> +</html> +""" + +SUCCESS_TEMPLATE = """ +<html> +<head> +<title>Success!</title> +<meta name='viewport' content='width=device-width, initial-scale=1, + user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> +<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> +<script> +if (window.onAuthDone != undefined) { + window.onAuthDone(); +} +</script> +</head> +<body> + <div> + <p>Thank you</p> + <p>You may now close this window and return to the application</p> + </div> +</body> +</html> +""" + + +class AuthRestServlet(RestServlet): + """ + Handles Client / Server API authentication in any situations where it + cannot be handled in the normal flow (with requests to the same endpoint). + Current use is for web fallback auth. + """ + PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + + def __init__(self, hs): + super(AuthRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.registration_handler = hs.get_handlers().registration_handler + + @defer.inlineCallbacks + def on_GET(self, request, stagetype): + yield + if stagetype == LoginType.RECAPTCHA: + if ('session' not in request.args or + len(request.args['session']) == 0): + raise SynapseError(400, "No session supplied") + + session = request.args["session"][0] + + html = RECAPTCHA_TEMPLATE % { + 'session': session, + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + ), + 'sitekey': self.hs.config.recaptcha_public_key, + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + request.finish() + defer.returnValue(None) + else: + raise SynapseError(404, "Unknown auth stage type") + + @defer.inlineCallbacks + def on_POST(self, request, stagetype): + yield + if stagetype == "m.login.recaptcha": + if ('g-recaptcha-response' not in request.args or + len(request.args['g-recaptcha-response'])) == 0: + raise SynapseError(400, "No captcha response supplied") + if ('session' not in request.args or + len(request.args['session'])) == 0: + raise SynapseError(400, "No session supplied") + + session = request.args['session'][0] + + authdict = { + 'response': request.args['g-recaptcha-response'][0], + 'session': session, + } + + success = yield self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, + authdict, + self.hs.get_ip_from_request(request) + ) + + if success: + html = SUCCESS_TEMPLATE + else: + html = RECAPTCHA_TEMPLATE % { + 'session': session, + 'myurl': "%s/auth/%s/fallback/web" % ( + CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + ), + 'sitekey': self.hs.config.recaptcha_public_key, + } + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + request.finish() + + defer.returnValue(None) + else: + raise SynapseError(404, "Unknown auth stage type") + + def on_OPTIONS(self, _): + return 200, {} + + +def register_servlets(hs, http_server): + AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py new file mode 100644 index 0000000000..3640fb4a29 --- /dev/null +++ b/synapse/rest/client/v2_alpha/register.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError, Codes +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_request_allow_empty + +import logging +import hmac +from hashlib import sha1 +from synapse.util.async import run_on_reactor + + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + compare_digest = lambda a, b: a == b + + +logger = logging.getLogger(__name__) + + +class RegisterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/register") + + def __init__(self, hs): + super(RegisterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_handlers().auth_handler + self.registration_handler = hs.get_handlers().registration_handler + self.identity_handler = hs.get_handlers().identity_handler + self.login_handler = hs.get_handlers().login_handler + + @defer.inlineCallbacks + def on_POST(self, request): + yield run_on_reactor() + + body = parse_request_allow_empty(request) + if 'password' not in body: + raise SynapseError(400, "", Codes.MISSING_PARAM) + + if 'username' in body: + desired_username = body['username'] + yield self.registration_handler.check_username(desired_username) + + is_using_shared_secret = False + is_application_server = False + + service = None + if 'access_token' in request.args: + service = yield self.auth.get_appservice_by_req(request) + + if self.hs.config.enable_registration_captcha: + flows = [ + [LoginType.RECAPTCHA], + [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA] + ] + else: + flows = [ + [LoginType.DUMMY], + [LoginType.EMAIL_IDENTITY] + ] + + if service: + is_application_server = True + elif 'mac' in body: + # Check registration-specific shared secret auth + if 'username' not in body: + raise SynapseError(400, "", Codes.MISSING_PARAM) + self._check_shared_secret_auth( + body['username'], body['mac'] + ) + is_using_shared_secret = True + else: + authed, result, params = yield self.auth_handler.check_auth( + flows, body, self.hs.get_ip_from_request(request) + ) + + if not authed: + defer.returnValue((401, result)) + + can_register = ( + not self.hs.config.disable_registration + or is_application_server + or is_using_shared_secret + ) + if not can_register: + raise SynapseError(403, "Registration has been disabled") + + if 'password' not in params: + raise SynapseError(400, "", Codes.MISSING_PARAM) + desired_username = params['username'] if 'username' in params else None + new_password = params['password'] + + (user_id, token) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password + ) + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + + for reqd in ['medium', 'address', 'validated_at']: + if reqd not in threepid: + logger.info("Can't add incomplete 3pid") + else: + yield self.login_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + if 'bind_email' in params and params['bind_email']: + logger.info("bind_email specified: binding") + + emailThreepid = result[LoginType.EMAIL_IDENTITY] + threepid_creds = emailThreepid['threepid_creds'] + logger.debug("Binding emails %s to %s" % ( + emailThreepid, user_id + )) + yield self.identity_handler.bind_threepid(threepid_creds, user_id) + else: + logger.info("bind_email not specified: not binding email") + + result = { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def on_OPTIONS(self, _): + return 200, {} + + def _check_shared_secret_auth(self, username, mac): + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + user = username.encode("utf-8") + + # str() because otherwise hmac complains that 'unicode' does not + # have the buffer interface + got_mac = str(mac) + + want_mac = hmac.new( + key=self.hs.config.registration_shared_secret, + msg=user, + digestmod=sha1, + ).hexdigest() + + if compare_digest(want_mac, got_mac): + return True + else: + raise SynapseError( + 403, "HMAC incorrect", + ) + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index a602b425e3..d61a228c36 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -65,6 +65,7 @@ class BaseHomeServer(object): 'replication_layer', 'datastore', 'handlers', + 'v1auth', 'auth', 'rest_servlet_factory', 'state_handler', @@ -182,6 +183,15 @@ class HomeServer(BaseHomeServer): def build_auth(self): return Auth(self) + def build_v1auth(self): + orf = Auth(self) + # Matrix spec makes no reference to what HTTP status code is returned, + # but the V1 API uses 403 where it means 401, and the webclient + # relies on this behaviour, so V1 gets its own copy of the auth + # with backwards compat behaviour. + orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 + return orf + def build_state_handler(self): return StateHandler(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f4dec70393..61215bbc7b 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 15 +SCHEMA_VERSION = 16 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -104,14 +104,16 @@ class DataStore(RoomMemberStore, RoomStore, self.client_ip_last_seen.prefill(*key + (now,)) - yield self._simple_insert( + yield self._simple_upsert( "user_ips", - { - "user": user.to_string(), + keyvalues={ + "user_id": user.to_string(), "access_token": access_token, - "device_id": device_id, "ip": ip, "user_agent": user_agent, + }, + values={ + "device_id": device_id, "last_seen": now, }, desc="insert_client_ip", @@ -120,7 +122,7 @@ class DataStore(RoomMemberStore, RoomStore, def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", - keyvalues={"user": user.to_string()}, + keyvalues={"user_id": user.to_string()}, retcols=[ "device_id", "access_token", "ip", "user_agent", "last_seen" ], @@ -148,21 +150,23 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn): +def prepare_database(db_conn, database_engine): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ try: cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur) + version_info = _get_or_create_schema_state(cur, database_engine) if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database(cur, user_version, delta_files, upgraded) + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) else: - _setup_new_database(cur) + _setup_new_database(cur, database_engine) - cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) cur.close() db_conn.commit() @@ -171,7 +175,7 @@ def prepare_database(db_conn): raise -def _setup_new_database(cur): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -225,31 +229,30 @@ def _setup_new_database(cur): directory_entries = os.listdir(sql_dir) - sql_script = "BEGIN TRANSACTION;\n" for filename in fnmatch.filter(directory_entries, "*.sql"): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) - sql_script += read_schema(sql_loc) - sql_script += "\n" - sql_script += "COMMIT TRANSACTION;" - cur.executescript(sql_script) + executescript(cur, sql_loc) cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (max_current_ver, False) + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) ) _upgrade_existing_database( cur, current_version=max_current_ver, applied_delta_files=[], - upgraded=False + upgraded=False, + database_engine=database_engine, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded): + upgraded, database_engine): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -305,6 +308,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, if not upgraded: start_ver += 1 + logger.debug("applied_delta_files: %s", applied_delta_files) + for v in range(start_ver, SCHEMA_VERSION + 1): logger.debug("Upgrading schema to v%d", v) @@ -321,6 +326,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, directory_entries.sort() for file_name in directory_entries: relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) if relative_path in applied_delta_files: continue @@ -342,9 +348,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module.run_upgrade(cur) elif ext == ".sql": # A plain old .sql file, just read and execute it - delta_schema = read_schema(absolute_path) logger.debug("Applying schema %s", relative_path) - cur.executescript(delta_schema) + executescript(cur, absolute_path) else: # Not a valid delta file. logger.warn( @@ -356,24 +361,82 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, # Mark as done. cur.execute( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), (v, relative_path) ) cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", + database_engine.convert_param_style( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), (v, True) ) -def _get_or_create_schema_state(txn): +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. schema_path = os.path.join( dir_path, "schema", "schema_version.sql", ) - create_schema = read_schema(schema_path) - txn.executescript(create_schema) + executescript(txn, schema_path) txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() @@ -382,10 +445,13 @@ def _get_or_create_schema_state(txn): if current_version: txn.execute( - "SELECT file FROM applied_schema_deltas WHERE version >= ?", + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), (current_version,) ) - return current_version, txn.fetchall(), upgraded + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded return None @@ -417,7 +483,19 @@ def prepare_sqlite3_database(db_conn): if row and row[0]: db_conn.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" + "REPLACE INTO schema_version (version, upgraded)" " VALUES (?,?)", (row[0], False) ) + + +def are_all_users_on_domain(txn, database_engine, domain): + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + txn.execute(sql, (pat,)) + num_not_matching = txn.fetchall()[0][0] + if num_not_matching == 0: + return True + return False diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e3e67d8e0d..6017c2a6e8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics +from util.id_generators import IdGenerator, StreamIdGenerator + from twisted.internet import defer from collections import namedtuple, OrderedDict @@ -145,11 +147,12 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name"] + __slots__ = ["txn", "name", "database_engine"] - def __init__(self, txn, name): + def __init__(self, txn, name, database_engine): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) def __getattr__(self, name): return getattr(self.txn, name) @@ -161,26 +164,32 @@ class LoggingTransaction(object): # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) - try: - if args and args[0]: - values = args[0] + sql = self.database_engine.convert_param_style(sql) + + if args and args[0]: + args = list(args) + args[0] = [ + self.database_engine.encode_parameter(a) for a in args[0] + ] + try: sql_logger.debug( - "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)), + "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])), self.name, - *values + *args[0] ) - except: - # Don't let logging failures stop SQL from working - pass + except: + # Don't let logging failures stop SQL from working + pass start = time.time() * 1000 + try: return self.txn.execute( sql, *args, **kwargs ) - except: - logger.exception("[SQL FAIL] {%s}", self.name) - raise + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise finally: msecs = (time.time() * 1000) - start sql_logger.debug("[SQL time] {%s} %f", self.name, msecs) @@ -245,6 +254,14 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self.database_engine = hs.database_engine + + self._stream_id_gen = StreamIdGenerator() + self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) + self._state_groups_id_gen = IdGenerator("state_groups", "id", self) + self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._pushers_id_gen = IdGenerator("pushers", "id", self) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -281,8 +298,11 @@ class SQLBaseStore(object): start_time = time.time() * 1000 - def inner_func(txn, *args, **kwargs): + def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: + if self.database_engine.is_connection_closed(conn): + conn.reconnect() + current_context.copy_to(context) start = time.time() * 1000 txn_id = self._TXN_ID @@ -296,9 +316,48 @@ class SQLBaseStore(object): sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) transaction_logger.debug("[TXN START] {%s}", name) try: - return func(LoggingTransaction(txn, name), *args, **kwargs) - except: - logger.exception("[TXN FAIL] {%s}", name) + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + return func( + LoggingTransaction(txn, name, self.database_engine), + *args, **kwargs + ) + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warn( + "[TXN OPERROR] {%s} %s %d/%d", + name, e, i, N + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = time.time() * 1000 @@ -311,7 +370,7 @@ class SQLBaseStore(object): sql_txn_timer.inc_by(duration, desc) with PreserveLoggingContext(): - result = yield self._db_pool.runInteraction( + result = yield self._db_pool.runWithConnection( inner_func, *args, **kwargs ) defer.returnValue(result) @@ -342,11 +401,11 @@ class SQLBaseStore(object): The result of decoder(results) """ def interaction(txn): - cursor = txn.execute(query, args) + txn.execute(query, args) if decoder: - return decoder(cursor) + return decoder(txn) else: - return cursor.fetchall() + return txn.fetchall() return self.runInteraction(desc, interaction) @@ -356,27 +415,29 @@ class SQLBaseStore(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - def _simple_insert(self, table, values, or_replace=False, or_ignore=False, + @defer.inlineCallbacks + def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): """Executes an INSERT query on the named table. Args: table : string giving the table name values : dict of new column names and values for them - or_replace : bool; if True performs an INSERT OR REPLACE """ - return self.runInteraction( - desc, - self._simple_insert_txn, table, values, or_replace=or_replace, - or_ignore=or_ignore, - ) + try: + yield self.runInteraction( + desc, + self._simple_insert_txn, table, values, + ) + except self.database_engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise @log_function - def _simple_insert_txn(self, txn, table, values, or_replace=False, - or_ignore=False): - sql = "%s INTO %s (%s) VALUES(%s)" % ( - ("INSERT OR REPLACE" if or_replace else - "INSERT OR IGNORE" if or_ignore else "INSERT"), + def _simple_insert_txn(self, txn, table, values): + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in values), ", ".join("?" for k in values) @@ -388,22 +449,26 @@ class SQLBaseStore(object): ) txn.execute(sql, values.values()) - return txn.lastrowid - def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"): + def _simple_upsert(self, table, keyvalues, values, + insertion_values={}, desc="_simple_upsert"): """ Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values + insertion_values (dict): key/values to use when inserting Returns: A deferred """ return self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values + self._simple_upsert_txn, table, keyvalues, values, insertion_values, ) - def _simple_upsert_txn(self, txn, table, keyvalues, values): + def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}): + # We need to lock the table :( + self.database_engine.lock_table(txn, table) + # Try to update sql = "UPDATE %s SET %s WHERE %s" % ( table, @@ -422,6 +487,7 @@ class SQLBaseStore(object): allvalues = {} allvalues.update(keyvalues) allvalues.update(values) + allvalues.update(insertion_values) sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, @@ -489,8 +555,7 @@ class SQLBaseStore(object): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): sql = ( - "SELECT %(retcol)s FROM %(table)s WHERE %(where)s " - "ORDER BY rowid asc" + "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" ) % { "retcol": retcol, "table": table, @@ -548,14 +613,14 @@ class SQLBaseStore(object): retcols : list of strings giving the names of the columns to return """ if keyvalues: - sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( + sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) txn.execute(sql, keyvalues.values()) else: - sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( + sql = "SELECT %s FROM %s" % ( ", ".join(retcols), table ) @@ -607,10 +672,10 @@ class SQLBaseStore(object): def _simple_select_one_txn(self, txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( + select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues) ) txn.execute(select_sql, keyvalues.values()) @@ -648,6 +713,11 @@ class SQLBaseStore(object): updatevalues=updatevalues, ) + # if txn.rowcount == 0: + # raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + return ret return self.runInteraction(desc, func) @@ -860,6 +930,12 @@ class SQLBaseStore(object): result = txn.fetchone() return result[0] if result else None + def get_next_stream_id(self): + with self._next_stream_id_lock: + i = self._next_stream_id + self._next_stream_id += 1 + return i + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying @@ -883,7 +959,7 @@ class Table(object): _select_where_clause = "SELECT %s FROM %s WHERE %s" _select_clause = "SELECT %s FROM %s" - _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" + _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" @classmethod def select_statement(cls, where_clause=None): diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index f8cbb3f323..63d1af4e86 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -366,11 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore): new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = [e.event_id for e in events] + event_ids = json.dumps([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", - (service.id, new_txn_id, json.dumps(event_ids)) + (service.id, new_txn_id, event_ids) ) return AppServiceTransaction( service=service, id=new_txn_id, events=events diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 0199539fea..2b2bdf8615 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -21,8 +21,6 @@ from twisted.internet import defer from collections import namedtuple -import sqlite3 - RoomAliasMapping = namedtuple( "RoomAliasMapping", @@ -91,7 +89,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - except sqlite3.IntegrityError: + except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) @@ -120,12 +118,12 @@ class DirectoryStore(SQLBaseStore): defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): - cursor = txn.execute( + txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) ) - res = cursor.fetchone() + res = txn.fetchone() if res: room_id = res[0] else: diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py new file mode 100644 index 0000000000..eb76df7f01 --- /dev/null +++ b/synapse/storage/engines/__init__.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .postgres import PostgresEngine +from .sqlite3 import Sqlite3Engine + +import importlib + + +SUPPORTED_MODULE = { + "sqlite3": Sqlite3Engine, + "psycopg2": PostgresEngine, +} + + +def create_engine(name): + engine_class = SUPPORTED_MODULE.get(name, None) + + if engine_class: + module = importlib.import_module(name) + return engine_class(module) + + raise RuntimeError( + "Unsupported database engine '%s'" % (name,) + ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py new file mode 100644 index 0000000000..b8cca9b187 --- /dev/null +++ b/synapse/storage/engines/postgres.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage import prepare_database + + +class PostgresEngine(object): + def __init__(self, database_module): + self.module = database_module + self.module.extensions.register_type(self.module.extensions.UNICODE) + + def convert_param_style(self, sql): + return sql.replace("?", "%s") + + def encode_parameter(self, param): + return param + + def on_new_connection(self, db_conn): + db_conn.set_isolation_level( + self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ + ) + + def prepare_database(self, db_conn): + prepare_database(db_conn, self) + + def is_deadlock(self, error): + if isinstance(error, self.module.DatabaseError): + return error.pgcode in ["40001", "40P01"] + return False + + def is_connection_closed(self, conn): + return bool(conn) + + def lock_table(self, txn, table): + txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py new file mode 100644 index 0000000000..f62d5d1205 --- /dev/null +++ b/synapse/storage/engines/sqlite3.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage import prepare_database, prepare_sqlite3_database + + +class Sqlite3Engine(object): + def __init__(self, database_module): + self.module = database_module + + def convert_param_style(self, sql): + return sql + + def encode_parameter(self, param): + return param + + def on_new_connection(self, db_conn): + self.prepare_database(db_conn) + + def prepare_database(self, db_conn): + prepare_sqlite3_database(db_conn) + prepare_database(db_conn, self) + + def is_deadlock(self, error): + return False + + def is_connection_closed(self, conn): + return False + + def lock_table(self, txn, table): + return diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 032334bfd6..54a3c9d805 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -153,7 +153,7 @@ class EventFederationStore(SQLBaseStore): results = self._get_prev_events_and_state( txn, event_id, - is_state=1, + is_state=True, ) return [(e_id, h, ) for e_id, h, _ in results] @@ -164,7 +164,7 @@ class EventFederationStore(SQLBaseStore): } if is_state is not None: - keyvalues["is_state"] = is_state + keyvalues["is_state"] = bool(is_state) res = self._simple_select_list_txn( txn, @@ -242,7 +242,6 @@ class EventFederationStore(SQLBaseStore): "room_id": room_id, "min_depth": depth, }, - or_replace=True, ) def _handle_prev_events(self, txn, outlier, event_id, prev_events, @@ -260,9 +259,8 @@ class EventFederationStore(SQLBaseStore): "event_id": event_id, "prev_event_id": e_id, "room_id": room_id, - "is_state": 0, + "is_state": False, }, - or_ignore=True, ) # Update the extremities table if this is not an outlier. @@ -281,19 +279,19 @@ class EventFederationStore(SQLBaseStore): # We only insert as a forward extremity the new event if there are # no other events that reference it as a prev event query = ( - "INSERT OR IGNORE INTO %(table)s (event_id, room_id) " - "SELECT ?, ? WHERE NOT EXISTS (" - "SELECT 1 FROM %(event_edges)s WHERE " - "prev_event_id = ? " - ")" - ) % { - "table": "event_forward_extremities", - "event_edges": "event_edges", - } + "SELECT 1 FROM event_edges WHERE prev_event_id = ?" + ) - logger.debug("query: %s", query) + txn.execute(query, (event_id,)) - txn.execute(query, (event_id, room_id, event_id)) + if not txn.fetchone(): + query = ( + "INSERT INTO event_forward_extremities" + " (event_id, room_id)" + " VALUES (?, ?)" + ) + + txn.execute(query, (event_id, room_id)) # Insert all the prev_events as a backwards thing, they'll get # deleted in a second if they're incorrect anyway. @@ -306,7 +304,6 @@ class EventFederationStore(SQLBaseStore): "event_id": e_id, "room_id": room_id, }, - or_ignore=True, ) # Also delete from the backwards extremities table all ones that @@ -400,7 +397,7 @@ class EventFederationStore(SQLBaseStore): query = ( "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = 0 " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -409,7 +406,7 @@ class EventFederationStore(SQLBaseStore): for event_id in front: txn.execute( query, - (room_id, event_id, limit - len(event_results)) + (room_id, event_id, False, limit - len(event_results)) ) for e_id, in txn.fetchall(): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2425f57f5f..a3c260ddc4 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -52,7 +52,6 @@ class EventsStore(SQLBaseStore): is_new_state=is_new_state, current_state=current_state, ) - self.get_room_events_max_id.invalidate() except _RollbackButIsFineException: pass @@ -96,12 +95,22 @@ class EventsStore(SQLBaseStore): # Remove the any existing cache entries for the event_id self._invalidate_get_event_cache(event.event_id) + if stream_ordering is None: + with self._stream_id_gen.get_next_txn(txn) as stream_ordering: + return self._persist_event_txn( + txn, event, context, backfilled, + stream_ordering=stream_ordering, + is_new_state=is_new_state, + current_state=current_state, + ) + # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: - txn.execute( - "DELETE FROM current_state_events WHERE room_id = ?", - (event.room_id,) + self._simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": event.room_id}, ) for s in current_state: @@ -114,7 +123,6 @@ class EventsStore(SQLBaseStore): "type": s.type, "state_key": s.state_key, }, - or_replace=True, ) if event.is_state() and is_new_state: @@ -128,7 +136,6 @@ class EventsStore(SQLBaseStore): "type": event.type, "state_key": event.state_key, }, - or_replace=True, ) for prev_state_id, _ in event.prev_state: @@ -151,14 +158,6 @@ class EventsStore(SQLBaseStore): event.depth ) - self._handle_prev_events( - txn, - outlier=outlier, - event_id=event.event_id, - prev_events=event.prev_events, - room_id=event.room_id, - ) - have_persisted = self._simple_select_one_onecol_txn( txn, table="event_json", @@ -169,7 +168,7 @@ class EventsStore(SQLBaseStore): metadata_json = encode_canonical_json( event.internal_metadata.get_dict() - ) + ).decode("UTF-8") # If we have already persisted this event, we don't need to do any # more processing. @@ -185,23 +184,29 @@ class EventsStore(SQLBaseStore): ) txn.execute( sql, - (metadata_json.decode("UTF-8"), event.event_id,) + (metadata_json, event.event_id,) ) sql = ( - "UPDATE events SET outlier = 0" + "UPDATE events SET outlier = ?" " WHERE event_id = ?" ) txn.execute( sql, - (event.event_id,) + (False, event.event_id,) ) return + self._handle_prev_events( + txn, + outlier=outlier, + event_id=event.event_id, + prev_events=event.prev_events, + room_id=event.room_id, + ) + if event.type == EventTypes.Member: self._store_room_member_txn(txn, event) - elif event.type == EventTypes.Feedback: - self._store_feedback_txn(txn, event) elif event.type == EventTypes.Name: self._store_room_name_txn(txn, event) elif event.type == EventTypes.Topic: @@ -224,10 +229,9 @@ class EventsStore(SQLBaseStore): values={ "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": metadata_json.decode("UTF-8"), + "internal_metadata": metadata_json, "json": encode_canonical_json(event_dict).decode("UTF-8"), }, - or_replace=True, ) content = encode_canonical_json( @@ -245,9 +249,6 @@ class EventsStore(SQLBaseStore): "depth": event.depth, } - if stream_ordering is not None: - vals["stream_ordering"] = stream_ordering - unrec = { k: v for k, v in event.get_dict().items() @@ -264,25 +265,53 @@ class EventsStore(SQLBaseStore): unrec ).decode("UTF-8") - try: - self._simple_insert_txn( - txn, - "events", - vals, - or_replace=(not outlier), - or_ignore=bool(outlier), - ) - except: - logger.warn( - "Failed to persist, probably duplicate: %s", - event.event_id, - exc_info=True, + sql = ( + "INSERT INTO events" + " (stream_ordering, topological_ordering, event_id, type," + " room_id, content, processed, outlier, depth)" + " VALUES (?,?,?,?,?,?,?,?,?)" + ) + + txn.execute( + sql, + ( + stream_ordering, event.depth, event.event_id, event.type, + event.room_id, content, True, outlier, event.depth ) - raise _RollbackButIsFineException("_persist_event") + ) if context.rejected: self._store_rejections_txn(txn, event.event_id, context.rejected) + for hash_alg, hash_base64 in event.hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_event_content_hash_txn( + txn, event.event_id, hash_alg, hash_bytes, + ) + + for prev_event_id, prev_hashes in event.prev_events: + for alg, hash_base64 in prev_hashes.items(): + hash_bytes = decode_base64(hash_base64) + self._store_prev_event_hash_txn( + txn, event.event_id, prev_event_id, alg, hash_bytes + ) + + for auth_id, _ in event.auth_events: + self._simple_insert_txn( + txn, + table="event_auth", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + }, + ) + + (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) + self._store_event_reference_hash_txn( + txn, event.event_id, ref_alg, ref_hash_bytes + ) + if event.is_state(): vals = { "event_id": event.event_id, @@ -301,18 +330,6 @@ class EventsStore(SQLBaseStore): vals, ) - if is_new_state and not context.rejected: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - ) - for e_id, h in event.prev_state: self._simple_insert_txn( txn, @@ -321,39 +338,24 @@ class EventsStore(SQLBaseStore): "event_id": event.event_id, "prev_event_id": e_id, "room_id": event.room_id, - "is_state": 1, + "is_state": True, }, ) - for hash_alg, hash_base64 in event.hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_event_content_hash_txn( - txn, event.event_id, hash_alg, hash_bytes, - ) - - for prev_event_id, prev_hashes in event.prev_events: - for alg, hash_base64 in prev_hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_prev_event_hash_txn( - txn, event.event_id, prev_event_id, alg, hash_bytes + if is_new_state and not context.rejected: + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } ) - for auth_id, _ in event.auth_events: - self._simple_insert_txn( - txn, - table="event_auth", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - }, - ) - - (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) - self._store_event_reference_hash_txn( - txn, event.event_id, ref_alg, ref_hash_bytes - ) - def _store_redaction(self, txn, event): # invalidate the cache for the redacted event self._invalidate_get_event_cache(event.redacts) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 22b158d71e..cbe9339ccf 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -57,16 +57,18 @@ class KeyStore(SQLBaseStore): OpenSSL.crypto.FILETYPE_ASN1, tls_certificate ) fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() - return self._simple_insert( + return self._simple_upsert( table="server_tls_certificates", - values={ + keyvalues={ "server_name": server_name, "fingerprint": fingerprint, + }, + values={ "from_server": from_server, "ts_added_ms": time_now_ms, "tls_certificate": buffer(tls_certificate_bytes), }, - or_ignore=True, + desc="store_server_certificate", ) @defer.inlineCallbacks @@ -107,16 +109,18 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_insert( + return self._simple_upsert( table="server_signature_keys", - values={ + keyvalues={ "server_name": server_name, "key_id": "%s:%s" % (verify_key.alg, verify_key.version), + }, + values={ "from_server": from_server, "ts_added_ms": time_now_ms, "verify_key": buffer(verify_key.encode()), }, - or_ignore=True, + desc="store_server_verify_key", ) def store_server_keys_json(self, server_name, key_id, from_server, diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 87fba55439..22ec94bc16 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore): values={"observed_user_id": observed_localpart, "observer_user_id": observer_userid}, desc="allow_presence_visible", + or_ignore=True, ) def disallow_presence_visible(self, observed_localpart, observer_userid): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index c47bdc2861..ee7718d5ed 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) # now insert the new rule - sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql = "INSERT INTO "+PushRuleTable.table_name+" (" sql += ",".join(new_rule.keys())+") VALUES (" sql += ", ".join(["?" for _ in new_rule.keys()])+")" @@ -183,7 +183,7 @@ class PushRuleStore(SQLBaseStore): new_rule['priority_class'] = priority_class new_rule['priority'] = new_prio - sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql = "INSERT INTO "+PushRuleTable.table_name+" (" sql += ",".join(new_rule.keys())+") VALUES (" sql += ", ".join(["?" for _ in new_rule.keys()])+")" diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 000502b4ff..2582a1da66 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections - from ._base import SQLBaseStore, Table from twisted.internet import defer from synapse.api.errors import StoreError +from syutil.jsonutil import encode_canonical_json + import logging logger = logging.getLogger(__name__) @@ -27,93 +27,55 @@ logger = logging.getLogger(__name__) class PusherStore(SQLBaseStore): @defer.inlineCallbacks - def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): + def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): sql = ( - "SELECT id, user_name, kind, profile_tag, app_id," - "app_display_name, device_display_name, pushkey, ts, data, " - "last_token, last_success, failing_since " - "FROM pushers " + "SELECT * FROM pushers " "WHERE app_id = ? AND pushkey = ?" ) - rows = yield self._execute( - "get_pushers_by_app_id_and_pushkey", None, sql, - app_id_and_pushkey[0], app_id_and_pushkey[1] + rows = yield self._execute_and_decode( + "get_pushers_by_app_id_and_pushkey", + sql, + app_id, pushkey ) - ret = [ - { - "id": r[0], - "user_name": r[1], - "kind": r[2], - "profile_tag": r[3], - "app_id": r[4], - "app_display_name": r[5], - "device_display_name": r[6], - "pushkey": r[7], - "pushkey_ts": r[8], - "data": r[9], - "last_token": r[10], - "last_success": r[11], - "failing_since": r[12] - } - for r in rows - ] - - defer.returnValue(ret[0]) + defer.returnValue(rows) @defer.inlineCallbacks def get_all_pushers(self): sql = ( - "SELECT id, user_name, kind, profile_tag, app_id," - "app_display_name, device_display_name, pushkey, ts, data, " - "last_token, last_success, failing_since " - "FROM pushers" + "SELECT * FROM pushers" ) - rows = yield self._execute("get_all_pushers", None, sql) - - ret = [ - { - "id": r[0], - "user_name": r[1], - "kind": r[2], - "profile_tag": r[3], - "app_id": r[4], - "app_display_name": r[5], - "device_display_name": r[6], - "pushkey": r[7], - "pushkey_ts": r[8], - "data": r[9], - "last_token": r[10], - "last_success": r[11], - "failing_since": r[12] - } - for r in rows - ] - - defer.returnValue(ret) + rows = yield self._execute_and_decode("get_all_pushers", sql) + + defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_name, profile_tag, kind, app_id, + def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data): try: + next_id = yield self._pushers_id_gen.get_next() yield self._simple_upsert( PushersTable.table_name, dict( app_id=app_id, pushkey=pushkey, + user_name=user_name, ), dict( - user_name=user_name, + access_token=access_token, kind=kind, profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, - data=data + data=encode_canonical_json(data).decode("UTF-8"), + ), + insertion_values=dict( + id=next_id, ), desc="add_pusher", ) @@ -122,37 +84,38 @@ class PusherStore(SQLBaseStore): raise StoreError(500, "Problem creating pusher.") @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): + def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name): yield self._simple_delete_one( PushersTable.table_name, - {"app_id": app_id, "pushkey": pushkey}, - desc="delete_pusher_by_app_id_pushkey", + {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name}, + desc="delete_pusher_by_app_id_pushkey_user_name", ) @defer.inlineCallbacks - def update_pusher_last_token(self, app_id, pushkey, last_token): + def update_pusher_last_token(self, app_id, pushkey, user_name, last_token): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'last_token': last_token}, desc="update_pusher_last_token", ) @defer.inlineCallbacks - def update_pusher_last_token_and_success(self, app_id, pushkey, + def update_pusher_last_token_and_success(self, app_id, pushkey, user_name, last_token, last_success): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'last_token': last_token, 'last_success': last_success}, desc="update_pusher_last_token_and_success", ) @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, failing_since): + def update_pusher_failing_since(self, app_id, pushkey, user_name, + failing_since): yield self._simple_update_one( PushersTable.table_name, - {'app_id': app_id, 'pushkey': pushkey}, + {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'failing_since': failing_since}, desc="update_pusher_failing_since", ) @@ -160,21 +123,3 @@ class PusherStore(SQLBaseStore): class PushersTable(Table): table_name = "pushers" - - fields = [ - "id", - "user_name", - "kind", - "profile_tag", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "pushkey_ts", - "data", - "last_token", - "last_success", - "failing_since" - ] - - EntryType = collections.namedtuple("PusherEntry", fields) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f24154f146..a986c4816e 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from sqlite3 import IntegrityError - from synapse.api.errors import StoreError, Codes from ._base import SQLBaseStore, cached @@ -39,17 +37,13 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - row = yield self._simple_select_one( - "users", {"name": user_id}, ["id"], - desc="add_access_token_to_user", - ) - if not row: - raise StoreError(400, "Bad user ID supplied.") - row_id = row["id"] + next_id = yield self._access_tokens_id_gen.get_next() + yield self._simple_insert( "access_tokens", { - "user_id": row_id, + "id": next_id, + "user_id": user_id, "token": token }, desc="add_access_token_to_user", @@ -74,32 +68,71 @@ class RegistrationStore(SQLBaseStore): def _register(self, txn, user_id, token, password_hash): now = int(self.clock.time()) + next_id = self._access_tokens_id_gen.get_next_txn(txn) + try: txn.execute("INSERT INTO users(name, password_hash, creation_ts) " "VALUES (?,?,?)", [user_id, password_hash, now]) - except IntegrityError: + except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) # it's possible for this to get a conflict, but only for a single user # since tokens are namespaced based on their user ID - txn.execute("INSERT INTO access_tokens(user_id, token) " + - "VALUES (?,?)", [txn.lastrowid, token]) + txn.execute( + "INSERT INTO access_tokens(id, user_id, token)" + " VALUES (?,?,?)", + (next_id, user_id, token,) + ) def get_user_by_id(self, user_id): - query = ("SELECT users.name, users.password_hash FROM users" - " WHERE users.name = ?") - return self._execute( - "get_user_by_id", self.cursor_to_dict, query, user_id + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=["name", "password_hash"], + allow_none=True, + ) + + @defer.inlineCallbacks + def user_set_password_hash(self, user_id, password_hash): + """ + NB. This does *not* evict any cache because the one use for this + removes most of the entries subsequently anyway so it would be + pointless. Use flush_user separately. + """ + yield self._simple_update_one('users', { + 'name': user_id + }, { + 'password_hash': password_hash + }) + + @defer.inlineCallbacks + def user_delete_access_tokens_apart_from(self, user_id, token_id): + rows = yield self.get_user_by_id(user_id) + if len(rows) == 0: + raise Exception("No such user!") + + yield self._execute( + "delete_access_tokens_apart_from", None, + "DELETE FROM access_tokens WHERE user_id = ? AND id != ?", + rows[0]['id'], token_id + ) + + @defer.inlineCallbacks + def flush_user(self, user_id): + rows = yield self._execute( + 'flush_user', None, + "SELECT token FROM access_tokens WHERE user_id = ?", + user_id ) + for r in rows: + self.get_user_by_token.invalidate(r) @cached() - # TODO(paul): Currently there's no code to invalidate this cache. That - # means if/when we ever add internal ways to invalidate access tokens or - # change whether a user is a server admin, those will need to invoke - # store.get_user_by_token.invalidate(token) def get_user_by_token(self, token): """Get a user from the given access token. @@ -134,13 +167,49 @@ class RegistrationStore(SQLBaseStore): "SELECT users.name, users.admin," " access_tokens.device_id, access_tokens.id as token_id" " FROM users" - " INNER JOIN access_tokens on users.id = access_tokens.user_id" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" ) - cursor = txn.execute(sql, (token,)) - rows = self.cursor_to_dict(cursor) + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) if rows: return rows[0] - raise StoreError(404, "Token not found.") + return None + + @defer.inlineCallbacks + def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + yield self._simple_upsert("user_threepids", { + "user": user_id, + "medium": medium, + "address": address, + }, { + "validated_at": validated_at, + "added_at": added_at, + }) + + @defer.inlineCallbacks + def user_get_threepids(self, user_id): + ret = yield self._simple_select_list( + "user_threepids", { + "user": user_id + }, + ['medium', 'address', 'validated_at', 'added_at'], + 'user_get_threepids' + ) + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_user_by_threepid(self, medium, address): + ret = yield self._simple_select_one( + "user_threepids", + { + "medium": medium, + "address": address + }, + ['user'], True, 'get_user_by_threepid' + ) + if ret: + defer.returnValue(ret['user']) + defer.returnValue(None) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index be3e28c2ea..48ebb33057 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -72,6 +72,7 @@ class RoomStore(SQLBaseStore): keyvalues={"room_id": room_id}, retcols=RoomsTable.fields, desc="get_room", + allow_none=True, ) @defer.inlineCallbacks @@ -102,24 +103,37 @@ class RoomStore(SQLBaseStore): "ON c.event_id = room_names.event_id " ) - # We use non printing ascii character US () as a seperator + # We use non printing ascii character US (\x1F) as a separator sql = ( - "SELECT r.room_id, n.name, t.topic, " - "group_concat(a.room_alias, '') " - "FROM rooms AS r " - "LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id " - "LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id " - "INNER JOIN room_aliases AS a ON a.room_id = r.room_id " - "WHERE r.is_public = ? " - "GROUP BY r.room_id " + "SELECT r.room_id, max(n.name), max(t.topic)" + " FROM rooms AS r" + " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" + " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" + " WHERE r.is_public = ?" + " GROUP BY r.room_id" ) % { "topic": topic_subquery, "name": name_subquery, } - c = txn.execute(sql, (is_public,)) + txn.execute(sql, (is_public,)) - return c.fetchall() + rows = txn.fetchall() + + for i, row in enumerate(rows): + room_id = row[0] + aliases = self._simple_select_onecol_txn( + txn, + table="room_aliases", + keyvalues={ + "room_id": room_id + }, + retcol="room_alias", + ) + + rows[i] = list(row) + [aliases] + + return rows rows = yield self.runInteraction( "get_rooms", f @@ -130,9 +144,10 @@ class RoomStore(SQLBaseStore): "room_id": r[0], "name": r[1], "topic": r[2], - "aliases": r[3].split(""), + "aliases": r[3], } for r in rows + if r[3] # We only return rooms that have at least one alias. ] defer.returnValue(ret) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 52c37c76f5..8ea5756d61 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -40,7 +40,6 @@ class RoomMemberStore(SQLBaseStore): """ try: target_user_id = event.state_key - domain = UserID.from_string(target_user_id).domain except: logger.exception( "Failed to parse target_user_id=%s", target_user_id @@ -65,42 +64,8 @@ class RoomMemberStore(SQLBaseStore): } ) - # Update room hosts table - if event.membership == Membership.JOIN: - sql = ( - "INSERT OR IGNORE INTO room_hosts (room_id, host) " - "VALUES (?, ?)" - ) - txn.execute(sql, (event.room_id, domain)) - elif event.membership != Membership.INVITE: - # Check if this was the last person to have left. - member_events = self._get_members_query_txn( - txn, - where_clause=("c.room_id = ? AND m.membership = ?" - " AND m.user_id != ?"), - where_values=(event.room_id, Membership.JOIN, target_user_id,) - ) - - joined_domains = set() - for e in member_events: - try: - joined_domains.add( - UserID.from_string(e.state_key).domain - ) - except: - # FIXME: How do we deal with invalid user ids in the db? - logger.exception("Invalid user_id: %s", event.state_key) - - if domain not in joined_domains: - sql = ( - "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" - ) - - txn.execute(sql, (event.room_id, domain)) - self.get_rooms_for_user.invalidate(target_user_id) - @defer.inlineCallbacks def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -110,41 +75,27 @@ class RoomMemberStore(SQLBaseStore): Returns: Deferred: Results in a MembershipEvent or None. """ - rows = yield self._get_members_by_dict({ - "e.room_id": room_id, - "m.user_id": user_id, - }) + def f(txn): + events = self._get_members_events_txn( + txn, + room_id, + user_id=user_id, + ) - defer.returnValue(rows[0] if rows else None) + return events[0] if events else None - def _get_room_member(self, txn, user_id, room_id): - sql = ( - "SELECT e.* FROM events as e" - " INNER JOIN room_memberships as m" - " ON e.event_id = m.event_id" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id" - " WHERE m.user_id = ? and e.room_id = ?" - " LIMIT 1" - ) - txn.execute(sql, (user_id, room_id)) - rows = self.cursor_to_dict(txn) - if rows: - return self._parse_events_txn(txn, rows)[0] - else: - return None + return self.runInteraction("get_room_member", f) def get_users_in_room(self, room_id): def f(txn): - sql = ( - "SELECT m.user_id FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id" - " WHERE m.membership = ? AND m.room_id = ?" + + rows = self._get_members_rows_txn( + txn, + room_id=room_id, + membership=Membership.JOIN, ) - txn.execute(sql, (Membership.JOIN, room_id)) - return [r[0] for r in txn.fetchall()] + return [r["user_id"] for r in rows] return self.runInteraction("get_users_in_room", f) def get_room_members(self, room_id, membership=None): @@ -159,11 +110,14 @@ class RoomMemberStore(SQLBaseStore): list of namedtuples representing the members in this room. """ - where = {"m.room_id": room_id} - if membership: - where["m.membership"] = membership + def f(txn): + return self._get_members_events_txn( + txn, + room_id, + membership=membership, + ) - return self._get_members_by_dict(where) + return self.runInteraction("get_room_members", f) def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -209,32 +163,55 @@ class RoomMemberStore(SQLBaseStore): ] def get_joined_hosts_for_room(self, room_id): - return self._simple_select_onecol( - "room_hosts", - {"room_id": room_id}, - "host", - desc="get_joined_hosts_for_room", + return self.runInteraction( + "get_joined_hosts_for_room", + self._get_joined_hosts_for_room_txn, + room_id, + ) + + def _get_joined_hosts_for_room_txn(self, txn, room_id): + rows = self._get_members_rows_txn( + txn, + room_id, membership=Membership.JOIN + ) + + joined_domains = set( + UserID.from_string(r["user_id"]).domain + for r in rows ) - def _get_members_by_dict(self, where_dict): - clause = " AND ".join("%s = ?" % k for k in where_dict.keys()) - vals = where_dict.values() - return self._get_members_query(clause, vals) + return joined_domains def _get_members_query(self, where_clause, where_values): return self.runInteraction( - "get_members_query", self._get_members_query_txn, + "get_members_query", self._get_members_events_txn, where_clause, where_values ) - def _get_members_query_txn(self, txn, where_clause, where_values): + def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): + rows = self._get_members_rows_txn( + txn, + room_id, membership, user_id, + ) + return self._get_events_txn(txn, [r["event_id"] for r in rows]) + + def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): + where_clause = "c.room_id = ?" + where_values = [room_id] + + if membership: + where_clause += " AND m.membership = ?" + where_values.append(membership) + + if user_id: + where_clause += " AND m.user_id = ?" + where_values.append(user_id) + sql = ( - "SELECT e.* FROM events as e " - "INNER JOIN room_memberships as m " - "ON e.event_id = m.event_id " - "INNER JOIN current_state_events as c " - "ON m.event_id = c.event_id " - "WHERE %(where)s " + "SELECT m.* FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id" + " WHERE %(where)s" ) % { "where": where_clause, } @@ -242,8 +219,7 @@ class RoomMemberStore(SQLBaseStore): txn.execute(sql, where_values) rows = self.cursor_to_dict(txn) - results = self._parse_events_txn(txn, rows) - return results + return rows @cached() def get_rooms_for_user(self, user_id): diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/schema/delta/12/v12.sql index b87ef1fe79..878c36260a 100644 --- a/synapse/storage/schema/delta/12/v12.sql +++ b/synapse/storage/schema/delta/12/v12.sql @@ -17,26 +17,25 @@ CREATE TABLE IF NOT EXISTS rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, - CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE + UNIQUE (event_id) ); -- Push notification endpoints that users have configured CREATE TABLE IF NOT EXISTS pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_name TEXT NOT NULL, - profile_tag varchar(32) NOT NULL, - kind varchar(8) NOT NULL, - app_id varchar(64) NOT NULL, - app_display_name varchar(64) NOT NULL, - device_display_name varchar(128) NOT NULL, - pushkey blob NOT NULL, - ts BIGINT NOT NULL, - lang varchar(8), - data blob, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey VARBINARY(512) NOT NULL, + ts BIGINT UNSIGNED NOT NULL, + lang VARCHAR(8), + data LONGBLOB, last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - FOREIGN KEY(user_name) REFERENCES users(name), + last_success BIGINT UNSIGNED, + failing_since BIGINT UNSIGNED, UNIQUE (app_id, pushkey) ); @@ -55,13 +54,10 @@ CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); CREATE TABLE IF NOT EXISTS user_filters( user_id TEXT, - filter_id INTEGER, - filter_json TEXT, - FOREIGN KEY(user_id) REFERENCES users(id) + filter_id BIGINT UNSIGNED, + filter_json LONGBLOB ); CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id + user_id, filter_id ); - -PRAGMA user_version = 12; diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/schema/delta/13/v13.sql index e491ad5aec..3265924013 100644 --- a/synapse/storage/schema/delta/13/v13.sql +++ b/synapse/storage/schema/delta/13/v13.sql @@ -19,16 +19,13 @@ CREATE TABLE IF NOT EXISTS application_services( token TEXT, hs_token TEXT, sender TEXT, - UNIQUE(token) ON CONFLICT ROLLBACK + UNIQUE(token) ); CREATE TABLE IF NOT EXISTS application_services_regex( id INTEGER PRIMARY KEY AUTOINCREMENT, - as_id INTEGER NOT NULL, + as_id BIGINT UNSIGNED NOT NULL, namespace INTEGER, /* enum[room_id|room_alias|user_id] */ regex TEXT, FOREIGN KEY(as_id) REFERENCES application_services(id) ); - - - diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/schema/delta/15/appservice_txns.sql index 2b27e2a429..db2e720393 100644 --- a/synapse/storage/schema/delta/15/appservice_txns.sql +++ b/synapse/storage/schema/delta/15/appservice_txns.sql @@ -15,16 +15,17 @@ CREATE TABLE IF NOT EXISTS application_services_state( as_id TEXT PRIMARY KEY, - state TEXT, - last_txn TEXT + state VARCHAR(5), + last_txn INTEGER ); CREATE TABLE IF NOT EXISTS application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, - UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK + UNIQUE(as_id, txn_id) ); - - +CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/schema/delta/15/presence_indices.sql new file mode 100644 index 0000000000..6b8d0f1ca7 --- /dev/null +++ b/synapse/storage/schema/delta/15/presence_indices.sql @@ -0,0 +1,2 @@ + +CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/schema/delta/15/v15.sql new file mode 100644 index 0000000000..f5b2a08ca4 --- /dev/null +++ b/synapse/storage/schema/delta/15/v15.sql @@ -0,0 +1,25 @@ +-- Drop, copy & recreate pushers table to change unique key +-- Also add access_token column at the same time +CREATE TABLE IF NOT EXISTS pushers2 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + access_token INTEGER DEFAULT NULL, + profile_tag varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey, user_name) +); +INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) + SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; +DROP TABLE pushers; +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/schema/delta/16/events_order_index.sql new file mode 100644 index 0000000000..a48f215170 --- /dev/null +++ b/synapse/storage/schema/delta/16/events_order_index.sql @@ -0,0 +1,4 @@ +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/schema/delta/16/remote_media_cache_index.sql new file mode 100644 index 0000000000..7a15265cb1 --- /dev/null +++ b/synapse/storage/schema/delta/16/remote_media_cache_index.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); \ No newline at end of file diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/schema/delta/16/remove_duplicates.sql new file mode 100644 index 0000000000..65c97b5e2f --- /dev/null +++ b/synapse/storage/schema/delta/16/remove_duplicates.sql @@ -0,0 +1,9 @@ + + +DELETE FROM event_to_state_groups WHERE state_group not in ( + SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id +); + +DELETE FROM event_to_state_groups WHERE rowid not in ( + SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id +); diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/schema/delta/16/room_alias_index.sql new file mode 100644 index 0000000000..f82486132b --- /dev/null +++ b/synapse/storage/schema/delta/16/room_alias_index.sql @@ -0,0 +1,3 @@ + +CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); +CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/schema/delta/16/unique_constraints.sql new file mode 100644 index 0000000000..fecf11118c --- /dev/null +++ b/synapse/storage/schema/delta/16/unique_constraints.sql @@ -0,0 +1,80 @@ + +-- We can use SQLite features here, since other db support was only added in v16 + +-- +DELETE FROM current_state_events WHERE rowid not in ( + SELECT MIN(rowid) FROM current_state_events GROUP BY event_id +); + +DROP INDEX IF EXISTS current_state_events_event_id; +CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id); + +-- +DELETE FROM room_memberships WHERE rowid not in ( + SELECT MIN(rowid) FROM room_memberships GROUP BY event_id +); + +DROP INDEX IF EXISTS room_memberships_event_id; +CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); + +-- +DELETE FROM feedback WHERE rowid not in ( + SELECT MIN(rowid) FROM feedback GROUP BY event_id +); + +DROP INDEX IF EXISTS feedback_event_id; +CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id); + +-- +DELETE FROM topics WHERE rowid not in ( + SELECT MIN(rowid) FROM topics GROUP BY event_id +); + +DROP INDEX IF EXISTS topics_event_id; +CREATE UNIQUE INDEX topics_event_id ON topics(event_id); + +-- +DELETE FROM room_names WHERE rowid not in ( + SELECT MIN(rowid) FROM room_names GROUP BY event_id +); + +DROP INDEX IF EXISTS room_names_id; +CREATE UNIQUE INDEX room_names_id ON room_names(event_id); + +-- +DELETE FROM presence WHERE rowid not in ( + SELECT MIN(rowid) FROM presence GROUP BY user_id +); + +DROP INDEX IF EXISTS presence_id; +CREATE UNIQUE INDEX presence_id ON presence(user_id); + +-- +DELETE FROM presence_allow_inbound WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_allow_inbound + GROUP BY observed_user_id, observer_user_id +); + +DROP INDEX IF EXISTS presence_allow_inbound_observers; +CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound( + observed_user_id, observer_user_id +); + +-- +DELETE FROM presence_list WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_list + GROUP BY user_id, observed_user_id +); + +DROP INDEX IF EXISTS presence_list_observers; +CREATE UNIQUE INDEX presence_list_observers ON presence_list( + user_id, observed_user_id +); + +-- +DELETE FROM room_aliases WHERE rowid not in ( + SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias +); + +DROP INDEX IF EXISTS room_aliases_id; +CREATE INDEX room_aliases_id ON room_aliases(room_id); diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/schema/delta/16/users.sql new file mode 100644 index 0000000000..cd0709250d --- /dev/null +++ b/synapse/storage/schema/delta/16/users.sql @@ -0,0 +1,56 @@ +-- Convert `access_tokens`.user from rowids to user strings. +-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW +CREATE TABLE IF NOT EXISTS new_access_tokens( + id BIGINT UNSIGNED PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT UNSIGNED, + UNIQUE(token) +); + +INSERT INTO new_access_tokens + SELECT a.id, u.name, a.device_id, a.token, a.last_used + FROM access_tokens as a + INNER JOIN users as u ON u.id = a.user_id; + +DROP TABLE access_tokens; + +ALTER TABLE new_access_tokens RENAME TO access_tokens; + +-- Remove ID column from `users` table +CREATE TABLE IF NOT EXISTS new_users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT UNSIGNED, + admin BOOL DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users; + +DROP TABLE users; + +ALTER TABLE new_users RENAME TO users; + + +-- Remove UNIQUE constraint from `user_ips` table +CREATE TABLE IF NOT EXISTS new_user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT UNSIGNED NOT NULL +); + +INSERT INTO new_user_ips + SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips; + +DROP TABLE user_ips; + +ALTER TABLE new_user_ips RENAME TO user_ips; + +CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id); +CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip); + diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql index 1e766d6db2..f7020f7793 100644 --- a/synapse/storage/schema/full_schemas/11/event_edges.sql +++ b/synapse/storage/schema/full_schemas/11/event_edges.sql @@ -16,52 +16,52 @@ CREATE TABLE IF NOT EXISTS event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, - CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE + UNIQUE (event_id, room_id) ); -CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); CREATE TABLE IF NOT EXISTS event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, - CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE + UNIQUE (event_id, room_id) ); -CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); CREATE TABLE IF NOT EXISTS event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, - is_state INTEGER NOT NULL, - CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state) + is_state BOOL NOT NULL, + UNIQUE (event_id, prev_event_id, room_id, is_state) ); -CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); -CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE IF NOT EXISTS room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, - CONSTRAINT uniqueness UNIQUE (room_id) + UNIQUE (room_id) ); -CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); +CREATE INDEX room_depth_room ON room_depth(room_id); create TABLE IF NOT EXISTS event_destinations( event_id TEXT NOT NULL, destination TEXT NOT NULL, - delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered - CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE + delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered + UNIQUE (event_id, destination) ); -CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); +CREATE INDEX event_destinations_id ON event_destinations(event_id); CREATE TABLE IF NOT EXISTS state_forward_extremities( @@ -69,21 +69,21 @@ CREATE TABLE IF NOT EXISTS state_forward_extremities( room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, - CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE + UNIQUE (event_id, room_id) ); -CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities( +CREATE INDEX st_extrem_keys ON state_forward_extremities( room_id, type, state_key ); -CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id); +CREATE INDEX st_extrem_id ON state_forward_extremities(event_id); CREATE TABLE IF NOT EXISTS event_auth( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL, - CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id) + UNIQUE (event_id, auth_id, room_id) ); -CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id); -CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id); \ No newline at end of file +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/schema/full_schemas/11/event_signatures.sql b/synapse/storage/schema/full_schemas/11/event_signatures.sql index c28c39c48a..636b2d3353 100644 --- a/synapse/storage/schema/full_schemas/11/event_signatures.sql +++ b/synapse/storage/schema/full_schemas/11/event_signatures.sql @@ -16,50 +16,40 @@ CREATE TABLE IF NOT EXISTS event_content_hashes ( event_id TEXT, algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE (event_id, algorithm) + hash bytea, + UNIQUE (event_id, algorithm) ); -CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( - event_id -); +CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id); CREATE TABLE IF NOT EXISTS event_reference_hashes ( event_id TEXT, algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE (event_id, algorithm) + hash bytea, + UNIQUE (event_id, algorithm) ); -CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( - event_id -); +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); CREATE TABLE IF NOT EXISTS event_signatures ( event_id TEXT, signature_name TEXT, key_id TEXT, - signature BLOB, - CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) + signature bytea, + UNIQUE (event_id, signature_name, key_id) ); -CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( - event_id -); +CREATE INDEX event_signatures_id ON event_signatures(event_id); CREATE TABLE IF NOT EXISTS event_edge_hashes( event_id TEXT, prev_event_id TEXT, algorithm TEXT, - hash BLOB, - CONSTRAINT uniqueness UNIQUE ( - event_id, prev_event_id, algorithm - ) + hash bytea, + UNIQUE (event_id, prev_event_id, algorithm) ); -CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( - event_id -); +CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id); diff --git a/synapse/storage/schema/full_schemas/11/im.sql b/synapse/storage/schema/full_schemas/11/im.sql index dd00c1cd2f..1901654ac2 100644 --- a/synapse/storage/schema/full_schemas/11/im.sql +++ b/synapse/storage/schema/full_schemas/11/im.sql @@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS events( stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT, - topological_ordering INTEGER NOT NULL, + topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, @@ -23,26 +23,24 @@ CREATE TABLE IF NOT EXISTS events( unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, - depth INTEGER DEFAULT 0 NOT NULL, - CONSTRAINT ev_uniq UNIQUE (event_id) + depth BIGINT DEFAULT 0 NOT NULL, + UNIQUE (event_id) ); -CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id); -CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering); -CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); -CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id); +CREATE INDEX events_stream_ordering ON events (stream_ordering); +CREATE INDEX events_topological_ordering ON events (topological_ordering); +CREATE INDEX events_room_id ON events (room_id); CREATE TABLE IF NOT EXISTS event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, - internal_metadata NOT NULL, - json BLOB NOT NULL, - CONSTRAINT ev_j_uniq UNIQUE (event_id) + internal_metadata TEXT NOT NULL, + json TEXT NOT NULL, + UNIQUE (event_id) ); -CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id); -CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id); +CREATE INDEX event_json_room_id ON event_json(room_id); CREATE TABLE IF NOT EXISTS state_events( @@ -50,13 +48,13 @@ CREATE TABLE IF NOT EXISTS state_events( room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, - prev_state TEXT + prev_state TEXT, + UNIQUE (event_id) ); -CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id); -CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id); -CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type); -CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key); +CREATE INDEX state_events_room_id ON state_events (room_id); +CREATE INDEX state_events_type ON state_events (type); +CREATE INDEX state_events_state_key ON state_events (state_key); CREATE TABLE IF NOT EXISTS current_state_events( @@ -64,13 +62,13 @@ CREATE TABLE IF NOT EXISTS current_state_events( room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, - CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE + UNIQUE (room_id, type, state_key) ); -CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id); -CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id); -CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type); -CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key); +CREATE INDEX curr_events_event_id ON current_state_events (event_id); +CREATE INDEX current_state_events_room_id ON current_state_events (room_id); +CREATE INDEX current_state_events_type ON current_state_events (type); +CREATE INDEX current_state_events_state_key ON current_state_events (state_key); CREATE TABLE IF NOT EXISTS room_memberships( event_id TEXT NOT NULL, @@ -80,9 +78,9 @@ CREATE TABLE IF NOT EXISTS room_memberships( membership TEXT NOT NULL ); -CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id); -CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id); +CREATE INDEX room_memberships_event_id ON room_memberships (event_id); +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); CREATE TABLE IF NOT EXISTS feedback( event_id TEXT NOT NULL, @@ -98,8 +96,8 @@ CREATE TABLE IF NOT EXISTS topics( topic TEXT NOT NULL ); -CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id); -CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id); +CREATE INDEX topics_event_id ON topics(event_id); +CREATE INDEX topics_room_id ON topics(room_id); CREATE TABLE IF NOT EXISTS room_names( event_id TEXT NOT NULL, @@ -107,19 +105,19 @@ CREATE TABLE IF NOT EXISTS room_names( name TEXT NOT NULL ); -CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id); -CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id); +CREATE INDEX room_names_event_id ON room_names(event_id); +CREATE INDEX room_names_room_id ON room_names(room_id); CREATE TABLE IF NOT EXISTS rooms( room_id TEXT PRIMARY KEY NOT NULL, - is_public INTEGER, + is_public BOOL, creator TEXT ); CREATE TABLE IF NOT EXISTS room_hosts( room_id TEXT NOT NULL, host TEXT NOT NULL, - CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE + UNIQUE (room_id, host) ); -CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); +CREATE INDEX room_hosts_room_id ON room_hosts (room_id); diff --git a/synapse/storage/schema/full_schemas/11/keys.sql b/synapse/storage/schema/full_schemas/11/keys.sql index a9e0a4fe0d..afc142045e 100644 --- a/synapse/storage/schema/full_schemas/11/keys.sql +++ b/synapse/storage/schema/full_schemas/11/keys.sql @@ -16,16 +16,16 @@ CREATE TABLE IF NOT EXISTS server_tls_certificates( server_name TEXT, -- Server name. fingerprint TEXT, -- Certificate fingerprint. from_server TEXT, -- Which key server the certificate was fetched from. - ts_added_ms INTEGER, -- When the certifcate was added. - tls_certificate BLOB, -- DER encoded x509 certificate. - CONSTRAINT uniqueness UNIQUE (server_name, fingerprint) + ts_added_ms BIGINT, -- When the certifcate was added. + tls_certificate bytea, -- DER encoded x509 certificate. + UNIQUE (server_name, fingerprint) ); CREATE TABLE IF NOT EXISTS server_signature_keys( server_name TEXT, -- Server name. key_id TEXT, -- Key version. from_server TEXT, -- Which key server the key was fetched form. - ts_added_ms INTEGER, -- When the key was added. - verify_key BLOB, -- NACL verification key. - CONSTRAINT uniqueness UNIQUE (server_name, key_id) + ts_added_ms BIGINT, -- When the key was added. + verify_key bytea, -- NACL verification key. + UNIQUE (server_name, key_id) ); diff --git a/synapse/storage/schema/full_schemas/11/media_repository.sql b/synapse/storage/schema/full_schemas/11/media_repository.sql index afdf48cbfb..e927e581d1 100644 --- a/synapse/storage/schema/full_schemas/11/media_repository.sql +++ b/synapse/storage/schema/full_schemas/11/media_repository.sql @@ -17,10 +17,10 @@ CREATE TABLE IF NOT EXISTS local_media_repository ( media_id TEXT, -- The id used to refer to the media. media_type TEXT, -- The MIME-type of the media. media_length INTEGER, -- Length of the media in bytes. - created_ts INTEGER, -- When the content was uploaded in ms. + created_ts BIGINT, -- When the content was uploaded in ms. upload_name TEXT, -- The name the media was uploaded with. user_id TEXT, -- The user who uploaded the file. - CONSTRAINT uniqueness UNIQUE (media_id) + UNIQUE (media_id) ); CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( @@ -30,23 +30,23 @@ CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_method TEXT, -- The method used to make the thumbnail. thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - CONSTRAINT uniqueness UNIQUE ( + UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); -CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); CREATE TABLE IF NOT EXISTS remote_media_cache ( media_origin TEXT, -- The remote HS the media came from. media_id TEXT, -- The id used to refer to the media on that server. media_type TEXT, -- The MIME-type of the media. - created_ts INTEGER, -- When the content was uploaded in ms. + created_ts BIGINT, -- When the content was uploaded in ms. upload_name TEXT, -- The name the media was uploaded with. media_length INTEGER, -- Length of the media in bytes. filesystem_id TEXT, -- The name used to store the media on disk. - CONSTRAINT uniqueness UNIQUE (media_origin, media_id) + UNIQUE (media_origin, media_id) ); CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( @@ -58,11 +58,8 @@ CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_length INTEGER, -- The length of the thumbnail in bytes. filesystem_id TEXT, -- The name used to store the media on disk. - CONSTRAINT uniqueness UNIQUE ( + UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, - thumbnail_type, thumbnail_type - ) + thumbnail_type + ) ); - -CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); diff --git a/synapse/storage/schema/full_schemas/11/presence.sql b/synapse/storage/schema/full_schemas/11/presence.sql index f9f8db9697..d8d82e9fe3 100644 --- a/synapse/storage/schema/full_schemas/11/presence.sql +++ b/synapse/storage/schema/full_schemas/11/presence.sql @@ -13,26 +13,23 @@ * limitations under the License. */ CREATE TABLE IF NOT EXISTS presence( - user_id INTEGER NOT NULL, - state INTEGER, + user_id TEXT NOT NULL, + state VARCHAR(20), status_msg TEXT, - mtime INTEGER, -- miliseconds since last state change - FOREIGN KEY(user_id) REFERENCES users(id) + mtime BIGINT -- miliseconds since last state change ); -- For each of /my/ users which possibly-remote users are allowed to see their -- presence state CREATE TABLE IF NOT EXISTS presence_allow_inbound( - observed_user_id INTEGER NOT NULL, - observer_user_id TEXT, -- a UserID, - FOREIGN KEY(observed_user_id) REFERENCES users(id) + observed_user_id TEXT NOT NULL, + observer_user_id TEXT NOT NULL -- a UserID, ); -- For each of /my/ users (watcher), which possibly-remote users are they -- watching? CREATE TABLE IF NOT EXISTS presence_list( - user_id INTEGER NOT NULL, - observed_user_id TEXT, -- a UserID, - accepted BOOLEAN, - FOREIGN KEY(user_id) REFERENCES users(id) + user_id TEXT NOT NULL, + observed_user_id TEXT NOT NULL, -- a UserID, + accepted BOOLEAN NOT NULL ); diff --git a/synapse/storage/schema/full_schemas/11/profiles.sql b/synapse/storage/schema/full_schemas/11/profiles.sql index f06a528b4d..26e4204437 100644 --- a/synapse/storage/schema/full_schemas/11/profiles.sql +++ b/synapse/storage/schema/full_schemas/11/profiles.sql @@ -13,8 +13,7 @@ * limitations under the License. */ CREATE TABLE IF NOT EXISTS profiles( - user_id INTEGER NOT NULL, + user_id TEXT NOT NULL, displayname TEXT, - avatar_url TEXT, - FOREIGN KEY(user_id) REFERENCES users(id) + avatar_url TEXT ); diff --git a/synapse/storage/schema/full_schemas/11/redactions.sql b/synapse/storage/schema/full_schemas/11/redactions.sql index 5011d95db8..69621955d4 100644 --- a/synapse/storage/schema/full_schemas/11/redactions.sql +++ b/synapse/storage/schema/full_schemas/11/redactions.sql @@ -15,8 +15,8 @@ CREATE TABLE IF NOT EXISTS redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, - CONSTRAINT ev_uniq UNIQUE (event_id) + UNIQUE (event_id) ); -CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id); -CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts); +CREATE INDEX redactions_event_id ON redactions (event_id); +CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/schema/full_schemas/11/room_aliases.sql b/synapse/storage/schema/full_schemas/11/room_aliases.sql index 0d2df01603..5027b1e3f6 100644 --- a/synapse/storage/schema/full_schemas/11/room_aliases.sql +++ b/synapse/storage/schema/full_schemas/11/room_aliases.sql @@ -22,6 +22,3 @@ CREATE TABLE IF NOT EXISTS room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); - - - diff --git a/synapse/storage/schema/full_schemas/11/state.sql b/synapse/storage/schema/full_schemas/11/state.sql index 1fe8f1e430..ffd164ab71 100644 --- a/synapse/storage/schema/full_schemas/11/state.sql +++ b/synapse/storage/schema/full_schemas/11/state.sql @@ -30,18 +30,11 @@ CREATE TABLE IF NOT EXISTS state_groups_state( CREATE TABLE IF NOT EXISTS event_to_state_groups( event_id TEXT NOT NULL, state_group INTEGER NOT NULL, - CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id) + UNIQUE (event_id) ); -CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id); +CREATE INDEX state_groups_id ON state_groups(id); -CREATE INDEX IF NOT EXISTS state_groups_state_id ON state_groups_state( - state_group -); -CREATE INDEX IF NOT EXISTS state_groups_state_tuple ON state_groups_state( - room_id, type, state_key -); - -CREATE INDEX IF NOT EXISTS event_to_state_groups_id ON event_to_state_groups( - event_id -); \ No newline at end of file +CREATE INDEX state_groups_state_id ON state_groups_state(state_group); +CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); +CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/schema/full_schemas/11/transactions.sql b/synapse/storage/schema/full_schemas/11/transactions.sql index 2d30f99b06..cc5b54f5aa 100644 --- a/synapse/storage/schema/full_schemas/11/transactions.sql +++ b/synapse/storage/schema/full_schemas/11/transactions.sql @@ -14,17 +14,16 @@ */ -- Stores what transaction ids we have received and what our response was CREATE TABLE IF NOT EXISTS received_transactions( - transaction_id TEXT, - origin TEXT, - ts INTEGER, + transaction_id TEXT, + origin TEXT, + ts BIGINT, response_code INTEGER, - response_json TEXT, - has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx - CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE + response_json bytea, + has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx + UNIQUE (transaction_id, origin) ); -CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin); -CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; +CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; -- Stores what transactions we've sent, what their response was (if we got one) and whether we have @@ -35,17 +34,14 @@ CREATE TABLE IF NOT EXISTS sent_transactions( destination TEXT, response_code INTEGER DEFAULT 0, response_json TEXT, - ts INTEGER + ts BIGINT ); -CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); -CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions( - destination -); -CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id); +CREATE INDEX sent_transaction_dest ON sent_transactions(destination); +CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id); -- So that we can do an efficient look up of all transactions that have yet to be successfully -- sent. -CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code); +CREATE INDEX sent_transaction_sent ON sent_transactions(response_code); -- For sent transactions only. @@ -56,13 +52,12 @@ CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( pdu_origin TEXT ); -CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); -CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); -CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination); +CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); +CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); -- To track destination health CREATE TABLE IF NOT EXISTS destinations( destination TEXT PRIMARY KEY, - retry_last_ts INTEGER, + retry_last_ts BIGINT, retry_interval INTEGER ); diff --git a/synapse/storage/schema/full_schemas/11/users.sql b/synapse/storage/schema/full_schemas/11/users.sql index 08ccfdac0a..eec3da3c35 100644 --- a/synapse/storage/schema/full_schemas/11/users.sql +++ b/synapse/storage/schema/full_schemas/11/users.sql @@ -16,19 +16,18 @@ CREATE TABLE IF NOT EXISTS users( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, password_hash TEXT, - creation_ts INTEGER, - admin BOOL DEFAULT 0 NOT NULL, - UNIQUE(name) ON CONFLICT ROLLBACK + creation_ts BIGINT, + admin SMALLINT DEFAULT 0 NOT NULL, + UNIQUE(name) ); CREATE TABLE IF NOT EXISTS access_tokens( id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, + user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, - last_used INTEGER, - FOREIGN KEY(user_id) REFERENCES users(id), - UNIQUE(token) ON CONFLICT ROLLBACK + last_used BIGINT, + UNIQUE(token) ); CREATE TABLE IF NOT EXISTS user_ips ( @@ -37,9 +36,8 @@ CREATE TABLE IF NOT EXISTS user_ips ( device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, - last_seen INTEGER NOT NULL, - CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE + last_seen BIGINT NOT NULL, + UNIQUE (user, access_token, ip, user_agent) ); -CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user); - +CREATE INDEX user_ips_user ON user_ips(user); diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/schema/full_schemas/16/application_services.sql new file mode 100644 index 0000000000..d382d63fbd --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/application_services.sql @@ -0,0 +1,48 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS application_services( + id BIGINT PRIMARY KEY, + url TEXT, + token TEXT, + hs_token TEXT, + sender TEXT, + UNIQUE(token) +); + +CREATE TABLE IF NOT EXISTS application_services_regex( + id BIGINT PRIMARY KEY, + as_id BIGINT NOT NULL, + namespace INTEGER, /* enum[room_id|room_alias|user_id] */ + regex TEXT, + FOREIGN KEY(as_id) REFERENCES application_services(id) +); + +CREATE TABLE IF NOT EXISTS application_services_state( + as_id TEXT PRIMARY KEY, + state VARCHAR(5), + last_txn INTEGER +); + +CREATE TABLE IF NOT EXISTS application_services_txns( + as_id TEXT NOT NULL, + txn_id INTEGER NOT NULL, + event_ids TEXT NOT NULL, + UNIQUE(as_id, txn_id) +); + +CREATE INDEX application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql new file mode 100644 index 0000000000..f7020f7793 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/event_edges.sql @@ -0,0 +1,89 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_state BOOL NOT NULL, + UNIQUE (event_id, prev_event_id, room_id, is_state) +); + +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT NOT NULL, + min_depth INTEGER NOT NULL, + UNIQUE (room_id) +); + +CREATE INDEX room_depth_room ON room_depth(room_id); + + +create TABLE IF NOT EXISTS event_destinations( + event_id TEXT NOT NULL, + destination TEXT NOT NULL, + delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered + UNIQUE (event_id, destination) +); + +CREATE INDEX event_destinations_id ON event_destinations(event_id); + + +CREATE TABLE IF NOT EXISTS state_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX st_extrem_keys ON state_forward_extremities( + room_id, type, state_key +); +CREATE INDEX st_extrem_id ON state_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_auth( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, auth_id, room_id) +); + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/schema/full_schemas/16/event_signatures.sql new file mode 100644 index 0000000000..636b2d3353 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/event_signatures.sql @@ -0,0 +1,55 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_content_hashes ( + event_id TEXT, + algorithm TEXT, + hash bytea, + UNIQUE (event_id, algorithm) +); + +CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id); + + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash bytea, + UNIQUE (event_id, algorithm) +); + +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); + + +CREATE TABLE IF NOT EXISTS event_signatures ( + event_id TEXT, + signature_name TEXT, + key_id TEXT, + signature bytea, + UNIQUE (event_id, signature_name, key_id) +); + +CREATE INDEX event_signatures_id ON event_signatures(event_id); + + +CREATE TABLE IF NOT EXISTS event_edge_hashes( + event_id TEXT, + prev_event_id TEXT, + algorithm TEXT, + hash bytea, + UNIQUE (event_id, prev_event_id, algorithm) +); + +CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id); diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql new file mode 100644 index 0000000000..576653a3c9 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/im.sql @@ -0,0 +1,128 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS events( + stream_ordering INTEGER PRIMARY KEY, + topological_ordering BIGINT NOT NULL, + event_id TEXT NOT NULL, + type TEXT NOT NULL, + room_id TEXT NOT NULL, + content TEXT NOT NULL, + unrecognized_keys TEXT, + processed BOOL NOT NULL, + outlier BOOL NOT NULL, + depth BIGINT DEFAULT 0 NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX events_stream_ordering ON events (stream_ordering); +CREATE INDEX events_topological_ordering ON events (topological_ordering); +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_room_id ON events (room_id); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); + + +CREATE TABLE IF NOT EXISTS event_json( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + internal_metadata TEXT NOT NULL, + json TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX event_json_room_id ON event_json(room_id); + + +CREATE TABLE IF NOT EXISTS state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + prev_state TEXT, + UNIQUE (event_id) +); + +CREATE INDEX state_events_room_id ON state_events (room_id); +CREATE INDEX state_events_type ON state_events (type); +CREATE INDEX state_events_state_key ON state_events (state_key); + + +CREATE TABLE IF NOT EXISTS current_state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + UNIQUE (event_id), + UNIQUE (room_id, type, state_key) +); + +CREATE INDEX current_state_events_room_id ON current_state_events (room_id); +CREATE INDEX current_state_events_type ON current_state_events (type); +CREATE INDEX current_state_events_state_key ON current_state_events (state_key); + +CREATE TABLE IF NOT EXISTS room_memberships( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + sender TEXT NOT NULL, + room_id TEXT NOT NULL, + membership TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); + +CREATE TABLE IF NOT EXISTS feedback( + event_id TEXT NOT NULL, + feedback_type TEXT, + target_event_id TEXT, + sender TEXT, + room_id TEXT, + UNIQUE (event_id) +); + +CREATE TABLE IF NOT EXISTS topics( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + topic TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX topics_room_id ON topics(room_id); + +CREATE TABLE IF NOT EXISTS room_names( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + name TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_names_room_id ON room_names(room_id); + +CREATE TABLE IF NOT EXISTS rooms( + room_id TEXT PRIMARY KEY NOT NULL, + is_public BOOL, + creator TEXT +); + +CREATE TABLE IF NOT EXISTS room_hosts( + room_id TEXT NOT NULL, + host TEXT NOT NULL, + UNIQUE (room_id, host) +); + +CREATE INDEX room_hosts_room_id ON room_hosts (room_id); diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/schema/full_schemas/16/keys.sql new file mode 100644 index 0000000000..afc142045e --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/keys.sql @@ -0,0 +1,31 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS server_tls_certificates( + server_name TEXT, -- Server name. + fingerprint TEXT, -- Certificate fingerprint. + from_server TEXT, -- Which key server the certificate was fetched from. + ts_added_ms BIGINT, -- When the certifcate was added. + tls_certificate bytea, -- DER encoded x509 certificate. + UNIQUE (server_name, fingerprint) +); + +CREATE TABLE IF NOT EXISTS server_signature_keys( + server_name TEXT, -- Server name. + key_id TEXT, -- Key version. + from_server TEXT, -- Which key server the key was fetched form. + ts_added_ms BIGINT, -- When the key was added. + verify_key bytea, -- NACL verification key. + UNIQUE (server_name, key_id) +); diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/schema/full_schemas/16/media_repository.sql new file mode 100644 index 0000000000..dacbda40ca --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/media_repository.sql @@ -0,0 +1,68 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS local_media_repository ( + media_id TEXT, -- The id used to refer to the media. + media_type TEXT, -- The MIME-type of the media. + media_length INTEGER, -- Length of the media in bytes. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + user_id TEXT, -- The user who uploaded the file. + UNIQUE (media_id) +); + +CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_method TEXT, -- The method used to make the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + UNIQUE ( + media_id, thumbnail_width, thumbnail_height, thumbnail_type + ) +); + +CREATE INDEX local_media_repository_thumbnails_media_id + ON local_media_repository_thumbnails (media_id); + +CREATE TABLE IF NOT EXISTS remote_media_cache ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media on that server. + media_type TEXT, -- The MIME-type of the media. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + media_length INTEGER, -- Length of the media in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE (media_origin, media_id) +); + +CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_method TEXT, -- The method used to make the thumbnail + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE ( + media_origin, media_id, thumbnail_width, thumbnail_height, + thumbnail_type + ) +); + +CREATE INDEX remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql new file mode 100644 index 0000000000..80088413ba --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/presence.sql @@ -0,0 +1,40 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS presence( + user_id TEXT NOT NULL, + state VARCHAR(20), + status_msg TEXT, + mtime BIGINT, -- miliseconds since last state change + UNIQUE (user_id) +); + +-- For each of /my/ users which possibly-remote users are allowed to see their +-- presence state +CREATE TABLE IF NOT EXISTS presence_allow_inbound( + observed_user_id TEXT NOT NULL, + observer_user_id TEXT NOT NULL, -- a UserID, + UNIQUE (observed_user_id, observer_user_id) +); + +-- For each of /my/ users (watcher), which possibly-remote users are they +-- watching? +CREATE TABLE IF NOT EXISTS presence_list( + user_id TEXT NOT NULL, + observed_user_id TEXT NOT NULL, -- a UserID, + accepted BOOLEAN NOT NULL, + UNIQUE (user_id, observed_user_id) +); + +CREATE INDEX presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/schema/full_schemas/16/profiles.sql new file mode 100644 index 0000000000..934be86520 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/profiles.sql @@ -0,0 +1,20 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS profiles( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + UNIQUE(user_id) +); diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/schema/full_schemas/16/push.sql new file mode 100644 index 0000000000..db6e05cbdf --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/push.sql @@ -0,0 +1,73 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + UNIQUE (event_id) +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data bytea, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class SMALLINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_user_name on push_rules (user_name); + +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id BIGINT, + filter_json bytea +); + +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); + +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled SMALLINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/schema/full_schemas/16/redactions.sql new file mode 100644 index 0000000000..69621955d4 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/redactions.sql @@ -0,0 +1,22 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS redactions ( + event_id TEXT NOT NULL, + redacts TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX redactions_event_id ON redactions (event_id); +CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/schema/full_schemas/16/room_aliases.sql new file mode 100644 index 0000000000..412bb97fad --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/room_aliases.sql @@ -0,0 +1,29 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS room_aliases( + room_alias TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (room_alias) +); + +CREATE INDEX room_aliases_id ON room_aliases(room_id); + +CREATE TABLE IF NOT EXISTS room_alias_servers( + room_alias TEXT NOT NULL, + server TEXT NOT NULL +); + +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/schema/full_schemas/16/state.sql new file mode 100644 index 0000000000..705cac6ce9 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/state.sql @@ -0,0 +1,40 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX state_groups_id ON state_groups(id); + +CREATE INDEX state_groups_state_id ON state_groups_state(state_group); +CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); +CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/schema/full_schemas/16/transactions.sql new file mode 100644 index 0000000000..1ab77cdb63 --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/transactions.sql @@ -0,0 +1,63 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores what transaction ids we have received and what our response was +CREATE TABLE IF NOT EXISTS received_transactions( + transaction_id TEXT, + origin TEXT, + ts BIGINT, + response_code INTEGER, + response_json bytea, + has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx + UNIQUE (transaction_id, origin) +); + +CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; + + +-- Stores what transactions we've sent, what their response was (if we got one) and whether we have +-- since referenced the transaction in another outgoing transaction +CREATE TABLE IF NOT EXISTS sent_transactions( + id BIGINT PRIMARY KEY, -- This is used to apply insertion ordering + transaction_id TEXT, + destination TEXT, + response_code INTEGER DEFAULT 0, + response_json TEXT, + ts BIGINT +); + +CREATE INDEX sent_transaction_dest ON sent_transactions(destination); +CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id); +-- So that we can do an efficient look up of all transactions that have yet to be successfully +-- sent. +CREATE INDEX sent_transaction_sent ON sent_transactions(response_code); + + +-- For sent transactions only. +CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( + transaction_id INTEGER, + destination TEXT, + pdu_id TEXT, + pdu_origin TEXT, + UNIQUE (transaction_id, destination) +); + +CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); + +-- To track destination health +CREATE TABLE IF NOT EXISTS destinations( + destination TEXT PRIMARY KEY, + retry_last_ts BIGINT, + retry_interval INTEGER +); diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/schema/full_schemas/16/users.sql new file mode 100644 index 0000000000..d2fa3122da --- /dev/null +++ b/synapse/storage/schema/full_schemas/16/users.sql @@ -0,0 +1,42 @@ +/* Copyright 2014, 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT, + admin SMALLINT DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +CREATE TABLE IF NOT EXISTS access_tokens( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT, + UNIQUE(token) +); + +CREATE TABLE IF NOT EXISTS user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT NOT NULL +); + +CREATE INDEX user_ips_user ON user_ips(user_id); +CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index 0431e2d051..d682608aa0 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -14,17 +14,14 @@ */ CREATE TABLE IF NOT EXISTS schema_version( - Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row. + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. version INTEGER NOT NULL, upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema. - CONSTRAINT schema_version_lock_x CHECK (Lock='X') - CONSTRAINT schema_version_lock_uniq UNIQUE (Lock) + CHECK (Lock='X') ); CREATE TABLE IF NOT EXISTS applied_schema_deltas( version INTEGER NOT NULL, file TEXT NOT NULL, - CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE + UNIQUE(version, file) ); - -CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index d0d53770f2..f051828630 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -56,7 +56,6 @@ class SignatureStore(SQLBaseStore): "algorithm": algorithm, "hash": buffer(hash_bytes), }, - or_ignore=True, ) def get_event_reference_hashes(self, event_ids): @@ -100,7 +99,7 @@ class SignatureStore(SQLBaseStore): " WHERE event_id = ?" ) txn.execute(query, (event_id, )) - return dict(txn.fetchall()) + return {k: v for k, v in txn.fetchall()} def _store_event_reference_hash_txn(self, txn, event_id, algorithm, hash_bytes): @@ -119,7 +118,6 @@ class SignatureStore(SQLBaseStore): "algorithm": algorithm, "hash": buffer(hash_bytes), }, - or_ignore=True, ) def _get_event_signatures_txn(self, txn, event_id): @@ -164,7 +162,6 @@ class SignatureStore(SQLBaseStore): "key_id": key_id, "signature": buffer(signature_bytes), }, - or_ignore=True, ) def _get_prev_event_hashes_txn(self, txn, event_id): @@ -198,5 +195,4 @@ class SignatureStore(SQLBaseStore): "algorithm": algorithm, "hash": buffer(hash_bytes), }, - or_ignore=True, ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 58dbf2802b..553ba9dd1f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -17,6 +17,8 @@ from ._base import SQLBaseStore from twisted.internet import defer +from synapse.util.stringutils import random_string + import logging logger = logging.getLogger(__name__) @@ -91,14 +93,15 @@ class StateStore(SQLBaseStore): state_group = context.state_group if not state_group: - state_group = self._simple_insert_txn( + state_group = self._state_groups_id_gen.get_next_txn(txn) + self._simple_insert_txn( txn, table="state_groups", values={ + "id": state_group, "room_id": event.room_id, "event_id": event.event_id, }, - or_ignore=True, ) for state in state_events.values(): @@ -112,7 +115,6 @@ class StateStore(SQLBaseStore): "state_key": state.state_key, "event_id": state.event_id, }, - or_ignore=True, ) self._simple_insert_txn( @@ -122,7 +124,6 @@ class StateStore(SQLBaseStore): "state_group": state_group, "event_id": event.event_id, }, - or_replace=True, ) @defer.inlineCallbacks @@ -154,3 +155,7 @@ class StateStore(SQLBaseStore): events = yield self._parse_events(results) defer.returnValue(events) + + +def _make_group_id(clock): + return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 66f307e640..df6de7cbcd 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,7 +35,7 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.util.logutils import log_function @@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): if self.topological is None: return "(%d < %s)" % (self.stream, "stream_ordering") else: - return "(%d < %s OR (%d == %s AND %d < %s))" % ( + return "(%d < %s OR (%d = %s AND %d < %s))" % ( self.topological, "topological_ordering", self.topological, "topological_ordering", self.stream, "stream_ordering", @@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): if self.topological is None: return "(%d >= %s)" % (self.stream, "stream_ordering") else: - return "(%d > %s OR (%d == %s AND %d >= %s))" % ( + return "(%d > %s OR (%d = %s AND %d >= %s))" % ( self.topological, "topological_ordering", self.topological, "topological_ordering", self.stream, "stream_ordering", @@ -240,7 +240,7 @@ class StreamStore(SQLBaseStore): sql = ( "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " - "(e.outlier = 0 AND (room_id IN (%(current)s)) OR " + "(e.outlier = ? AND (room_id IN (%(current)s)) OR " "(event_id IN (%(invites)s))) " "AND e.stream_ordering > ? AND e.stream_ordering <= ? " "ORDER BY stream_ordering ASC LIMIT %(limit)d " @@ -251,7 +251,7 @@ class StreamStore(SQLBaseStore): } def f(txn): - txn.execute(sql, (user_id, user_id, from_id.stream, to_id.stream,)) + txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) rows = self.cursor_to_dict(txn) @@ -283,7 +283,7 @@ class StreamStore(SQLBaseStore): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. - args = [room_id] + args = [False, room_id] if direction == 'b': order = "DESC" bounds = _StreamToken.parse(from_key).upper_bound() @@ -307,7 +307,7 @@ class StreamStore(SQLBaseStore): sql = ( "SELECT * FROM events" - " WHERE outlier = 0 AND room_id = ? AND %(bounds)s" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s %(limit)s" ) % { @@ -358,7 +358,7 @@ class StreamStore(SQLBaseStore): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" - " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" + " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?" " ORDER BY topological_ordering DESC, stream_ordering DESC" " LIMIT ?" ) @@ -368,17 +368,17 @@ class StreamStore(SQLBaseStore): "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering > ?" - " AND stream_ordering <= ? AND outlier = 0" + " AND stream_ordering <= ? AND outlier = ?" " ORDER BY topological_ordering DESC, stream_ordering DESC" " LIMIT ?" ) def get_recent_events_for_room_txn(txn): if from_token is None: - txn.execute(sql, (room_id, end_token.stream, limit,)) + txn.execute(sql, (room_id, end_token.stream, False, limit,)) else: txn.execute(sql, ( - room_id, from_token.stream, end_token.stream, limit + room_id, from_token.stream, end_token.stream, False, limit )) rows = self.cursor_to_dict(txn) @@ -413,12 +413,10 @@ class StreamStore(SQLBaseStore): "get_recent_events_for_room", get_recent_events_for_room_txn ) - @cached(num_args=0) + @defer.inlineCallbacks def get_room_events_max_id(self): - return self.runInteraction( - "get_room_events_max_id", - self._get_room_events_max_id_txn - ) + token = yield self._stream_id_gen.get_max_token(self) + defer.returnValue("s%d" % (token,)) @defer.inlineCallbacks def _get_min_token(self): @@ -433,27 +431,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(self.min_token) - def get_next_stream_id(self): - with self._next_stream_id_lock: - i = self._next_stream_id - self._next_stream_id += 1 - return i - - def _get_room_events_max_id_txn(self, txn): - txn.execute( - "SELECT MAX(stream_ordering) as m FROM events" - ) - - res = self.cursor_to_dict(txn) - - logger.debug("get_room_events_max_id: %s", res) - - if not res or not res[0] or not res[0]["m"]: - return "s0" - - key = res[0]["m"] - return "s%d" % (key,) - @staticmethod def _set_before_and_after(events, rows): for event, row in zip(events, rows): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index b777395e06..89dd7d8947 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table, cached +from ._base import SQLBaseStore, cached from collections import namedtuple @@ -76,22 +76,18 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self.runInteraction( - "set_received_txn_response", - self._set_received_txn_response, - transaction_id, origin, code, response_dict + return self._simple_insert( + table=ReceivedTransactionsTable.table_name, + values={ + "transaction_id": transaction_id, + "origin": origin, + "response_code": code, + "response_json": response_dict, + }, + or_ignore=True, + desc="set_received_txn_response", ) - def _set_received_txn_response(self, txn, transaction_id, origin, code, - response_json): - query = ( - "UPDATE %s " - "SET response_code = ?, response_json = ? " - "WHERE transaction_id = ? AND origin = ?" - ) % ReceivedTransactionsTable.table_name - - txn.execute(query, (code, response_json, transaction_id, origin)) - def prep_send_transaction(self, transaction_id, destination, origin_server_ts): """Persists an outgoing transaction and calculates the values for the @@ -118,41 +114,38 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): + next_id = self._transaction_id_gen.get_next_txn(txn) + # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, # we can simply take the last one. - query = "%s ORDER BY id DESC LIMIT 1" % ( - SentTransactions.select_statement("destination = ?"), - ) + query = ( + "SELECT * FROM sent_transactions" + " WHERE destination = ?" + " ORDER BY id DESC LIMIT 1" + ) - results = txn.execute(query, (destination,)) - results = SentTransactions.decode_results(results) + txn.execute(query, (destination,)) + results = self.cursor_to_dict(txn) - prev_txns = [r.transaction_id for r in results] + prev_txns = [r["transaction_id"] for r in results] # Actually add the new transaction to the sent_transactions table. - query = SentTransactions.insert_statement() - txn.execute(query, SentTransactions.EntryType( - None, - transaction_id=transaction_id, - destination=destination, - ts=origin_server_ts, - response_code=0, - response_json=None - )) - - # Update the tx id -> pdu id mapping - - # values = [ - # (transaction_id, destination, pdu[0], pdu[1]) - # for pdu in pdu_list - # ] - # - # logger.debug("Inserting: %s", repr(values)) - # - # query = TransactionsToPduTable.insert_statement() - # txn.executemany(query, values) + self._simple_insert_txn( + txn, + table=SentTransactions.table_name, + values={ + "id": next_id, + "transaction_id": transaction_id, + "destination": destination, + "ts": origin_server_ts, + "response_code": 0, + "response_json": None, + } + ) + + # TODO Update the tx id -> pdu id mapping return prev_txns @@ -171,15 +164,20 @@ class TransactionStore(SQLBaseStore): transaction_id, destination, code, response_dict ) - def _delivered_txn(cls, txn, transaction_id, destination, + def _delivered_txn(self, txn, transaction_id, destination, code, response_json): - query = ( - "UPDATE %s " - "SET response_code = ?, response_json = ? " - "WHERE transaction_id = ? AND destination = ?" - ) % SentTransactions.table_name - - txn.execute(query, (code, response_json, transaction_id, destination)) + self._simple_update_one_txn( + txn, + table=SentTransactions.table_name, + keyvalues={ + "transaction_id": transaction_id, + "destination": destination, + }, + updatevalues={ + "response_code": code, + "response_json": None, # For now, don't persist response_json + } + ) def get_transactions_after(self, transaction_id, destination): """Get all transactions after a given local transaction_id. @@ -189,25 +187,26 @@ class TransactionStore(SQLBaseStore): destination (str) Returns: - list: A list of `ReceivedTransactionsTable.EntryType` + list: A list of dicts """ return self.runInteraction( "get_transactions_after", self._get_transactions_after, transaction_id, destination ) - def _get_transactions_after(cls, txn, transaction_id, destination): - where = ( - "destination = ? AND id > (select id FROM %s WHERE " - "transaction_id = ? AND destination = ?)" - ) % ( - SentTransactions.table_name + def _get_transactions_after(self, txn, transaction_id, destination): + query = ( + "SELECT * FROM sent_transactions" + " WHERE destination = ? AND id >" + " (" + " SELECT id FROM sent_transactions" + " WHERE transaction_id = ? AND destination = ?" + " )" ) - query = SentTransactions.select_statement(where) txn.execute(query, (destination, transaction_id, destination)) - return ReceivedTransactionsTable.decode_results(txn.fetchall()) + return self.cursor_to_dict(txn) @cached() def get_destination_retry_timings(self, destination): @@ -218,22 +217,27 @@ class TransactionStore(SQLBaseStore): Returns: None if not retrying - Otherwise a DestinationsTable.EntryType for the retry scheme + Otherwise a dict for the retry scheme """ return self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) - def _get_destination_retry_timings(cls, txn, destination): - query = DestinationsTable.select_statement("destination = ?") - txn.execute(query, (destination,)) - result = txn.fetchall() - if result: - result = DestinationsTable.decode_single_result(result) - if result.retry_last_ts > 0: - return result - else: - return None + def _get_destination_retry_timings(self, txn, destination): + result = self._simple_select_one_txn( + txn, + table=DestinationsTable.table_name, + keyvalues={ + "destination": destination, + }, + retcols=DestinationsTable.fields, + allow_none=True, + ) + + if result and result["retry_last_ts"] > 0: + return result + else: + return None def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval): @@ -249,11 +253,11 @@ class TransactionStore(SQLBaseStore): # As this is the new value, we might as well prefill the cache self.get_destination_retry_timings.prefill( destination, - DestinationsTable.EntryType( - destination, - retry_last_ts, - retry_interval - ) + { + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval + }, ) # XXX: we could chose to not bother persisting this if our cache thinks @@ -266,22 +270,38 @@ class TransactionStore(SQLBaseStore): retry_interval, ) - def _set_destination_retry_timings(cls, txn, destination, + def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - query = ( - "INSERT OR REPLACE INTO %s " - "(destination, retry_last_ts, retry_interval) " - "VALUES (?, ?, ?) " - ) % DestinationsTable.table_name + "UPDATE destinations" + " SET retry_last_ts = ?, retry_interval = ?" + " WHERE destination = ?" + ) + + txn.execute( + query, + ( + retry_last_ts, retry_interval, destination, + ) + ) - txn.execute(query, (destination, retry_last_ts, retry_interval)) + if txn.rowcount == 0: + # destination wasn't already in table. Insert it. + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } + ) def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. Returns: - list: A list of `DestinationsTable.EntryType` + list: A list of dicts """ return self.runInteraction( @@ -289,14 +309,17 @@ class TransactionStore(SQLBaseStore): self._get_destinations_needing_retry ) - def _get_destinations_needing_retry(cls, txn): - where = "retry_last_ts > 0 and retry_next_ts < now()" - query = DestinationsTable.select_statement(where) - txn.execute(query) - return DestinationsTable.decode_results(txn.fetchall()) + def _get_destinations_needing_retry(self, txn): + query = ( + "SELECT * FROM destinations" + " WHERE retry_last_ts > 0 and retry_next_ts < ?" + ) + + txn.execute(query, (self._clock.time_msec(),)) + return self.cursor_to_dict(txn) -class ReceivedTransactionsTable(Table): +class ReceivedTransactionsTable(object): table_name = "received_transactions" fields = [ @@ -308,10 +331,8 @@ class ReceivedTransactionsTable(Table): "has_been_referenced", ] - EntryType = namedtuple("ReceivedTransactionsEntry", fields) - -class SentTransactions(Table): +class SentTransactions(object): table_name = "sent_transactions" fields = [ @@ -326,7 +347,7 @@ class SentTransactions(Table): EntryType = namedtuple("SentTransactionsEntry", fields) -class TransactionsToPduTable(Table): +class TransactionsToPduTable(object): table_name = "transaction_id_to_pdu" fields = [ @@ -336,10 +357,8 @@ class TransactionsToPduTable(Table): "pdu_origin", ] - EntryType = namedtuple("TransactionsToPduEntry", fields) - -class DestinationsTable(Table): +class DestinationsTable(object): table_name = "destinations" fields = [ @@ -347,5 +366,3 @@ class DestinationsTable(Table): "retry_last_ts", "retry_interval", ] - - EntryType = namedtuple("DestinationsEntry", fields) diff --git a/synapse/storage/util/__init__.py b/synapse/storage/util/__init__.py new file mode 100644 index 0000000000..c488b10d3c --- /dev/null +++ b/synapse/storage/util/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py new file mode 100644 index 0000000000..9d461d5e96 --- /dev/null +++ b/synapse/storage/util/id_generators.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from collections import deque +import contextlib +import threading + + +class IdGenerator(object): + def __init__(self, table, column, store): + self.table = table + self.column = column + self.store = store + self._lock = threading.Lock() + self._next_id = None + + @defer.inlineCallbacks + def get_next(self): + with self._lock: + if not self._next_id: + res = yield self.store._execute_and_decode( + "IdGenerator_%s" % (self.table,), + "SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,) + ) + + self._next_id = (res and res[0] and res[0]["mx"]) or 1 + + i = self._next_id + self._next_id += 1 + defer.returnValue(i) + + def get_next_txn(self, txn): + with self._lock: + if self._next_id: + i = self._next_id + self._next_id += 1 + return i + else: + txn.execute( + "SELECT MAX(%s) FROM %s" % (self.column, self.table,) + ) + + val, = txn.fetchone() + cur = val or 0 + cur += 1 + self._next_id = cur + 1 + + return cur + + +class StreamIdGenerator(object): + """Used to generate new stream ids when persisting events while keeping + track of which transactions have been completed. + + This allows us to get the "current" stream id, i.e. the stream id such that + all ids less than or equal to it have completed. This handles the fact that + persistence of events can complete out of order. + + Usage: + with stream_id_gen.get_next_txn(txn) as stream_id: + # ... persist event ... + """ + def __init__(self): + self._lock = threading.Lock() + + self._current_max = None + self._unfinished_ids = deque() + + def get_next_txn(self, txn): + """ + Usage: + with stream_id_gen.get_next_txn(txn) as stream_id: + # ... persist event ... + """ + with self._lock: + if not self._current_max: + self._compute_current_max(txn) + + self._current_max += 1 + next_id = self._current_max + + self._unfinished_ids.append(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + with self._lock: + self._unfinished_ids.remove(next_id) + + return manager() + + @defer.inlineCallbacks + def get_max_token(self, store): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + defer.returnValue(self._unfinished_ids[0] - 1) + + if not self._current_max: + yield store.runInteraction( + "_compute_current_max", + self._compute_current_max, + ) + + defer.returnValue(self._current_max) + + def _compute_current_max(self, txn): + txn.execute("SELECT MAX(stream_ordering) FROM events") + val, = txn.fetchone() + + self._current_max = int(val) if val else 1 + + return self._current_max diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py index 2f7b615f78..96163c90f1 100644 --- a/synapse/util/lrucache.py +++ b/synapse/util/lrucache.py @@ -14,6 +14,10 @@ # limitations under the License. +from functools import wraps +import threading + + class LruCache(object): """Least-recently-used cache.""" # TODO(mjark) Add mutex for linked list for thread safety. @@ -24,6 +28,16 @@ class LruCache(object): PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + lock = threading.Lock() + + def synchronized(f): + @wraps(f) + def inner(*args, **kwargs): + with lock: + return f(*args, **kwargs) + + return inner + def add_node(key, value): prev_node = list_root next_node = prev_node[NEXT] @@ -51,6 +65,7 @@ class LruCache(object): next_node[PREV] = prev_node cache.pop(node[KEY], None) + @synchronized def cache_get(key, default=None): node = cache.get(key, None) if node is not None: @@ -59,6 +74,7 @@ class LruCache(object): else: return default + @synchronized def cache_set(key, value): node = cache.get(key, None) if node is not None: @@ -69,6 +85,7 @@ class LruCache(object): if len(cache) > max_size: delete_node(list_root[PREV]) + @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: @@ -79,6 +96,7 @@ class LruCache(object): delete_node(list_root[PREV]) return value + @synchronized def cache_pop(key, default=None): node = cache.get(key, None) if node: @@ -87,9 +105,11 @@ class LruCache(object): else: return default + @synchronized def cache_len(): return len(cache) + @synchronized def cache_contains(key): return key in cache diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 4e82232796..a42138f556 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -60,7 +60,7 @@ def get_retry_limiter(destination, clock, store, **kwargs): if retry_timings: retry_last_ts, retry_interval = ( - retry_timings.retry_last_ts, retry_timings.retry_interval + retry_timings["retry_last_ts"], retry_timings["retry_interval"] ) now = int(clock.time_msec()) |