diff --git a/synapse/__init__.py b/synapse/__init__.py
index 56c10a84e9..c89f444f4e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.8.1-r4"
+__version__ = "0.9.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 64f605b962..d5bf0be85c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -18,9 +18,8 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
-from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging
@@ -40,6 +39,7 @@ class Auth(object):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
+ self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
@@ -64,7 +64,10 @@ class Auth(object):
if event.type == EventTypes.Aliases:
return True
- logger.debug("Auth events: %s", auth_events)
+ logger.debug(
+ "Auth events: %s",
+ [a.event_id for a in auth_events.values()]
+ )
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
@@ -183,18 +186,10 @@ class Auth(object):
else:
join_rule = JoinRules.INVITE
- user_level = self._get_power_level_from_event_state(
- event,
- event.user_id,
- auth_events,
- )
+ user_level = self._get_user_power_level(event.user_id, auth_events)
- ban_level, kick_level, redact_level = (
- self._get_ops_level_from_event_state(
- event,
- auth_events,
- )
- )
+ # FIXME (erikj): What should we do here as the default?
+ ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
@@ -210,28 +205,33 @@ class Auth(object):
}
)
- if ban_level:
- ban_level = int(ban_level)
- else:
- ban_level = 50 # FIXME (erikj): What should we do here?
+ if Membership.JOIN != membership:
+ # JOIN is the only action you can perform if you're not in the room
+ if not caller_in_room: # caller isn't joined
+ raise AuthError(
+ 403,
+ "%s not in room %s." % (event.user_id, event.room_id,)
+ )
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
- if not caller_in_room: # caller isn't joined
- raise AuthError(
- 403,
- "%s not in room %s." % (event.user_id, event.room_id,)
- )
- elif target_banned:
+ if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
+ else:
+ invite_level = self._get_named_level(auth_events, "invite", 0)
+
+ if user_level < invite_level:
+ raise AuthError(
+ 403, "You cannot invite user %s." % target_user_id
+ )
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
@@ -251,21 +251,12 @@ class Auth(object):
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
-
- if not caller_in_room: # trying to leave a room you aren't joined
- raise AuthError(
- 403,
- "%s not in room %s." % (target_user_id, event.room_id,)
- )
- elif target_banned and user_level < ban_level:
+ if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
- if kick_level:
- kick_level = int(kick_level)
- else:
- kick_level = 50 # FIXME (erikj): What should we do here?
+ kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level:
raise AuthError(
@@ -279,34 +270,42 @@ class Auth(object):
return True
- def _get_power_level_from_event_state(self, event, user_id, auth_events):
+ def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
- power_level_event = auth_events.get(key)
- level = None
+ return auth_events.get(key)
+
+ def _get_user_power_level(self, user_id, auth_events):
+ power_level_event = self._get_power_level_event(auth_events)
+
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
+
+ if level is None:
+ return 0
+ else:
+ return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
+ else:
+ return 0
- return level
+ def _get_named_level(self, auth_events, name, default):
+ power_level_event = self._get_power_level_event(auth_events)
- def _get_ops_level_from_event_state(self, event, auth_events):
- key = (EventTypes.PowerLevels, "", )
- power_level_event = auth_events.get(key)
+ if not power_level_event:
+ return default
- if power_level_event:
- return (
- power_level_event.content.get("ban", 50),
- power_level_event.content.get("kick", 50),
- power_level_event.content.get("redact", 50),
- )
- return None, None, None,
+ level = power_level_event.content.get(name, None)
+ if level is not None:
+ return int(level)
+ else:
+ return default
@defer.inlineCallbacks
def get_user_by_req(self, request):
@@ -363,7 +362,7 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
- yield self.store.insert_client_ip(
+ self.store.insert_client_ip(
user=user,
access_token=access_token,
device_id=user_info["device_id"],
@@ -373,7 +372,10 @@ class Auth(object):
defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError:
- raise AuthError(403, "Missing access token.")
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
@defer.inlineCallbacks
def get_user_by_token(self, token):
@@ -387,21 +389,20 @@ class Auth(object):
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
- try:
- ret = yield self.store.get_user_by_token(token)
- if not ret:
- raise StoreError(400, "Unknown token")
- user_info = {
- "admin": bool(ret.get("admin", False)),
- "device_id": ret.get("device_id"),
- "user": UserID.from_string(ret.get("name")),
- "token_id": ret.get("token_id", None),
- }
+ ret = yield self.store.get_user_by_token(token)
+ if not ret:
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+ user_info = {
+ "admin": bool(ret.get("admin", False)),
+ "device_id": ret.get("device_id"),
+ "user": UserID.from_string(ret.get("name")),
+ "token_id": ret.get("token_id", None),
+ }
- defer.returnValue(user_info)
- except StoreError:
- raise AuthError(403, "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN)
+ defer.returnValue(user_info)
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
@@ -409,19 +410,22 @@ class Auth(object):
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
- raise AuthError(403, "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN)
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Unrecognised access token.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
defer.returnValue(service)
except KeyError:
- raise AuthError(403, "Missing access token.")
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
+ )
def is_server_admin(self, user):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
- yield run_on_reactor()
-
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(
@@ -486,7 +490,7 @@ class Auth(object):
send_level = send_level_event.content.get("events", {}).get(
event.type
)
- if not send_level:
+ if send_level is None:
if hasattr(event, "state_key"):
send_level = send_level_event.content.get(
"state_default", 50
@@ -501,16 +505,7 @@ class Auth(object):
else:
send_level = 0
- user_level = self._get_power_level_from_event_state(
- event,
- event.user_id,
- auth_events,
- )
-
- if user_level:
- user_level = int(user_level)
- else:
- user_level = 0
+ user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
@@ -542,16 +537,9 @@ class Auth(object):
return True
def _check_redaction(self, event, auth_events):
- user_level = self._get_power_level_from_event_state(
- event,
- event.user_id,
- auth_events,
- )
+ user_level = self._get_user_power_level(event.user_id, auth_events)
- _, _, redact_level = self._get_ops_level_from_event_state(
- event,
- auth_events,
- )
+ redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level < redact_level:
raise AuthError(
@@ -579,11 +567,7 @@ class Auth(object):
if not current_state:
return
- user_level = self._get_power_level_from_event_state(
- event,
- event.user_id,
- auth_events,
- )
+ user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
@@ -592,6 +576,7 @@ class Auth(object):
("ban", []),
("redact", []),
("kick", []),
+ ("invite", []),
]
old_list = current_state.content.get("users")
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index b16bf4247d..d8a18ee87b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -59,6 +59,9 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
+ DUMMY = u"m.login.dummy"
+
+ # Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index eddd889778..0b3320e62c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -31,13 +31,15 @@ class Codes(object):
BAD_PAGINATION = "M_BAD_PAGINATION"
UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND"
+ MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
- MISSING_PARAM = "M_MISSING_PARAM",
- TOO_LARGE = "M_TOO_LARGE",
+ MISSING_PARAM = "M_MISSING_PARAM"
+ TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE"
+ THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
class CodeMessageException(RuntimeError):
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 3d43674625..15c8558ea7 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -22,5 +22,6 @@ STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
+SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 500cae05fb..d8d0df7e41 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,14 +16,18 @@
import sys
sys.dont_write_bytecode = True
+from synapse.python_dependencies import check_requirements
+
+if __name__ == '__main__':
+ check_requirements()
+from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import (
- prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
+ are_all_users_on_domain, UpgradeDatabaseException,
)
from synapse.server import HomeServer
-from synapse.python_dependencies import check_requirements
from twisted.internet import reactor
from twisted.application import service
@@ -31,16 +35,17 @@ from twisted.enterprise import adbapi
from twisted.web.resource import Resource
from twisted.web.static import File
from twisted.web.server import Site
+from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect
-from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.http.server_key_resource import LocalKey
+from synapse.rest.key.v1.server_key_resource import LocalKey
+from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
- SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
- STATIC_PREFIX
+ SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX,
+ SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
@@ -59,9 +64,9 @@ import os
import re
import resource
import subprocess
-import sqlite3
-logger = logging.getLogger(__name__)
+
+logger = logging.getLogger("synapse.app.homeserver")
class SynapseHomeServer(HomeServer):
@@ -78,9 +83,6 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self):
return JsonResource(self)
- def build_resource_for_app_services(self):
- return AppServiceRestResource(self)
-
def build_resource_for_web_client(self):
import syweb
syweb_path = os.path.dirname(syweb.__file__)
@@ -101,6 +103,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_server_key(self):
return LocalKey(self)
+ def build_resource_for_server_key_v2(self):
+ return KeyApiV2Resource(self)
+
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
@@ -108,13 +113,11 @@ class SynapseHomeServer(HomeServer):
return None
def build_db_pool(self):
+ name = self.db_config["name"]
+
return adbapi.ConnectionPool(
- "sqlite3", self.get_db_name(),
- check_same_thread=False,
- cp_min=1,
- cp_max=1,
- cp_openfun=prepare_database, # Prepare the database for each conn
- # so that :memory: sqlite works
+ name,
+ **self.db_config.get("args", {})
)
def create_resource_tree(self, redirect_root_to_web_client):
@@ -140,8 +143,8 @@ class SynapseHomeServer(HomeServer):
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
+ (SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
- (APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
]
@@ -226,7 +229,11 @@ class SynapseHomeServer(HomeServer):
if not config.no_tls and config.bind_port is not None:
reactor.listenSSL(
config.bind_port,
- Site(self.root_resource),
+ SynapseSite(
+ "synapse.access.https",
+ config,
+ self.root_resource,
+ ),
self.tls_context_factory,
interface=config.bind_host
)
@@ -235,7 +242,11 @@ class SynapseHomeServer(HomeServer):
if config.unsecure_port is not None:
reactor.listenTCP(
config.unsecure_port,
- Site(self.root_resource),
+ SynapseSite(
+ "synapse.access.http",
+ config,
+ self.root_resource,
+ ),
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.unsecure_port)
@@ -243,10 +254,43 @@ class SynapseHomeServer(HomeServer):
metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None:
reactor.listenTCP(
- config.metrics_port, Site(metrics_resource), interface="127.0.0.1",
+ config.metrics_port,
+ SynapseSite(
+ "synapse.access.metrics",
+ config,
+ metrics_resource,
+ ),
+ interface="127.0.0.1",
)
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
+ def run_startup_checks(self, db_conn, database_engine):
+ all_users_native = are_all_users_on_domain(
+ db_conn.cursor(), database_engine, self.hostname
+ )
+ if not all_users_native:
+ quit_with_error(
+ "Found users in database not native to %s!\n"
+ "You cannot changed a synapse server_name after it's been configured"
+ % (self.hostname,)
+ )
+
+ try:
+ database_engine.check_database(db_conn.cursor())
+ except IncorrectDatabaseSetup as e:
+ quit_with_error(e.message)
+
+
+def quit_with_error(error_string):
+ message_lines = error_string.split("\n")
+ line_length = max([len(l) for l in message_lines]) + 2
+ sys.stderr.write("*" * line_length + '\n')
+ for line in message_lines:
+ if line.strip():
+ sys.stderr.write(" %s\n" % (line.strip(),))
+ sys.stderr.write("*" * line_length + '\n')
+ sys.exit(1)
+
def get_version_string():
try:
@@ -358,29 +402,39 @@ def setup(config_options):
tls_context_factory = context_factory.ServerContextFactory(config)
+ database_engine = create_engine(config.database_config["name"])
+ config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
+
hs = SynapseHomeServer(
config.server_name,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
- db_name=config.database_path,
+ db_config=config.database_config,
tls_context_factory=tls_context_factory,
config=config,
content_addr=config.content_addr,
version_string=version_string,
+ database_engine=database_engine,
)
hs.create_resource_tree(
redirect_root_to_web_client=True,
)
- db_name = hs.get_db_name()
-
- logger.info("Preparing database: %s...", db_name)
+ logger.info("Preparing database: %r...", config.database_config)
try:
- with sqlite3.connect(db_name) as db_conn:
- prepare_sqlite3_database(db_conn)
- prepare_database(db_conn)
+ db_conn = database_engine.module.connect(
+ **{
+ k: v for k, v in config.database_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ )
+
+ database_engine.prepare_database(db_conn)
+ hs.run_startup_checks(db_conn, database_engine)
+
+ db_conn.commit()
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"
@@ -389,7 +443,7 @@ def setup(config_options):
)
sys.exit(1)
- logger.info("Database prepared in %s.", db_name)
+ logger.info("Database prepared in %r.", config.database_config)
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
@@ -423,6 +477,24 @@ class SynapseService(service.Service):
return self._port.stopListening()
+class SynapseSite(Site):
+ """
+ Subclass of a twisted http Site that does access logging with python's
+ standard logging
+ """
+ def __init__(self, logger_name, config, resource, *args, **kwargs):
+ Site.__init__(self, resource, *args, **kwargs)
+ if config.captcha_ip_origin_is_x_forwarded:
+ self._log_formatter = proxiedLogFormatter
+ else:
+ self._log_formatter = combinedLogFormatter
+ self.access_logger = logging.getLogger(logger_name)
+
+ def log(self, request):
+ line = self._log_formatter(self._logDateTime, request)
+ self.access_logger.info(line)
+
+
def run(hs):
def in_thread():
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 3a70a248dc..0a2b0d6fcd 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -18,15 +18,18 @@ import sys
import os
import subprocess
import signal
+import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
-PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m"
+CONFIG = yaml.load(open(CONFIGFILE))
+PIDFILE = CONFIG["pid_file"]
+
def start():
if not os.path.exists(CONFIGFILE):
@@ -40,7 +43,7 @@ def start():
sys.exit(1)
print "Starting ...",
args = SYNAPSE
- args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE])
+ args.extend(["--daemonize", "-c", CONFIGFILE])
subprocess.check_call(args)
print GREEN + "started" + NORMAL
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a268a6bcc4..63a18b802b 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -20,6 +20,50 @@ import re
logger = logging.getLogger(__name__)
+class ApplicationServiceState(object):
+ DOWN = "down"
+ UP = "up"
+
+
+class AppServiceTransaction(object):
+ """Represents an application service transaction."""
+
+ def __init__(self, service, id, events):
+ self.service = service
+ self.id = id
+ self.events = events
+
+ def send(self, as_api):
+ """Sends this transaction using the provided AS API interface.
+
+ Args:
+ as_api(ApplicationServiceApi): The API to use to send.
+ Returns:
+ A Deferred which resolves to True if the transaction was sent.
+ """
+ return as_api.push_bulk(
+ service=self.service,
+ events=self.events,
+ txn_id=self.id
+ )
+
+ def complete(self, store):
+ """Completes this transaction as successful.
+
+ Marks this transaction ID on the application service and removes the
+ transaction contents from the database.
+
+ Args:
+ store: The database store to operate on.
+ Returns:
+ A Deferred which resolves to True if the transaction was completed.
+ """
+ return store.complete_appservice_txn(
+ service=self.service,
+ txn_id=self.id
+ )
+
+
class ApplicationService(object):
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -35,13 +79,13 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
- sender=None, txn_id=None):
+ sender=None, id=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
- self.txn_id = txn_id
+ self.id = id
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@@ -51,7 +95,7 @@ class ApplicationService(object):
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
- return None
+ namespaces = {}
for ns in ApplicationService.NS_LIST:
if ns not in namespaces:
@@ -155,7 +199,10 @@ class ApplicationService(object):
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id):
- return self._matches_regex(user_id, ApplicationService.NS_USERS)
+ return (
+ self._matches_regex(user_id, ApplicationService.NS_USERS)
+ or user_id == self.sender
+ )
def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
@@ -164,7 +211,10 @@ class ApplicationService(object):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id):
- return self._is_exclusive(ApplicationService.NS_USERS, user_id)
+ return (
+ self._is_exclusive(ApplicationService.NS_USERS, user_id)
+ or user_id == self.sender
+ )
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index c2179f8d55..2a9becccb3 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -72,14 +72,19 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
- def push_bulk(self, service, events):
+ def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)
+ if txn_id is None:
+ logger.warning("push_bulk: Missing txn ID sending events to %s",
+ service.url)
+ txn_id = str(0)
+ txn_id = str(txn_id)
+
uri = service.url + ("/transactions/%s" %
- urllib.quote(str(0))) # TODO txn_ids
- response = None
+ urllib.quote(txn_id))
try:
- response = yield self.put_json(
+ yield self.put_json(
uri=uri,
json_body={
"events": events
@@ -87,9 +92,8 @@ class ApplicationServiceApi(SimpleHttpClient):
args={
"access_token": service.hs_token
})
- if response: # just an empty json object
- # TODO: Mark txn as sent successfully
- defer.returnValue(True)
+ defer.returnValue(True)
+ return
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
@@ -97,8 +101,8 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
- def push(self, service, event):
- response = yield self.push_bulk(service, [event])
+ def push(self, service, event, txn_id=None):
+ response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response)
def _serialize(self, events):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
new file mode 100644
index 0000000000..59b0b1f4ac
--- /dev/null
+++ b/synapse/appservice/scheduler.py
@@ -0,0 +1,254 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module controls the reliability for application service transactions.
+
+The nominal flow through this module looks like:
+ __________
+1---ASa[e]-->| Service |--> Queue ASa[f]
+2----ASb[e]->| Queuer |
+3--ASa[f]--->|__________|-----------+ ASa[e], ASb[e]
+ V
+ -````````- +------------+
+ |````````|<--StoreTxn-|Transaction |
+ |Database| | Controller |---> SEND TO AS
+ `--------` +------------+
+What happens on SEND TO AS depends on the state of the Application Service:
+ - If the AS is marked as DOWN, do nothing.
+ - If the AS is marked as UP, send the transaction.
+ * SUCCESS : Increment where the AS is up to txn-wise and nuke the txn
+ contents from the db.
+ * FAILURE : Marked AS as DOWN and start Recoverer.
+
+Recoverer attempts to recover ASes who have died. The flow for this looks like:
+ ,--------------------- backoff++ --------------.
+ V |
+ START ---> Wait exp ------> Get oldest txn ID from ----> FAILURE
+ backoff DB and try to send it
+ ^ |___________
+Mark AS as | V
+UP & quit +---------- YES SUCCESS
+ | | |
+ NO <--- Have more txns? <------ Mark txn success & nuke <-+
+ from db; incr AS pos.
+ Reset backoff.
+
+This is all tied together by the AppServiceScheduler which DIs the required
+components.
+"""
+
+from synapse.appservice import ApplicationServiceState
+from twisted.internet import defer
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class AppServiceScheduler(object):
+ """ Public facing API for this module. Does the required DI to tie the
+ components together. This also serves as the "event_pool", which in this
+ case is a simple array.
+ """
+
+ def __init__(self, clock, store, as_api):
+ self.clock = clock
+ self.store = store
+ self.as_api = as_api
+
+ def create_recoverer(service, callback):
+ return _Recoverer(clock, store, as_api, service, callback)
+
+ self.txn_ctrl = _TransactionController(
+ clock, store, as_api, create_recoverer
+ )
+ self.queuer = _ServiceQueuer(self.txn_ctrl)
+
+ @defer.inlineCallbacks
+ def start(self):
+ logger.info("Starting appservice scheduler")
+ # check for any DOWN ASes and start recoverers for them.
+ recoverers = yield _Recoverer.start(
+ self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
+ )
+ self.txn_ctrl.add_recoverers(recoverers)
+
+ def submit_event_for_as(self, service, event):
+ self.queuer.enqueue(service, event)
+
+
+class _ServiceQueuer(object):
+ """Queues events for the same application service together, sending
+ transactions as soon as possible. Once a transaction is sent successfully,
+ this schedules any other events in the queue to run.
+ """
+
+ def __init__(self, txn_ctrl):
+ self.queued_events = {} # dict of {service_id: [events]}
+ self.pending_requests = {} # dict of {service_id: Deferred}
+ self.txn_ctrl = txn_ctrl
+
+ def enqueue(self, service, event):
+ # if this service isn't being sent something
+ if not self.pending_requests.get(service.id):
+ self._send_request(service, [event])
+ else:
+ # add to queue for this service
+ if service.id not in self.queued_events:
+ self.queued_events[service.id] = []
+ self.queued_events[service.id].append(event)
+
+ def _send_request(self, service, events):
+ # send request and add callbacks
+ d = self.txn_ctrl.send(service, events)
+ d.addBoth(self._on_request_finish)
+ d.addErrback(self._on_request_fail)
+ self.pending_requests[service.id] = d
+
+ def _on_request_finish(self, service):
+ self.pending_requests[service.id] = None
+ # if there are queued events, then send them.
+ if (service.id in self.queued_events
+ and len(self.queued_events[service.id]) > 0):
+ self._send_request(service, self.queued_events[service.id])
+ self.queued_events[service.id] = []
+
+ def _on_request_fail(self, err):
+ logger.error("AS request failed: %s", err)
+
+
+class _TransactionController(object):
+
+ def __init__(self, clock, store, as_api, recoverer_fn):
+ self.clock = clock
+ self.store = store
+ self.as_api = as_api
+ self.recoverer_fn = recoverer_fn
+ # keep track of how many recoverers there are
+ self.recoverers = []
+
+ @defer.inlineCallbacks
+ def send(self, service, events):
+ try:
+ txn = yield self.store.create_appservice_txn(
+ service=service,
+ events=events
+ )
+ service_is_up = yield self._is_service_up(service)
+ if service_is_up:
+ sent = yield txn.send(self.as_api)
+ if sent:
+ txn.complete(self.store)
+ else:
+ self._start_recoverer(service)
+ except Exception as e:
+ logger.exception(e)
+ self._start_recoverer(service)
+ # request has finished
+ defer.returnValue(service)
+
+ @defer.inlineCallbacks
+ def on_recovered(self, recoverer):
+ self.recoverers.remove(recoverer)
+ logger.info("Successfully recovered application service AS ID %s",
+ recoverer.service.id)
+ logger.info("Remaining active recoverers: %s", len(self.recoverers))
+ yield self.store.set_appservice_state(
+ recoverer.service,
+ ApplicationServiceState.UP
+ )
+
+ def add_recoverers(self, recoverers):
+ for r in recoverers:
+ self.recoverers.append(r)
+ if len(recoverers) > 0:
+ logger.info("New active recoverers: %s", len(self.recoverers))
+
+ @defer.inlineCallbacks
+ def _start_recoverer(self, service):
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ logger.info(
+ "Application service falling behind. Starting recoverer. AS ID %s",
+ service.id
+ )
+ recoverer = self.recoverer_fn(service, self.on_recovered)
+ self.add_recoverers([recoverer])
+ recoverer.recover()
+
+ @defer.inlineCallbacks
+ def _is_service_up(self, service):
+ state = yield self.store.get_appservice_state(service)
+ defer.returnValue(state == ApplicationServiceState.UP or state is None)
+
+
+class _Recoverer(object):
+
+ @staticmethod
+ @defer.inlineCallbacks
+ def start(clock, store, as_api, callback):
+ services = yield store.get_appservices_by_state(
+ ApplicationServiceState.DOWN
+ )
+ recoverers = [
+ _Recoverer(clock, store, as_api, s, callback) for s in services
+ ]
+ for r in recoverers:
+ logger.info("Starting recoverer for AS ID %s which was marked as "
+ "DOWN", r.service.id)
+ r.recover()
+ defer.returnValue(recoverers)
+
+ def __init__(self, clock, store, as_api, service, callback):
+ self.clock = clock
+ self.store = store
+ self.as_api = as_api
+ self.service = service
+ self.callback = callback
+ self.backoff_counter = 1
+
+ def recover(self):
+ self.clock.call_later((2 ** self.backoff_counter), self.retry)
+
+ def _backoff(self):
+ # cap the backoff to be around 18h => (2^16) = 65536 secs
+ if self.backoff_counter < 16:
+ self.backoff_counter += 1
+ self.recover()
+
+ @defer.inlineCallbacks
+ def retry(self):
+ try:
+ txn = yield self.store.get_oldest_unsent_txn(self.service)
+ if txn:
+ logger.info("Retrying transaction %s for AS ID %s",
+ txn.id, txn.service.id)
+ sent = yield txn.send(self.as_api)
+ if sent:
+ yield txn.complete(self.store)
+ # reset the backoff counter and retry immediately
+ self.backoff_counter = 1
+ yield self.retry()
+ else:
+ self._backoff()
+ else:
+ self._set_service_recovered()
+ except Exception as e:
+ logger.exception(e)
+ self._backoff()
+
+ def _set_service_recovered(self):
+ self.callback(self)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 87cdbf1d30..2807abbc90 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -14,9 +14,10 @@
# limitations under the License.
import argparse
-import sys
import os
import yaml
+import sys
+from textwrap import dedent
class ConfigError(Exception):
@@ -24,18 +25,35 @@ class ConfigError(Exception):
class Config(object):
- def __init__(self, args):
- pass
@staticmethod
- def parse_size(string):
+ def parse_size(value):
+ if isinstance(value, int) or isinstance(value, long):
+ return value
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
- suffix = string[-1]
+ suffix = value[-1]
if suffix in sizes:
- string = string[:-1]
+ value = value[:-1]
size = sizes[suffix]
- return int(string) * size
+ return int(value) * size
+
+ @staticmethod
+ def parse_duration(value):
+ if isinstance(value, int) or isinstance(value, long):
+ return value
+ second = 1000
+ hour = 60 * 60 * second
+ day = 24 * hour
+ week = 7 * day
+ year = 365 * day
+ sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
+ size = 1
+ suffix = value[-1]
+ if suffix in sizes:
+ value = value[:-1]
+ size = sizes[suffix]
+ return int(value) * size
@staticmethod
def abspath(file_path):
@@ -86,83 +104,130 @@ class Config(object):
with open(file_path) as file_stream:
return yaml.load(file_stream)
- @classmethod
- def add_arguments(cls, parser):
- pass
+ def invoke_all(self, name, *args, **kargs):
+ results = []
+ for cls in type(self).mro():
+ if name in cls.__dict__:
+ results.append(getattr(cls, name)(self, *args, **kargs))
+ return results
- @classmethod
- def generate_config(cls, args, config_dir_path):
- pass
+ def generate_config(self, config_dir_path, server_name):
+ default_config = "# vim:ft=yaml\n"
+
+ default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
+ "default_config", config_dir_path, server_name
+ ))
+
+ config = yaml.load(default_config)
+
+ return default_config, config
@classmethod
def load_config(cls, description, argv, generate_section=None):
+ obj = cls()
+
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
+ action="append",
metavar="CONFIG_FILE",
help="Specify config file"
)
config_parser.add_argument(
"--generate-config",
action="store_true",
- help="Generate config file"
+ help="Generate a config file for the server name"
+ )
+ config_parser.add_argument(
+ "-H", "--server-name",
+ help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
- "Must specify where to generate the config file"
+ "Must supply a config file.\nA config file can be automatically"
+ " generated using \"--generate-config -h SERVER_NAME"
+ " -c CONFIG-FILE\""
+ )
+
+ config_dir_path = os.path.dirname(config_args.config_path[0])
+ config_dir_path = os.path.abspath(config_dir_path)
+
+ server_name = config_args.server_name
+ if not server_name:
+ print "Most specify a server_name to a generate config for."
+ sys.exit(1)
+ (config_path,) = config_args.config_path
+ if not os.path.exists(config_dir_path):
+ os.makedirs(config_dir_path)
+ if os.path.exists(config_path):
+ print "Config file %r already exists" % (config_path,)
+ yaml_config = cls.read_config_file(config_path)
+ yaml_name = yaml_config["server_name"]
+ if server_name != yaml_name:
+ print (
+ "Config file %r has a different server_name: "
+ " %r != %r" % (config_path, server_name, yaml_name)
+ )
+ sys.exit(1)
+ config_bytes, config = obj.generate_config(
+ config_dir_path, server_name
)
- config_dir_path = os.path.dirname(config_args.config_path)
- if os.path.exists(config_args.config_path):
- defaults = cls.read_config_file(config_args.config_path)
- else:
- defaults = {}
- else:
- if config_args.config_path:
- defaults = cls.read_config_file(config_args.config_path)
- else:
- defaults = {}
+ config.update(yaml_config)
+ print "Generating any missing keys for %r" % (server_name,)
+ obj.invoke_all("generate_files", config)
+ sys.exit(0)
+ with open(config_path, "wb") as config_file:
+ config_bytes, config = obj.generate_config(
+ config_dir_path, server_name
+ )
+ obj.invoke_all("generate_files", config)
+ config_file.write(config_bytes)
+ print (
+ "A config file has been generated in %s for server name"
+ " '%s' with corresponding SSL keys and self-signed"
+ " certificates. Please review this file and customise it to"
+ " your needs."
+ ) % (config_path, server_name)
+ print (
+ "If this server name is incorrect, you will need to regenerate"
+ " the SSL certificates"
+ )
+ sys.exit(0)
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
- cls.add_arguments(parser)
- parser.set_defaults(**defaults)
+ obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
- if config_args.generate_config:
- config_dir_path = os.path.dirname(config_args.config_path)
- config_dir_path = os.path.abspath(config_dir_path)
- if not os.path.exists(config_dir_path):
- os.makedirs(config_dir_path)
- cls.generate_config(args, config_dir_path)
- config = {}
- for key, value in vars(args).items():
- if (key not in set(["config_path", "generate_config"])
- and value is not None):
- config[key] = value
- with open(config_args.config_path, "w") as config_file:
- # TODO(paul) it would be lovely if we wrote out vim- and emacs-
- # style mode markers into the file, to hint to people that
- # this is a YAML file.
- yaml.dump(config, config_file, default_flow_style=False)
- print (
- "A config file has been generated in %s for server name"
- " '%s' with corresponding SSL keys and self-signed"
- " certificates. Please review this file and customise it to"
- " your needs."
- ) % (
- config_args.config_path, config['server_name']
+ if not config_args.config_path:
+ config_parser.error(
+ "Must supply a config file.\nA config file can be automatically"
+ " generated using \"--generate-config -h SERVER_NAME"
+ " -c CONFIG-FILE\""
)
- print (
- "If this server name is incorrect, you will need to regenerate"
- " the SSL certificates"
- )
- sys.exit(0)
- return cls(args)
+ config_dir_path = os.path.dirname(config_args.config_path[0])
+ config_dir_path = os.path.abspath(config_dir_path)
+
+ specified_config = {}
+ for config_path in config_args.config_path:
+ yaml_config = cls.read_config_file(config_path)
+ specified_config.update(yaml_config)
+
+ server_name = specified_config["server_name"]
+ _, config = obj.generate_config(config_dir_path, server_name)
+ config.pop("log_config")
+ config.update(specified_config)
+
+ obj.invoke_all("read_config", config)
+
+ obj.invoke_all("read_arguments", args)
+
+ return obj
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
new file mode 100644
index 0000000000..38f41933b7
--- /dev/null
+++ b/synapse/config/appservice.py
@@ -0,0 +1,27 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class AppServiceConfig(Config):
+
+ def read_config(self, config):
+ self.app_service_config_files = config.get("app_service_config_files", [])
+
+ def default_config(cls, config_dir_path, server_name):
+ return """\
+ # A list of application service config file to use
+ app_service_config_files: []
+ """
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 7e21c7414d..d8fe577e34 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -17,35 +17,35 @@ from ._base import Config
class CaptchaConfig(Config):
- def __init__(self, args):
- super(CaptchaConfig, self).__init__(args)
- self.recaptcha_private_key = args.recaptcha_private_key
- self.enable_registration_captcha = args.enable_registration_captcha
+ def read_config(self, config):
+ self.recaptcha_private_key = config["recaptcha_private_key"]
+ self.recaptcha_public_key = config["recaptcha_public_key"]
+ self.enable_registration_captcha = config["enable_registration_captcha"]
+ # XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = (
- args.captcha_ip_origin_is_x_forwarded
- )
- self.captcha_bypass_secret = args.captcha_bypass_secret
-
- @classmethod
- def add_arguments(cls, parser):
- super(CaptchaConfig, cls).add_arguments(parser)
- group = parser.add_argument_group("recaptcha")
- group.add_argument(
- "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
- help="The matching private key for the web client's public key."
- )
- group.add_argument(
- "--enable-registration-captcha", type=bool, default=False,
- help="Enables ReCaptcha checks when registering, preventing signup"
- + " unless a captcha is answered. Requires a valid ReCaptcha "
- + "public/private key."
- )
- group.add_argument(
- "--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
- help="When checking captchas, use the X-Forwarded-For (XFF) header"
- + " as the client IP and not the actual client IP."
- )
- group.add_argument(
- "--captcha_bypass_secret", type=str,
- help="A secret key used to bypass the captcha test entirely."
+ config["captcha_ip_origin_is_x_forwarded"]
)
+ self.captcha_bypass_secret = config.get("captcha_bypass_secret")
+
+ def default_config(self, config_dir_path, server_name):
+ return """\
+ ## Captcha ##
+
+ # This Home Server's ReCAPTCHA public key.
+ recaptcha_private_key: "YOUR_PUBLIC_KEY"
+
+ # This Home Server's ReCAPTCHA private key.
+ recaptcha_public_key: "YOUR_PRIVATE_KEY"
+
+ # Enables ReCaptcha checks when registering, preventing signup
+ # unless a captcha is answered. Requires a valid ReCaptcha
+ # public/private key.
+ enable_registration_captcha: False
+
+ # When checking captchas, use the X-Forwarded-For (XFF) header
+ # as the client IP and not the actual client IP.
+ captcha_ip_origin_is_x_forwarded: False
+
+ # A secret key used to bypass the captcha test entirely.
+ #captcha_bypass_secret: "YOUR_SECRET_HERE"
+ """
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 87efe54645..f0611e8884 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -14,32 +14,66 @@
# limitations under the License.
from ._base import Config
-import os
class DatabaseConfig(Config):
- def __init__(self, args):
- super(DatabaseConfig, self).__init__(args)
- if args.database_path == ":memory:":
- self.database_path = ":memory:"
+
+ def read_config(self, config):
+ self.event_cache_size = self.parse_size(
+ config.get("event_cache_size", "10K")
+ )
+
+ self.database_config = config.get("database")
+
+ if self.database_config is None:
+ self.database_config = {
+ "name": "sqlite3",
+ "args": {},
+ }
+
+ name = self.database_config.get("name", None)
+ if name == "psycopg2":
+ pass
+ elif name == "sqlite3":
+ self.database_config.setdefault("args", {}).update({
+ "cp_min": 1,
+ "cp_max": 1,
+ "check_same_thread": False,
+ })
else:
- self.database_path = self.abspath(args.database_path)
- self.event_cache_size = self.parse_size(args.event_cache_size)
+ raise RuntimeError("Unsupported database type '%s'" % (name,))
+
+ self.set_databasepath(config.get("database_path"))
+
+ def default_config(self, config, config_dir_path):
+ database_path = self.abspath("homeserver.db")
+ return """\
+ # Database configuration
+ database:
+ # The database engine name
+ name: "sqlite3"
+ # Arguments to pass to the engine
+ args:
+ # Path to the database
+ database: "%(database_path)s"
- @classmethod
- def add_arguments(cls, parser):
- super(DatabaseConfig, cls).add_arguments(parser)
+ # Number of events to cache in memory.
+ event_cache_size: "10K"
+ """ % locals()
+
+ def read_arguments(self, args):
+ self.set_databasepath(args.database_path)
+
+ def set_databasepath(self, database_path):
+ if database_path != ":memory:":
+ database_path = self.abspath(database_path)
+ if self.database_config.get("name", None) == "sqlite3":
+ if database_path is not None:
+ self.database_config["args"]["database"] = database_path
+
+ def add_arguments(self, parser):
db_group = parser.add_argument_group("database")
db_group.add_argument(
- "-d", "--database-path", default="homeserver.db",
- help="The database name."
+ "-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
+ help="The path to a sqlite database to use."
)
- db_group.add_argument(
- "--event-cache-size", default="100K",
- help="Number of events to cache in memory."
- )
-
- @classmethod
- def generate_config(cls, args, config_dir_path):
- super(DatabaseConfig, cls).generate_config(args, config_dir_path)
- args.database_path = os.path.abspath(args.database_path)
diff --git a/synapse/config/email.py b/synapse/config/email.py
deleted file mode 100644
index f0854f8c37..0000000000
--- a/synapse/config/email.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from ._base import Config
-
-
-class EmailConfig(Config):
-
- def __init__(self, args):
- super(EmailConfig, self).__init__(args)
- self.email_from_address = args.email_from_address
- self.email_smtp_server = args.email_smtp_server
-
- @classmethod
- def add_arguments(cls, parser):
- super(EmailConfig, cls).add_arguments(parser)
- email_group = parser.add_argument_group("email")
- email_group.add_argument(
- "--email-from-address",
- default="FROM@EXAMPLE.COM",
- help="The address to send emails from (e.g. for password resets)."
- )
- email_group.add_argument(
- "--email-smtp-server",
- default="",
- help=(
- "The SMTP server to send emails from (e.g. for password"
- " resets)."
- )
- )
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 241afdf872..fe0ccb6eb7 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -20,19 +20,22 @@ from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig
-from .email import EmailConfig
from .voip import VoipConfig
from .registration import RegistrationConfig
from .metrics import MetricsConfig
+from .appservice import AppServiceConfig
+from .key import KeyConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- EmailConfig, VoipConfig, RegistrationConfig,
- MetricsConfig,):
+ VoipConfig, RegistrationConfig,
+ MetricsConfig, AppServiceConfig, KeyConfig,):
pass
if __name__ == '__main__':
import sys
- HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")
+ sys.stdout.write(
+ HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
+ )
diff --git a/synapse/config/key.py b/synapse/config/key.py
new file mode 100644
index 0000000000..0494c0cb77
--- /dev/null
+++ b/synapse/config/key.py
@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from ._base import Config, ConfigError
+import syutil.crypto.signing_key
+from syutil.crypto.signing_key import (
+ is_signing_algorithm_supported, decode_verify_key_bytes
+)
+from syutil.base64util import decode_base64
+from synapse.util.stringutils import random_string
+
+
+class KeyConfig(Config):
+
+ def read_config(self, config):
+ self.signing_key = self.read_signing_key(config["signing_key_path"])
+ self.old_signing_keys = self.read_old_signing_keys(
+ config["old_signing_keys"]
+ )
+ self.key_refresh_interval = self.parse_duration(
+ config["key_refresh_interval"]
+ )
+ self.perspectives = self.read_perspectives(
+ config["perspectives"]
+ )
+
+ def default_config(self, config_dir_path, server_name):
+ base_key_name = os.path.join(config_dir_path, server_name)
+ return """\
+ ## Signing Keys ##
+
+ # Path to the signing key to sign messages with
+ signing_key_path: "%(base_key_name)s.signing.key"
+
+ # The keys that the server used to sign messages with but won't use
+ # to sign new messages. E.g. it has lost its private key
+ old_signing_keys: {}
+ # "ed25519:auto":
+ # # Base64 encoded public key
+ # key: "The public part of your old signing key."
+ # # Millisecond POSIX timestamp when the key expired.
+ # expired_ts: 123456789123
+
+ # How long key response published by this server is valid for.
+ # Used to set the valid_until_ts in /key/v2 APIs.
+ # Determines how quickly servers will query to check which keys
+ # are still valid.
+ key_refresh_interval: "1d" # 1 Day.
+
+ # The trusted servers to download signing keys from.
+ perspectives:
+ servers:
+ "matrix.org":
+ verify_keys:
+ "ed25519:auto":
+ key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+ """ % locals()
+
+ def read_perspectives(self, perspectives_config):
+ servers = {}
+ for server_name, server_config in perspectives_config["servers"].items():
+ for key_id, key_data in server_config["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ servers.setdefault(server_name, {})[key_id] = verify_key
+ return servers
+
+ def read_signing_key(self, signing_key_path):
+ signing_keys = self.read_file(signing_key_path, "signing_key")
+ try:
+ return syutil.crypto.signing_key.read_signing_keys(
+ signing_keys.splitlines(True)
+ )
+ except Exception:
+ raise ConfigError(
+ "Error reading signing_key."
+ " Try running again with --generate-config"
+ )
+
+ def read_old_signing_keys(self, old_signing_keys):
+ keys = {}
+ for key_id, key_data in old_signing_keys.items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.expired_ts = key_data["expired_ts"]
+ keys[key_id] = verify_key
+ else:
+ raise ConfigError(
+ "Unsupported signing algorithm for old key: %r" % (key_id,)
+ )
+ return keys
+
+ def generate_files(self, config):
+ signing_key_path = config["signing_key_path"]
+ if not os.path.exists(signing_key_path):
+ with open(signing_key_path, "w") as signing_key_file:
+ key_id = "a_" + random_string(4)
+ syutil.crypto.signing_key.write_signing_keys(
+ signing_key_file,
+ (syutil.crypto.signing_key.generate_signing_key(key_id),),
+ )
+ else:
+ signing_keys = self.read_file(signing_key_path, "signing_key")
+ if len(signing_keys.split("\n")[0].split()) == 1:
+ # handle keys in the old format.
+ key_id = "a_" + random_string(4)
+ key = syutil.crypto.signing_key.decode_signing_key_base64(
+ syutil.crypto.signing_key.NACL_ED25519,
+ key_id,
+ signing_keys.split("\n")[0]
+ )
+ with open(signing_key_path, "w") as signing_key_file:
+ syutil.crypto.signing_key.write_signing_keys(
+ signing_key_file,
+ (key,),
+ )
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 63c8e36930..fa542623b7 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -19,25 +19,88 @@ from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
import yaml
+from string import Template
+import os
+
+
+DEFAULT_LOG_CONFIG = Template("""
+version: 1
+
+formatters:
+ precise:
+ format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
+- %(message)s'
+
+filters:
+ context:
+ (): synapse.util.logcontext.LoggingContextFilter
+ request: ""
+
+handlers:
+ file:
+ class: logging.handlers.RotatingFileHandler
+ formatter: precise
+ filename: ${log_file}
+ maxBytes: 104857600
+ backupCount: 10
+ filters: [context]
+ level: INFO
+ console:
+ class: logging.StreamHandler
+ formatter: precise
+
+loggers:
+ synapse:
+ level: INFO
+
+ synapse.storage.SQL:
+ level: INFO
+
+root:
+ level: INFO
+ handlers: [file, console]
+""")
class LoggingConfig(Config):
- def __init__(self, args):
- super(LoggingConfig, self).__init__(args)
- self.verbosity = int(args.verbose) if args.verbose else None
- self.log_config = self.abspath(args.log_config)
- self.log_file = self.abspath(args.log_file)
- @classmethod
+ def read_config(self, config):
+ self.verbosity = config.get("verbose", 0)
+ self.log_config = self.abspath(config.get("log_config"))
+ self.log_file = self.abspath(config.get("log_file"))
+
+ def default_config(self, config_dir_path, server_name):
+ log_file = self.abspath("homeserver.log")
+ log_config = self.abspath(
+ os.path.join(config_dir_path, server_name + ".log.config")
+ )
+ return """
+ # Logging verbosity level.
+ verbose: 0
+
+ # File to write logging to
+ log_file: "%(log_file)s"
+
+ # A yaml python logging config file
+ log_config: "%(log_config)s"
+ """ % locals()
+
+ def read_arguments(self, args):
+ if args.verbose is not None:
+ self.verbosity = args.verbose
+ if args.log_config is not None:
+ self.log_config = args.log_config
+ if args.log_file is not None:
+ self.log_file = args.log_file
+
def add_arguments(cls, parser):
- super(LoggingConfig, cls).add_arguments(parser)
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
help="The verbosity level."
)
logging_group.add_argument(
- '-f', '--log-file', dest="log_file", default="homeserver.log",
+ '-f', '--log-file', dest="log_file",
help="File to log to."
)
logging_group.add_argument(
@@ -45,6 +108,14 @@ class LoggingConfig(Config):
help="Python logging config file"
)
+ def generate_files(self, config):
+ log_config = config.get("log_config")
+ if log_config and not os.path.exists(log_config):
+ with open(log_config, "wb") as log_config_file:
+ log_config_file.write(
+ DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+ )
+
def setup_logging(self):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@@ -78,7 +149,6 @@ class LoggingConfig(Config):
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
- logger.info("Test")
else:
with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 901a429c76..71a1b1d189 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -17,20 +17,17 @@ from ._base import Config
class MetricsConfig(Config):
- def __init__(self, args):
- super(MetricsConfig, self).__init__(args)
- self.enable_metrics = args.enable_metrics
- self.metrics_port = args.metrics_port
+ def read_config(self, config):
+ self.enable_metrics = config["enable_metrics"]
+ self.metrics_port = config.get("metrics_port")
- @classmethod
- def add_arguments(cls, parser):
- super(MetricsConfig, cls).add_arguments(parser)
- metrics_group = parser.add_argument_group("metrics")
- metrics_group.add_argument(
- '--enable-metrics', dest="enable_metrics", action="store_true",
- help="Enable collection and rendering of performance metrics"
- )
- metrics_group.add_argument(
- '--metrics-port', metavar="PORT", type=int,
- help="Separate port to accept metrics requests on (on localhost)"
- )
+ def default_config(self, config_dir_path, server_name):
+ return """\
+ ## Metrics ###
+
+ # Enable collection and rendering of performance metrics
+ enable_metrics: False
+
+ # Separate port to accept metrics requests on (on localhost)
+ # metrics_port: 8081
+ """
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 862c07ef8c..76d9970e5b 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -17,56 +17,42 @@ from ._base import Config
class RatelimitConfig(Config):
- def __init__(self, args):
- super(RatelimitConfig, self).__init__(args)
- self.rc_messages_per_second = args.rc_messages_per_second
- self.rc_message_burst_count = args.rc_message_burst_count
+ def read_config(self, config):
+ self.rc_messages_per_second = config["rc_messages_per_second"]
+ self.rc_message_burst_count = config["rc_message_burst_count"]
- self.federation_rc_window_size = args.federation_rc_window_size
- self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
- self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
- self.federation_rc_reject_limit = args.federation_rc_reject_limit
- self.federation_rc_concurrent = args.federation_rc_concurrent
+ self.federation_rc_window_size = config["federation_rc_window_size"]
+ self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
+ self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
+ self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
+ self.federation_rc_concurrent = config["federation_rc_concurrent"]
- @classmethod
- def add_arguments(cls, parser):
- super(RatelimitConfig, cls).add_arguments(parser)
- rc_group = parser.add_argument_group("ratelimiting")
- rc_group.add_argument(
- "--rc-messages-per-second", type=float, default=0.2,
- help="number of messages a client can send per second"
- )
- rc_group.add_argument(
- "--rc-message-burst-count", type=float, default=10,
- help="number of message a client can send before being throttled"
- )
+ def default_config(self, config_dir_path, server_name):
+ return """\
+ ## Ratelimiting ##
- rc_group.add_argument(
- "--federation-rc-window-size", type=int, default=10000,
- help="The federation window size in milliseconds",
- )
+ # Number of messages a client can send per second
+ rc_messages_per_second: 0.2
- rc_group.add_argument(
- "--federation-rc-sleep-limit", type=int, default=10,
- help="The number of federation requests from a single server"
- " in a window before the server will delay processing the"
- " request.",
- )
+ # Number of message a client can send before being throttled
+ rc_message_burst_count: 10.0
- rc_group.add_argument(
- "--federation-rc-sleep-delay", type=int, default=500,
- help="The duration in milliseconds to delay processing events from"
- " remote servers by if they go over the sleep limit.",
- )
+ # The federation window size in milliseconds
+ federation_rc_window_size: 1000
- rc_group.add_argument(
- "--federation-rc-reject-limit", type=int, default=50,
- help="The maximum number of concurrent federation requests allowed"
- " from a single server",
- )
+ # The number of federation requests from a single server in a window
+ # before the server will delay processing the request.
+ federation_rc_sleep_limit: 10
- rc_group.add_argument(
- "--federation-rc-concurrent", type=int, default=3,
- help="The number of federation requests to concurrently process"
- " from a single server",
- )
+ # The duration in milliseconds to delay processing events from
+ # remote servers by if they go over the sleep limit.
+ federation_rc_sleep_delay: 500
+
+ # The maximum number of concurrent federation requests allowed
+ # from a single server
+ federation_rc_reject_limit: 50
+
+ # The number of federation requests to concurrently process from a
+ # single server
+ federation_rc_concurrent: 3
+ """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4401e774d1..b39989a87f 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -17,44 +17,44 @@ from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
-import distutils.util
+from distutils.util import strtobool
class RegistrationConfig(Config):
- def __init__(self, args):
- super(RegistrationConfig, self).__init__(args)
-
- # `args.disable_registration` may either be a bool or a string depending
- # on if the option was given a value (e.g. --disable-registration=false
- # would set `args.disable_registration` to "false" not False.)
- self.disable_registration = bool(
- distutils.util.strtobool(str(args.disable_registration))
+ def read_config(self, config):
+ self.disable_registration = not bool(
+ strtobool(str(config["enable_registration"]))
)
- self.registration_shared_secret = args.registration_shared_secret
+ if "disable_registration" in config:
+ self.disable_registration = bool(
+ strtobool(str(config["disable_registration"]))
+ )
- @classmethod
- def add_arguments(cls, parser):
- super(RegistrationConfig, cls).add_arguments(parser)
- reg_group = parser.add_argument_group("registration")
+ self.registration_shared_secret = config.get("registration_shared_secret")
+ def default_config(self, config_dir, server_name):
+ registration_shared_secret = random_string_with_symbols(50)
+ return """\
+ ## Registration ##
+
+ # Enable registration for new users.
+ enable_registration: True
+
+ # If set, allows registration by anyone who also has the shared
+ # secret, even if registration is otherwise disabled.
+ registration_shared_secret: "%(registration_shared_secret)s"
+ """ % locals()
+
+ def add_arguments(self, parser):
+ reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
- "--disable-registration",
- const=True,
- default=True,
- nargs='?',
- help="Disable registration of new users.",
- )
- reg_group.add_argument(
- "--registration-shared-secret", type=str,
- help="If set, allows registration by anyone who also has the shared"
- " secret, even if registration is otherwise disabled.",
+ "--enable-registration", action="store_true", default=None,
+ help="Enable registration for new users."
)
- @classmethod
- def generate_config(cls, args, config_dir_path):
- if args.disable_registration is None:
- args.disable_registration = True
-
- if args.registration_shared_secret is None:
- args.registration_shared_secret = random_string_with_symbols(50)
+ def read_arguments(self, args):
+ if args.enable_registration is not None:
+ self.disable_registration = not bool(
+ strtobool(str(args.enable_registration))
+ )
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index e1827f05e4..adaf4e4bb2 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,32 +17,20 @@ from ._base import Config
class ContentRepositoryConfig(Config):
- def __init__(self, args):
- super(ContentRepositoryConfig, self).__init__(args)
- self.max_upload_size = self.parse_size(args.max_upload_size)
- self.max_image_pixels = self.parse_size(args.max_image_pixels)
- self.media_store_path = self.ensure_directory(args.media_store_path)
+ def read_config(self, config):
+ self.max_upload_size = self.parse_size(config["max_upload_size"])
+ self.max_image_pixels = self.parse_size(config["max_image_pixels"])
+ self.media_store_path = self.ensure_directory(config["media_store_path"])
- def parse_size(self, string):
- sizes = {"K": 1024, "M": 1024 * 1024}
- size = 1
- suffix = string[-1]
- if suffix in sizes:
- string = string[:-1]
- size = sizes[suffix]
- return int(string) * size
+ def default_config(self, config_dir_path, server_name):
+ media_store = self.default_path("media_store")
+ return """
+ # Directory where uploaded images and attachments are stored.
+ media_store_path: "%(media_store)s"
- @classmethod
- def add_arguments(cls, parser):
- super(ContentRepositoryConfig, cls).add_arguments(parser)
- db_group = parser.add_argument_group("content_repository")
- db_group.add_argument(
- "--max-upload-size", default="10M"
- )
- db_group.add_argument(
- "--media-store-path", default=cls.default_path("media_store")
- )
- db_group.add_argument(
- "--max-image-pixels", default="32M",
- help="Maximum number of pixels that will be thumbnailed"
- )
+ # The largest allowed upload size in bytes
+ max_upload_size: "10M"
+
+ # Maximum number of pixels that will be thumbnailed
+ max_image_pixels: "32M"
+ """ % locals()
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 58a828cc4c..78195b3a4f 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -13,116 +13,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-from ._base import Config, ConfigError
-import syutil.crypto.signing_key
+from ._base import Config
class ServerConfig(Config):
- def __init__(self, args):
- super(ServerConfig, self).__init__(args)
- self.server_name = args.server_name
- self.signing_key = self.read_signing_key(args.signing_key_path)
- self.bind_port = args.bind_port
- self.bind_host = args.bind_host
- self.unsecure_port = args.unsecure_port
- self.daemonize = args.daemonize
- self.pid_file = self.abspath(args.pid_file)
- self.web_client = args.web_client
- self.manhole = args.manhole
- self.soft_file_limit = args.soft_file_limit
-
- if not args.content_addr:
- host = args.server_name
+
+ def read_config(self, config):
+ self.server_name = config["server_name"]
+ self.bind_port = config["bind_port"]
+ self.bind_host = config["bind_host"]
+ self.unsecure_port = config["unsecure_port"]
+ self.manhole = config.get("manhole")
+ self.pid_file = self.abspath(config.get("pid_file"))
+ self.web_client = config["web_client"]
+ self.soft_file_limit = config["soft_file_limit"]
+ self.daemonize = config.get("daemonize")
+
+ # Attempt to guess the content_addr for the v0 content repostitory
+ content_addr = config.get("content_addr")
+ if not content_addr:
+ host = self.server_name
if ':' not in host:
- host = "%s:%d" % (host, args.unsecure_port)
+ host = "%s:%d" % (host, self.unsecure_port)
else:
host = host.split(':')[0]
- host = "%s:%d" % (host, args.unsecure_port)
- args.content_addr = "http://%s" % (host,)
+ host = "%s:%d" % (host, self.unsecure_port)
+ content_addr = "http://%s" % (host,)
+
+ self.content_addr = content_addr
+
+ def default_config(self, config_dir_path, server_name):
+ if ":" in server_name:
+ bind_port = int(server_name.split(":")[1])
+ unsecure_port = bind_port - 400
+ else:
+ bind_port = 8448
+ unsecure_port = 8008
+
+ pid_file = self.abspath("homeserver.pid")
+ return """\
+ ## Server ##
+
+ # The domain name of the server, with optional explicit port.
+ # This is used by remote servers to connect to this server,
+ # e.g. matrix.org, localhost:8080, etc.
+ server_name: "%(server_name)s"
+
+ # The port to listen for HTTPS requests on.
+ # For when matrix traffic is sent directly to synapse.
+ bind_port: %(bind_port)s
- self.content_addr = args.content_addr
+ # The port to listen for HTTP requests on.
+ # For when matrix traffic passes through loadbalancer that unwraps TLS.
+ unsecure_port: %(unsecure_port)s
- @classmethod
- def add_arguments(cls, parser):
- super(ServerConfig, cls).add_arguments(parser)
+ # Local interface to listen on.
+ # The empty string will cause synapse to listen on all interfaces.
+ bind_host: ""
+
+ # When running as a daemon, the file to store the pid in
+ pid_file: %(pid_file)s
+
+ # Whether to serve a web client from the HTTP/HTTPS root resource.
+ web_client: True
+
+ # Set the soft limit on the number of file descriptors synapse can use
+ # Zero is used to indicate synapse should set the soft limit to the
+ # hard limit.
+ soft_file_limit: 0
+
+ # Turn on the twisted telnet manhole service on localhost on the given
+ # port.
+ #manhole: 9000
+ """ % locals()
+
+ def read_arguments(self, args):
+ if args.manhole is not None:
+ self.manhole = args.manhole
+ if args.daemonize is not None:
+ self.daemonize = args.daemonize
+
+ def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
- server_group.add_argument(
- "-H", "--server-name", default="localhost",
- help="The domain name of the server, with optional explicit port. "
- "This is used by remote servers to connect to this server, "
- "e.g. matrix.org, localhost:8080, etc."
- )
- server_group.add_argument("--signing-key-path",
- help="The signing key to sign messages with")
- server_group.add_argument("-p", "--bind-port", metavar="PORT",
- type=int, help="https port to listen on",
- default=8448)
- server_group.add_argument("--unsecure-port", metavar="PORT",
- type=int, help="http port to listen on",
- default=8008)
- server_group.add_argument("--bind-host", default="",
- help="Local interface to listen on")
server_group.add_argument("-D", "--daemonize", action='store_true',
+ default=None,
help="Daemonize the home server")
- server_group.add_argument('--pid-file', default="homeserver.pid",
- help="When running as a daemon, the file to"
- " store the pid in")
- server_group.add_argument('--web_client', default=True, type=bool,
- help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
" service on the given port.")
- server_group.add_argument("--content-addr", default=None,
- help="The host and scheme to use for the "
- "content repository")
- server_group.add_argument("--soft-file-limit", type=int, default=0,
- help="Set the soft limit on the number of "
- "file descriptors synapse can use. "
- "Zero is used to indicate synapse "
- "should set the soft limit to the hard"
- "limit.")
-
- def read_signing_key(self, signing_key_path):
- signing_keys = self.read_file(signing_key_path, "signing_key")
- try:
- return syutil.crypto.signing_key.read_signing_keys(
- signing_keys.splitlines(True)
- )
- except Exception:
- raise ConfigError(
- "Error reading signing_key."
- " Try running again with --generate-config"
- )
-
- @classmethod
- def generate_config(cls, args, config_dir_path):
- super(ServerConfig, cls).generate_config(args, config_dir_path)
- base_key_name = os.path.join(config_dir_path, args.server_name)
-
- args.pid_file = os.path.abspath(args.pid_file)
-
- if not args.signing_key_path:
- args.signing_key_path = base_key_name + ".signing.key"
-
- if not os.path.exists(args.signing_key_path):
- with open(args.signing_key_path, "w") as signing_key_file:
- syutil.crypto.signing_key.write_signing_keys(
- signing_key_file,
- (syutil.crypto.signing_key.generate_singing_key("auto"),),
- )
- else:
- signing_keys = cls.read_file(args.signing_key_path, "signing_key")
- if len(signing_keys.split("\n")[0].split()) == 1:
- # handle keys in the old format.
- key = syutil.crypto.signing_key.decode_signing_key_base64(
- syutil.crypto.signing_key.NACL_ED25519,
- "auto",
- signing_keys.split("\n")[0]
- )
- with open(args.signing_key_path, "w") as signing_key_file:
- syutil.crypto.signing_key.write_signing_keys(
- signing_key_file,
- (key,),
- )
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 034f9a7bf0..ecb2d42c1f 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -23,37 +23,44 @@ GENERATE_DH_PARAMS = False
class TlsConfig(Config):
- def __init__(self, args):
- super(TlsConfig, self).__init__(args)
+ def read_config(self, config):
self.tls_certificate = self.read_tls_certificate(
- args.tls_certificate_path
+ config.get("tls_certificate_path")
)
- self.no_tls = args.no_tls
+ self.no_tls = config.get("no_tls", False)
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
- args.tls_private_key_path
+ config.get("tls_private_key_path")
)
self.tls_dh_params_path = self.check_file(
- args.tls_dh_params_path, "tls_dh_params"
+ config.get("tls_dh_params_path"), "tls_dh_params"
)
- @classmethod
- def add_arguments(cls, parser):
- super(TlsConfig, cls).add_arguments(parser)
- tls_group = parser.add_argument_group("tls")
- tls_group.add_argument("--tls-certificate-path",
- help="PEM encoded X509 certificate for TLS")
- tls_group.add_argument("--tls-private-key-path",
- help="PEM encoded private key for TLS")
- tls_group.add_argument("--tls-dh-params-path",
- help="PEM dh parameters for ephemeral keys")
- tls_group.add_argument("--no-tls", action='store_true',
- help="Don't bind to the https port.")
+ def default_config(self, config_dir_path, server_name):
+ base_key_name = os.path.join(config_dir_path, server_name)
+
+ tls_certificate_path = base_key_name + ".tls.crt"
+ tls_private_key_path = base_key_name + ".tls.key"
+ tls_dh_params_path = base_key_name + ".tls.dh"
+
+ return """\
+ # PEM encoded X509 certificate for TLS
+ tls_certificate_path: "%(tls_certificate_path)s"
+
+ # PEM encoded private key for TLS
+ tls_private_key_path: "%(tls_private_key_path)s"
+
+ # PEM dh parameters for ephemeral keys
+ tls_dh_params_path: "%(tls_dh_params_path)s"
+
+ # Don't bind to the https port
+ no_tls: False
+ """ % locals()
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@@ -63,22 +70,13 @@ class TlsConfig(Config):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
- @classmethod
- def generate_config(cls, args, config_dir_path):
- super(TlsConfig, cls).generate_config(args, config_dir_path)
- base_key_name = os.path.join(config_dir_path, args.server_name)
-
- if args.tls_certificate_path is None:
- args.tls_certificate_path = base_key_name + ".tls.crt"
-
- if args.tls_private_key_path is None:
- args.tls_private_key_path = base_key_name + ".tls.key"
-
- if args.tls_dh_params_path is None:
- args.tls_dh_params_path = base_key_name + ".tls.dh"
+ def generate_files(self, config):
+ tls_certificate_path = config["tls_certificate_path"]
+ tls_private_key_path = config["tls_private_key_path"]
+ tls_dh_params_path = config["tls_dh_params_path"]
- if not os.path.exists(args.tls_private_key_path):
- with open(args.tls_private_key_path, "w") as private_key_file:
+ if not os.path.exists(tls_private_key_path):
+ with open(tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -86,17 +84,17 @@ class TlsConfig(Config):
)
private_key_file.write(private_key_pem)
else:
- with open(args.tls_private_key_path) as private_key_file:
+ with open(tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
- if not os.path.exists(args.tls_certificate_path):
- with open(args.tls_certificate_path, "w") as certifcate_file:
+ if not os.path.exists(tls_certificate_path):
+ with open(tls_certificate_path, "w") as certifcate_file:
cert = crypto.X509()
subject = cert.get_subject()
- subject.CN = args.server_name
+ subject.CN = config["server_name"]
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
@@ -110,16 +108,16 @@ class TlsConfig(Config):
certifcate_file.write(cert_pem)
- if not os.path.exists(args.tls_dh_params_path):
+ if not os.path.exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
- "-out", args.tls_dh_params_path,
+ "-out", tls_dh_params_path,
"2048"
])
else:
- with open(args.tls_dh_params_path, "w") as dh_params_file:
+ with open(tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 65162d21b7..a1707223d3 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -17,28 +17,21 @@ from ._base import Config
class VoipConfig(Config):
- def __init__(self, args):
- super(VoipConfig, self).__init__(args)
- self.turn_uris = args.turn_uris
- self.turn_shared_secret = args.turn_shared_secret
- self.turn_user_lifetime = args.turn_user_lifetime
-
- @classmethod
- def add_arguments(cls, parser):
- super(VoipConfig, cls).add_arguments(parser)
- group = parser.add_argument_group("voip")
- group.add_argument(
- "--turn-uris", type=str, default=None, action='append',
- help="The public URIs of the TURN server to give to clients"
- )
- group.add_argument(
- "--turn-shared-secret", type=str, default=None,
- help=(
- "The shared secret used to compute passwords for the TURN"
- " server"
- )
- )
- group.add_argument(
- "--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
- help="How long generated TURN credentials last, in ms"
- )
+ def read_config(self, config):
+ self.turn_uris = config.get("turn_uris", [])
+ self.turn_shared_secret = config["turn_shared_secret"]
+ self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+
+ def default_config(self, config_dir_path, server_name):
+ return """\
+ ## Turn ##
+
+ # The public URIs of the TURN server to give to clients
+ turn_uris: []
+
+ # The shared secret used to compute passwords for the TURN server
+ turn_shared_secret: "YOUR_SHARED_SECRET"
+
+ # How long generated TURN credentials last
+ turn_user_lifetime: "1h"
+ """
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 74008347c3..4911f0896b 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -25,12 +25,15 @@ import logging
logger = logging.getLogger(__name__)
+KEY_API_V1 = b"/_matrix/key/v1/"
+
@defer.inlineCallbacks
-def fetch_server_key(server_name, ssl_context_factory):
+def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
+ factory.path = path
endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30
)
@@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
return
+ except SynapseKeyClientError as e:
+ logger.exception("Error getting key for %r" % (server_name,))
+ if e.status.startswith("4"):
+ # Don't retry for 4xx responses.
+ raise IOError("Cannot get key for %r" % server_name)
except Exception as e:
logger.exception(e)
- raise IOError("Cannot get key for %s" % server_name)
+ raise IOError("Cannot get key for %r" % server_name)
class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server."""
+ status = None
pass
@@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
def connectionMade(self):
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
- self.sendCommand(b"GET", b"/_matrix/key/v1/")
+ self.sendCommand(b"GET", self.path)
self.endHeaders()
self.timer = reactor.callLater(
self.timeout,
self.on_timeout
)
+ def errback(self, error):
+ if not self.remote_key.called:
+ self.remote_key.errback(error)
+
+ def callback(self, result):
+ if not self.remote_key.called:
+ self.remote_key.callback(result)
+
def handleStatus(self, version, status, message):
if status != b"200":
# logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
+ error = SynapseKeyClientError(
+ "Non-200 response %r from %r" % (status, self.host)
+ )
+ error.status = status
+ self.errback(error)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
@@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
return
certificate = self.transport.getPeerCertificate()
- self.remote_key.callback((json_response, certificate))
+ self.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host)
- self.remote_key.errback(IOError("Timeout waiting for response"))
+ self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
class SynapseKeyClientFactory(Factory):
- protocol = SynapseKeyClientProtocol
+ def protocol(self):
+ protocol = SynapseKeyClientProtocol()
+ protocol.path = self.path
+ return protocol
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f4db7b8a05..8709394b97 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,9 @@
from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer
-from syutil.crypto.jsonsign import verify_signed_json, signature_ids
+from syutil.crypto.jsonsign import (
+ verify_signed_json, signature_ids, sign_json, encode_canonical_json
+)
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
@@ -24,8 +26,12 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
+from synapse.util.async import create_observer
+
from OpenSSL import crypto
+import urllib
+import hashlib
import logging
@@ -36,8 +42,13 @@ class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.config = hs.get_config()
+ self.perspective_servers = self.config.perspectives
self.hs = hs
+ self.key_downloads = {}
+
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name)
@@ -85,19 +96,56 @@ class Keyring(object):
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
+ Trys to fetch the key from a trusted perspective server first.
Args:
- server_name (str): The name of the server to fetch a key for.
+ server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
-
- # Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
defer.returnValue(cached[0])
return
- # Try to fetch the key from the remote server.
+ download = self.key_downloads.get(server_name)
+
+ if download is None:
+ download = self._get_server_verify_key_impl(server_name, key_ids)
+ self.key_downloads[server_name] = download
+
+ @download.addBoth
+ def callback(ret):
+ del self.key_downloads[server_name]
+ return ret
+
+ r = yield create_observer(download)
+ defer.returnValue(r)
+
+ @defer.inlineCallbacks
+ def _get_server_verify_key_impl(self, server_name, key_ids):
+ keys = None
+
+ perspective_results = []
+ for perspective_name, perspective_keys in self.perspective_servers.items():
+ @defer.inlineCallbacks
+ def get_key():
+ try:
+ result = yield self.get_server_verify_key_v2_indirect(
+ server_name, key_ids, perspective_name, perspective_keys
+ )
+ defer.returnValue(result)
+ except:
+ logging.info(
+ "Unable to getting key %r for %r from %r",
+ key_ids, server_name, perspective_name,
+ )
+ perspective_results.append(get_key())
+
+ perspective_results = yield defer.gatherResults(perspective_results)
+
+ for results in perspective_results:
+ if results is not None:
+ keys = results
limiter = yield get_retry_limiter(
server_name,
@@ -106,10 +154,234 @@ class Keyring(object):
)
with limiter:
+ if keys is None:
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except:
+ pass
+
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ for key_id in key_ids:
+ if key_id in keys:
+ defer.returnValue(keys[key_id])
+ return
+ raise ValueError("No verification key found for given key ids")
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ perspective_name,
+ perspective_keys):
+ limiter = yield get_retry_limiter(
+ perspective_name, self.clock, self.store
+ )
+
+ with limiter:
+ # TODO(mark): Set the minimum_valid_until_ts to that needed by
+ # the events being validated or the current time if validating
+ # an incoming request.
+ responses = yield self.client.post_json(
+ destination=perspective_name,
+ path=b"/_matrix/key/v2/query",
+ data={
+ u"server_keys": {
+ server_name: {
+ key_id: {
+ u"minimum_valid_until_ts": 0
+ } for key_id in key_ids
+ }
+ }
+ },
+ )
+
+ keys = {}
+
+ for response in responses:
+ if (u"signatures" not in response
+ or perspective_name not in response[u"signatures"]):
+ raise ValueError(
+ "Key response not signed by perspective server"
+ " %r" % (perspective_name,)
+ )
+
+ verified = False
+ for key_id in response[u"signatures"][perspective_name]:
+ if key_id in perspective_keys:
+ verify_signed_json(
+ response,
+ perspective_name,
+ perspective_keys[key_id]
+ )
+ verified = True
+
+ if not verified:
+ logging.info(
+ "Response from perspective server %r not signed with a"
+ " known key, signed with: %r, known keys: %r",
+ perspective_name,
+ list(response[u"signatures"][perspective_name]),
+ list(perspective_keys)
+ )
+ raise ValueError(
+ "Response not signed with a known key for perspective"
+ " server %r" % (perspective_name,)
+ )
+
+ response_keys = yield self.process_v2_response(
+ server_name, perspective_name, response
+ )
+
+ keys.update(response_keys)
+
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=keys,
+ )
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_direct(self, server_name, key_ids):
+
+ keys = {}
+
+ for requested_key_id in key_ids:
+ if requested_key_id in keys:
+ continue
+
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_context_factory
+ server_name, self.hs.tls_context_factory,
+ path=(b"/_matrix/key/v2/server/%s" % (
+ urllib.quote(requested_key_id),
+ )).encode("ascii"),
+ )
+
+ if (u"signatures" not in response
+ or server_name not in response[u"signatures"]):
+ raise ValueError("Key response not signed by remote server")
+
+ if "tls_fingerprints" not in response:
+ raise ValueError("Key response missing TLS fingerprints")
+
+ certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1, tls_certificate
+ )
+ sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
+ sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
+
+ response_sha256_fingerprints = set()
+ for fingerprint in response[u"tls_fingerprints"]:
+ if u"sha256" in fingerprint:
+ response_sha256_fingerprints.add(fingerprint[u"sha256"])
+
+ if sha256_fingerprint_b64 not in response_sha256_fingerprints:
+ raise ValueError("TLS certificate not allowed by fingerprints")
+
+ response_keys = yield self.process_v2_response(
+ server_name=server_name,
+ from_server=server_name,
+ requested_id=requested_key_id,
+ response_json=response,
+ )
+
+ keys.update(response_keys)
+
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=keys,
+ )
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def process_v2_response(self, server_name, from_server, response_json,
+ requested_id=None):
+ time_now_ms = self.clock.time_msec()
+ response_keys = {}
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
+ verify_keys[key_id] = verify_key
+
+ old_verify_keys = {}
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.expired = key_data["expired_ts"]
+ verify_key.time_added = time_now_ms
+ old_verify_keys[key_id] = verify_key
+
+ for key_id in response_json["signatures"][server_name]:
+ if key_id not in response_json["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response_json,
+ server_name,
+ verify_keys[key_id]
+ )
+
+ signed_key_json = sign_json(
+ response_json,
+ self.config.server_name,
+ self.config.signing_key[0],
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
+
+ updated_key_ids = set()
+ if requested_id is not None:
+ updated_key_ids.add(requested_id)
+ updated_key_ids.update(verify_keys)
+ updated_key_ids.update(old_verify_keys)
+
+ response_keys.update(verify_keys)
+ response_keys.update(old_verify_keys)
+
+ for key_id in updated_key_ids:
+ yield self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
)
+ defer.returnValue(response_keys)
+
+ raise ValueError("No verification key found for given key ids")
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v1_direct(self, server_name, key_ids):
+ """Finds a verification key for the server with one of the key ids.
+ Args:
+ server_name (str): The name of the server to fetch a key for.
+ keys_ids (list of str): The key_ids to check for.
+ """
+
+ # Try to fetch the key from the remote server.
+
+ (response, tls_certificate) = yield fetch_server_key(
+ server_name, self.hs.tls_context_factory
+ )
+
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
@@ -128,11 +400,16 @@ class Keyring(object):
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
+ # Cache the result in the datastore.
+
+ time_now_ms = self.clock.time_msec()
+
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
@@ -148,10 +425,6 @@ class Keyring(object):
verify_keys[key_id]
)
- # Cache the result in the datastore.
-
- time_now_ms = self.clock.time_msec()
-
yield self.store.store_server_certificate(
server_name,
server_name,
@@ -159,14 +432,26 @@ class Keyring(object):
tls_certificate,
)
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+
+ defer.returnValue(verify_keys)
+
+ @defer.inlineCallbacks
+ def store_keys(self, server_name, from_server, verify_keys):
+ """Store a collection of verify keys for a given server
+ Args:
+ server_name(str): The name of the server the keys are for.
+ from_server(str): The server the keys were downloaded from.
+ verify_keys(dict): A mapping of key_id to VerifyKey.
+ Returns:
+ A deferred that completes when the keys are stored.
+ """
for key_id, key in verify_keys.items():
+ # TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key(
- server_name, server_name, time_now_ms, key
+ server_name, server_name, key.time_added, key
)
-
- for key_id in key_ids:
- if key_id in verify_keys:
- defer.returnValue(verify_keys[key_id])
- return
-
- raise ValueError("No verification key found for given key ids")
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 64e08223b0..e4495ccf12 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -46,9 +46,10 @@ def _event_dict_property(key):
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
- internal_metadata_dict={}):
+ internal_metadata_dict={}, rejected_reason=None):
self.signatures = signatures
self.unsigned = unsigned
+ self.rejected_reason = rejected_reason
self._event_dict = event_dict
@@ -109,7 +110,7 @@ class EventBase(object):
class FrozenEvent(EventBase):
- def __init__(self, event_dict, internal_metadata_dict={}):
+ def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -128,6 +129,7 @@ class FrozenEvent(EventBase):
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
+ rejected_reason=rejected_reason,
)
@staticmethod
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 6811a0e3d1..904c7c0945 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -491,7 +491,7 @@ class FederationClient(FederationBase):
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=True
+ destination, events, outlier=False
)
have_gotten_all_from_destination = True
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 25c0014f97..2b46188c91 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -417,13 +417,13 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
- latest_tuples = yield self.store.get_latest_events_in_room(
+ latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
- latest = set(e_id for e_id, _, _ in latest_tuples)
+ latest = set(latest)
latest |= seen
missing_events = yield self.get_missing_events(
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 76a9dcd777..1a7cc02f92 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -23,8 +23,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
-from syutil.jsonutil import encode_canonical_json
-
import logging
@@ -71,7 +69,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
- encode_canonical_json(response)
+ response,
)
@defer.inlineCallbacks
@@ -101,5 +99,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
- encode_canonical_json(response_dict)
+ response_dict,
)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 4dccd93d0e..ca04822fb3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -104,7 +104,6 @@ class TransactionQueue(object):
return not destination.startswith("localhost")
@defer.inlineCallbacks
- @log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 8d345bf936..685792dbdc 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
@@ -29,6 +30,8 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
+from .auth import AuthHandler
+from .identity import IdentityHandler
class Handlers(object):
@@ -54,7 +57,14 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
+ asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
- hs, ApplicationServiceApi(hs)
+ hs, asapi, AppServiceScheduler(
+ clock=hs.get_clock(),
+ store=hs.get_datastore(),
+ as_api=asapi
+ )
)
self.sync_handler = SyncHandler(hs)
+ self.auth_handler = AuthHandler(hs)
+ self.identity_handler = IdentityHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 48816a242d..4b3f4eadab 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError
-from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID
@@ -58,8 +57,6 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
- yield run_on_reactor()
-
latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
)
@@ -101,8 +98,6 @@ class BaseHandler(object):
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
- yield run_on_reactor()
-
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
@@ -143,7 +138,9 @@ class BaseHandler(object):
)
# Don't block waiting on waking up all the listeners.
- d = self.notifier.on_new_room_event(event, extra_users=extra_users)
+ notify_d = self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
def log_failure(f):
logger.warn(
@@ -151,8 +148,8 @@ class BaseHandler(object):
event.event_id, f.value
)
- d.addErrback(log_failure)
+ notify_d.addErrback(log_failure)
- yield federation_handler.handle_new_event(
+ federation_handler.handle_new_event(
event, destinations=destinations,
)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 2c488a46f6..355ab317df 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -16,57 +16,36 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService
from synapse.types import UserID
-import synapse.util.stringutils as stringutils
import logging
logger = logging.getLogger(__name__)
+def log_failure(failure):
+ logger.error(
+ "Application Services Failure",
+ exc_info=(
+ failure.type,
+ failure.value,
+ failure.getTracebackObject()
+ )
+ )
+
+
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object):
- def __init__(self, hs, appservice_api):
+ def __init__(self, hs, appservice_api, appservice_scheduler):
self.store = hs.get_datastore()
self.hs = hs
self.appservice_api = appservice_api
-
- @defer.inlineCallbacks
- def register(self, app_service):
- logger.info("Register -> %s", app_service)
- # check the token is recognised
- try:
- stored_service = yield self.store.get_app_service_by_token(
- app_service.token
- )
- if not stored_service:
- raise StoreError(404, "Application service not found")
- except StoreError:
- raise SynapseError(
- 403, "Unrecognised application services token. "
- "Consult the home server admin.",
- errcode=Codes.FORBIDDEN
- )
-
- app_service.hs_token = self._generate_hs_token()
-
- # create a sender for this application service which is used when
- # creating rooms, etc..
- account = yield self.hs.get_handlers().registration_handler.register()
- app_service.sender = account[0]
-
- yield self.store.update_app_service(app_service)
- defer.returnValue(app_service)
-
- @defer.inlineCallbacks
- def unregister(self, token):
- logger.info("Unregister as_token=%s", token)
- yield self.store.unregister_app_service(token)
+ self.scheduler = appservice_scheduler
+ self.started_scheduler = False
@defer.inlineCallbacks
def notify_interested_services(self, event):
@@ -90,9 +69,13 @@ class ApplicationServicesHandler(object):
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
- # Fork off pushes to these services - XXX First cut, best effort
+ if not self.started_scheduler:
+ self.scheduler.start().addErrback(log_failure)
+ self.started_scheduler = True
+
+ # Fork off pushes to these services
for service in services:
- self.appservice_api.push(service, event)
+ self.scheduler.submit_event_for_as(service, event)
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@@ -197,7 +180,14 @@ class ApplicationServicesHandler(object):
return
user_info = yield self.store.get_user_by_id(user_id)
- defer.returnValue(len(user_info) == 0)
+ if not user_info:
+ defer.returnValue(False)
+ return
+
+ # user not found; could be the AS though, so check.
+ services = yield self.store.get_app_services()
+ service_list = [s for s in services if s.sender == user_id]
+ defer.returnValue(len(service_list) == 0)
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
@@ -206,6 +196,3 @@ class ApplicationServicesHandler(object):
exists = yield self.query_user_exists(user_id)
defer.returnValue(exists)
defer.returnValue(True)
-
- def _generate_hs_token(self):
- return stringutils.random_string(24)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
new file mode 100644
index 0000000000..4e2e50345e
--- /dev/null
+++ b/synapse/handlers/auth.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import BaseHandler
+from synapse.api.constants import LoginType
+from synapse.types import UserID
+from synapse.api.errors import LoginError, Codes
+from synapse.http.client import SimpleHttpClient
+from synapse.util.async import run_on_reactor
+
+from twisted.web.client import PartialDownloadError
+
+import logging
+import bcrypt
+import simplejson
+
+import synapse.util.stringutils as stringutils
+
+
+logger = logging.getLogger(__name__)
+
+
+class AuthHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(AuthHandler, self).__init__(hs)
+ self.checkers = {
+ LoginType.PASSWORD: self._check_password_auth,
+ LoginType.RECAPTCHA: self._check_recaptcha,
+ LoginType.EMAIL_IDENTITY: self._check_email_identity,
+ LoginType.DUMMY: self._check_dummy_auth,
+ }
+ self.sessions = {}
+
+ @defer.inlineCallbacks
+ def check_auth(self, flows, clientdict, clientip=None):
+ """
+ Takes a dictionary sent by the client in the login / registration
+ protocol and handles the login flow.
+
+ Args:
+ flows: list of list of stages
+ authdict: The dictionary from the client root level, not the
+ 'auth' key: this method prompts for auth if none is sent.
+ Returns:
+ A tuple of authed, dict, dict where authed is true if the client
+ has successfully completed an auth flow. If it is true, the first
+ dict contains the authenticated credentials of each stage.
+
+ If authed is false, the first dictionary is the server response to
+ the login request and should be passed back to the client.
+
+ In either case, the second dict contains the parameters for this
+ request (which may have been given only in a previous call).
+ """
+
+ authdict = None
+ sid = None
+ if clientdict and 'auth' in clientdict:
+ authdict = clientdict['auth']
+ del clientdict['auth']
+ if 'session' in authdict:
+ sid = authdict['session']
+ sess = self._get_session_info(sid)
+
+ if len(clientdict) > 0:
+ # This was designed to allow the client to omit the parameters
+ # and just supply the session in subsequent calls so it split
+ # auth between devices by just sharing the session, (eg. so you
+ # could continue registration from your phone having clicked the
+ # email auth link on there). It's probably too open to abuse
+ # because it lets unauthenticated clients store arbitrary objects
+ # on a home server.
+ # sess['clientdict'] = clientdict
+ # self._save_session(sess)
+ pass
+ elif 'clientdict' in sess:
+ clientdict = sess['clientdict']
+
+ if not authdict:
+ defer.returnValue(
+ (False, self._auth_dict_for_flows(flows, sess), clientdict)
+ )
+
+ if 'creds' not in sess:
+ sess['creds'] = {}
+ creds = sess['creds']
+
+ # check auth type currently being presented
+ if 'type' in authdict:
+ if authdict['type'] not in self.checkers:
+ raise LoginError(400, "", Codes.UNRECOGNIZED)
+ result = yield self.checkers[authdict['type']](authdict, clientip)
+ if result:
+ creds[authdict['type']] = result
+ self._save_session(sess)
+
+ for f in flows:
+ if len(set(f) - set(creds.keys())) == 0:
+ logger.info("Auth completed with creds: %r", creds)
+ self._remove_session(sess)
+ defer.returnValue((True, creds, clientdict))
+
+ ret = self._auth_dict_for_flows(flows, sess)
+ ret['completed'] = creds.keys()
+ defer.returnValue((False, ret, clientdict))
+
+ @defer.inlineCallbacks
+ def add_oob_auth(self, stagetype, authdict, clientip):
+ """
+ Adds the result of out-of-band authentication into an existing auth
+ session. Currently used for adding the result of fallback auth.
+ """
+ if stagetype not in self.checkers:
+ raise LoginError(400, "", Codes.MISSING_PARAM)
+ if 'session' not in authdict:
+ raise LoginError(400, "", Codes.MISSING_PARAM)
+
+ sess = self._get_session_info(
+ authdict['session']
+ )
+ if 'creds' not in sess:
+ sess['creds'] = {}
+ creds = sess['creds']
+
+ result = yield self.checkers[stagetype](authdict, clientip)
+ if result:
+ creds[stagetype] = result
+ self._save_session(sess)
+ defer.returnValue(True)
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def _check_password_auth(self, authdict, _):
+ if "user" not in authdict or "password" not in authdict:
+ raise LoginError(400, "", Codes.MISSING_PARAM)
+
+ user = authdict["user"]
+ password = authdict["password"]
+ if not user.startswith('@'):
+ user = UserID.create(user, self.hs.hostname).to_string()
+
+ user_info = yield self.store.get_user_by_id(user_id=user)
+ if not user_info:
+ logger.warn("Attempted to login as %s but they do not exist", user)
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ stored_hash = user_info["password_hash"]
+ if bcrypt.checkpw(password, stored_hash):
+ defer.returnValue(user)
+ else:
+ logger.warn("Failed password login for user %s", user)
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ @defer.inlineCallbacks
+ def _check_recaptcha(self, authdict, clientip):
+ try:
+ user_response = authdict["response"]
+ except KeyError:
+ # Client tried to provide captcha but didn't give the parameter:
+ # bad request.
+ raise LoginError(
+ 400, "Captcha response is required",
+ errcode=Codes.CAPTCHA_NEEDED
+ )
+
+ logger.info(
+ "Submitting recaptcha response %s with remoteip %s",
+ user_response, clientip
+ )
+
+ # TODO: get this from the homeserver rather than creating a new one for
+ # each request
+ try:
+ client = SimpleHttpClient(self.hs)
+ data = yield client.post_urlencoded_get_json(
+ "https://www.google.com/recaptcha/api/siteverify",
+ args={
+ 'secret': self.hs.config.recaptcha_private_key,
+ 'response': user_response,
+ 'remoteip': clientip,
+ }
+ )
+ except PartialDownloadError as pde:
+ # Twisted is silly
+ data = pde.response
+ resp_body = simplejson.loads(data)
+ if 'success' in resp_body and resp_body['success']:
+ defer.returnValue(True)
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ @defer.inlineCallbacks
+ def _check_email_identity(self, authdict, _):
+ yield run_on_reactor()
+
+ if 'threepid_creds' not in authdict:
+ raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
+
+ threepid_creds = authdict['threepid_creds']
+ identity_handler = self.hs.get_handlers().identity_handler
+
+ logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
+ threepid = yield identity_handler.threepid_from_creds(threepid_creds)
+
+ if not threepid:
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ threepid['threepid_creds'] = authdict['threepid_creds']
+
+ defer.returnValue(threepid)
+
+ @defer.inlineCallbacks
+ def _check_dummy_auth(self, authdict, _):
+ yield run_on_reactor()
+ defer.returnValue(True)
+
+ def _get_params_recaptcha(self):
+ return {"public_key": self.hs.config.recaptcha_public_key}
+
+ def _auth_dict_for_flows(self, flows, session):
+ public_flows = []
+ for f in flows:
+ public_flows.append(f)
+
+ get_params = {
+ LoginType.RECAPTCHA: self._get_params_recaptcha,
+ }
+
+ params = {}
+
+ for f in public_flows:
+ for stage in f:
+ if stage in get_params and stage not in params:
+ params[stage] = get_params[stage]()
+
+ return {
+ "session": session['id'],
+ "flows": [{"stages": f} for f in public_flows],
+ "params": params
+ }
+
+ def _get_session_info(self, session_id):
+ if session_id not in self.sessions:
+ session_id = None
+
+ if not session_id:
+ # create a new session
+ while session_id is None or session_id in self.sessions:
+ session_id = stringutils.random_string(24)
+ self.sessions[session_id] = {
+ "id": session_id,
+ }
+
+ return self.sessions[session_id]
+
+ def _save_session(self, session):
+ # TODO: Persistent storage
+ logger.debug("Saving session %s", session)
+ self.sessions[session["id"]] = session
+
+ def _remove_session(self, session):
+ logger.debug("Removing session %s", session)
+ del self.sessions[session["id"]]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 15ba417e06..85e2757227 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -73,8 +73,6 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
- @log_function
- @defer.inlineCallbacks
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
@@ -89,9 +87,7 @@ class FederationHandler(BaseHandler):
processing.
"""
- yield run_on_reactor()
-
- self.replication_layer.send_pdu(event, destinations)
+ return self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
@@ -179,7 +175,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin)
- if (retry_timings and retry_timings.retry_last_ts):
+ if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id)
@@ -201,10 +197,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- yield self.notifier.on_new_room_event(
+ d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
+
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
user = UserID.from_string(event.state_key)
@@ -427,10 +431,18 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
- yield self.notifier.on_new_room_event(
+ d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee]
)
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ new_event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
+
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -500,10 +512,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- yield self.notifier.on_new_room_event(
+ d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
+
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
@@ -574,10 +594,18 @@ class FederationHandler(BaseHandler):
)
target_user = UserID.from_string(event.state_key)
- yield self.notifier.on_new_room_event(
+ d = self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
+
defer.returnValue(event)
@defer.inlineCallbacks
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
new file mode 100644
index 0000000000..6200e10775
--- /dev/null
+++ b/synapse/handlers/identity.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for interacting with Identity Servers"""
+from twisted.internet import defer
+
+from synapse.api.errors import (
+ CodeMessageException
+)
+from ._base import BaseHandler
+from synapse.http.client import SimpleHttpClient
+from synapse.util.async import run_on_reactor
+from synapse.api.errors import SynapseError
+
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class IdentityHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(IdentityHandler, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def threepid_from_creds(self, creds):
+ yield run_on_reactor()
+
+ # TODO: get this from the homeserver rather than creating a new one for
+ # each request
+ http_client = SimpleHttpClient(self.hs)
+ # XXX: make this configurable!
+ # trustedIdServers = ['matrix.org', 'localhost:8090']
+ trustedIdServers = ['matrix.org']
+
+ if 'id_server' in creds:
+ id_server = creds['id_server']
+ elif 'idServer' in creds:
+ id_server = creds['idServer']
+ else:
+ raise SynapseError(400, "No id_server in creds")
+
+ if 'client_secret' in creds:
+ client_secret = creds['client_secret']
+ elif 'clientSecret' in creds:
+ client_secret = creds['clientSecret']
+ else:
+ raise SynapseError(400, "No client_secret in creds")
+
+ if id_server not in trustedIdServers:
+ logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server)
+ defer.returnValue(None)
+
+ data = {}
+ try:
+ data = yield http_client.get_json(
+ "https://%s%s" % (
+ id_server,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid"
+ ),
+ {'sid': creds['sid'], 'client_secret': client_secret}
+ )
+ except CodeMessageException as e:
+ data = json.loads(e.msg)
+
+ if 'medium' in data:
+ defer.returnValue(data)
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def bind_threepid(self, creds, mxid):
+ yield run_on_reactor()
+ logger.debug("binding threepid %r to %s", creds, mxid)
+ http_client = SimpleHttpClient(self.hs)
+ data = None
+
+ if 'id_server' in creds:
+ id_server = creds['id_server']
+ elif 'idServer' in creds:
+ id_server = creds['idServer']
+ else:
+ raise SynapseError(400, "No id_server in creds")
+
+ if 'client_secret' in creds:
+ client_secret = creds['client_secret']
+ elif 'clientSecret' in creds:
+ client_secret = creds['clientSecret']
+ else:
+ raise SynapseError(400, "No client_secret in creds")
+
+ try:
+ data = yield http_client.post_urlencoded_get_json(
+ "https://%s%s" % (
+ id_server, "/_matrix/identity/api/v1/3pid/bind"
+ ),
+ {
+ 'sid': creds['sid'],
+ 'client_secret': client_secret,
+ 'mxid': mxid,
+ }
+ )
+ logger.debug("bound threepid %r to %s", creds, mxid)
+ except CodeMessageException as e:
+ data = json.loads(e.msg)
+ defer.returnValue(data)
diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py
index 7447800460..91d87d503d 100644
--- a/synapse/handlers/login.py
+++ b/synapse/handlers/login.py
@@ -16,13 +16,9 @@
from twisted.internet import defer
from ._base import BaseHandler
-from synapse.api.errors import LoginError, Codes, CodeMessageException
-from synapse.http.client import SimpleHttpClient
-from synapse.util.emailutils import EmailException
-import synapse.util.emailutils as emailutils
+from synapse.api.errors import LoginError, Codes
import bcrypt
-import json
import logging
logger = logging.getLogger(__name__)
@@ -57,7 +53,7 @@ class LoginHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- stored_hash = user_info[0]["password_hash"]
+ stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
@@ -69,48 +65,19 @@ class LoginHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def reset_password(self, user_id, email):
- is_valid = yield self._check_valid_association(user_id, email)
- logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
- is_valid)
- if is_valid:
- try:
- # send an email out
- emailutils.send_email(
- smtp_server=self.hs.config.email_smtp_server,
- from_addr=self.hs.config.email_from_address,
- to_addr=email,
- subject="Password Reset",
- body="TODO."
- )
- except EmailException as e:
- logger.exception(e)
+ def set_password(self, user_id, newpassword, token_id=None):
+ password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
- @defer.inlineCallbacks
- def _check_valid_association(self, user_id, email):
- identity = yield self._query_email(email)
- if identity and "mxid" in identity:
- if identity["mxid"] == user_id:
- defer.returnValue(True)
- return
- defer.returnValue(False)
+ yield self.store.user_set_password_hash(user_id, password_hash)
+ yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
+ yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
+ user_id, token_id
+ )
+ yield self.store.flush_user(user_id)
@defer.inlineCallbacks
- def _query_email(self, email):
- http_client = SimpleHttpClient(self.hs)
- try:
- data = yield http_client.get_json(
- # TODO FIXME This should be configurable.
- # XXX: ID servers need to use HTTPS
- "http://%s%s" % (
- "matrix.org:8090", "/_matrix/identity/api/v1/lookup"
- ),
- {
- 'medium': 'email',
- 'address': email
- }
- )
- defer.returnValue(data)
- except CodeMessageException as e:
- data = json.loads(e.msg)
- defer.returnValue(data)
+ def add_threepid(self, user_id, medium, address, validated_at):
+ yield self.store.user_add_threepid(
+ user_id, medium, address, validated_at,
+ self.hs.get_clock().time_msec()
+ )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7b9685be7f..22e19af17f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -267,14 +267,14 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None
)
- public_rooms = yield self.store.get_rooms(is_public=True)
- public_room_ids = [r["room_id"] for r in public_rooms]
+ public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
- for event in room_list:
+ @defer.inlineCallbacks
+ def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -290,12 +290,19 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d)
if event.membership != Membership.JOIN:
- continue
+ return
try:
- messages, token = yield self.store.get_recent_events_for_room(
- event.room_id,
- limit=limit,
- end_token=now_token.room_key,
+ (messages, token), current_state = yield defer.gatherResults(
+ [
+ self.store.get_recent_events_for_room(
+ event.room_id,
+ limit=limit,
+ end_token=now_token.room_key,
+ ),
+ self.state_handler.get_current_state(
+ event.room_id
+ ),
+ ]
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -311,9 +318,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
- current_state = yield self.state_handler.get_current_state(
- event.room_id
- )
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
@@ -321,6 +325,11 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
+ yield defer.gatherResults(
+ [handle_room(e) for e in room_list],
+ consumeErrors=True
+ )
+
ret = {
"rooms": rooms_ret,
"presence": presence,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 731df00648..9e15610401 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -33,6 +33,13 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
+# Don't bother bumping "last active" time if it differs by less than 60 seconds
+LAST_ACTIVE_GRANULARITY = 60*1000
+
+# Keep no more than this number of offline serial revisions
+MAX_OFFLINE_SERIALS = 1000
+
+
# TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func):
"""Partition the list by the result of func applied to each element."""
@@ -131,6 +138,9 @@ class PresenceHandler(BaseHandler):
self._remote_sendmap = {}
# map remote users to sets of local users who're interested in them
self._remote_recvmap = {}
+ # list of (serial, set of(userids)) tuples, ordered by serial, latest
+ # first
+ self._remote_offline_serials = []
# map any user to a UserPresenceCache
self._user_cachemap = {}
@@ -282,6 +292,10 @@ class PresenceHandler(BaseHandler):
if now is None:
now = self.clock.time_msec()
+ prev_state = self._get_or_make_usercache(user)
+ if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
+ return
+
self.changed_presencelike_data(user, {"last_active": now})
def changed_presencelike_data(self, user, state):
@@ -706,8 +720,24 @@ class PresenceHandler(BaseHandler):
statuscache=statuscache,
)
+ user_id = user.to_string()
+
if state["presence"] == PresenceState.OFFLINE:
+ self._remote_offline_serials.insert(
+ 0,
+ (self._user_cachemap_latest_serial, set([user_id]))
+ )
+ while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
+ self._remote_offline_serials.pop() # remove the oldest
del self._user_cachemap[user]
+ else:
+ # Remove the user from remote_offline_serials now that they're
+ # no longer offline
+ for idx, elem in enumerate(self._remote_offline_serials):
+ (_, user_ids) = elem
+ user_ids.discard(user_id)
+ if not user_ids:
+ self._remote_offline_serials.pop(idx)
for poll in content.get("poll", []):
user = UserID.from_string(poll)
@@ -829,26 +859,47 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
+ max_serial = presence._user_cachemap_latest_serial
+
+ clock = self.clock
+ latest_serial = 0
+
updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
cached = cachemap[observed_user]
- if cached.serial <= from_key:
+ if cached.serial <= from_key or cached.serial > max_serial:
continue
- if (yield self.is_visible(observer_user, observed_user)):
- updates.append((observed_user, cached))
+ if not (yield self.is_visible(observer_user, observed_user)):
+ continue
+
+ latest_serial = max(cached.serial, latest_serial)
+ updates.append(cached.make_event(user=observed_user, clock=clock))
# TODO(paul): limit
- if updates:
- clock = self.clock
+ for serial, user_ids in presence._remote_offline_serials:
+ if serial <= from_key:
+ break
- latest_serial = max([x[1].serial for x in updates])
- data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
+ if serial > max_serial:
+ continue
- defer.returnValue((data, latest_serial))
+ latest_serial = max(latest_serial, serial)
+ for u in user_ids:
+ updates.append({
+ "type": "m.presence",
+ "content": {"user_id": u, "presence": PresenceState.OFFLINE},
+ })
+ # TODO(paul): For the v2 API we want to tell the client their from_key
+ # is too old if we fell off the end of the _remote_offline_serials
+ # list, and get them to invalidate+resync. In v1 we have no such
+ # concept so this is a best-effort result.
+
+ if updates:
+ defer.returnValue((updates, latest_serial))
else:
defer.returnValue(([], presence._user_cachemap_latest_serial))
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c25e321099..7b68585a17 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -18,18 +18,15 @@ from twisted.internet import defer
from synapse.types import UserID
from synapse.api.errors import (
- AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError,
- CodeMessageException
+ AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
-from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient
import base64
import bcrypt
-import json
import logging
import urllib
@@ -45,6 +42,30 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user")
@defer.inlineCallbacks
+ def check_username(self, localpart):
+ yield run_on_reactor()
+
+ if urllib.quote(localpart) != localpart:
+ raise SynapseError(
+ 400,
+ "User ID must only contain characters which do not"
+ " require URL encoding."
+ )
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+
+ yield self.check_user_id_is_valid(user_id)
+
+ u = yield self.store.get_user_by_id(user_id)
+ if u:
+ raise SynapseError(
+ 400,
+ "User ID already taken.",
+ errcode=Codes.USER_IN_USE,
+ )
+
+ @defer.inlineCallbacks
def register(self, localpart=None, password=None):
"""Registers a new client on the server.
@@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
- if localpart and urllib.quote(localpart) != localpart:
- raise SynapseError(
- 400,
- "User ID must only contain characters which do not"
- " require URL encoding."
- )
+ yield self.check_username(localpart)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- yield self.check_user_id_is_valid(user_id)
-
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
@@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
- """Checks a recaptcha is correct."""
+ """
+ Checks a recaptcha is correct.
+
+ Used only by c/s api v1
+ """
captcha_response = yield self._validate_captcha(
ip,
@@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def register_email(self, threepidCreds):
- """Registers emails with an identity server."""
+ """
+ Registers emails with an identity server.
+
+ Used only by c/s api v1
+ """
for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
- threepid = yield self._threepid_from_creds(c)
+ identity_handler = self.hs.get_handlers().identity_handler
+ threepid = yield identity_handler.threepid_from_creds(c)
except:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
- """Links emails with a user ID and informs an identity server."""
+ """Links emails with a user ID and informs an identity server.
+
+ Used only by c/s api v1
+ """
# Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds:
+ identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it?
- yield self._bind_threepid(c, user_id)
+ yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_is_valid(self, user_id):
@@ -227,61 +254,11 @@ class RegistrationHandler(BaseHandler):
return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
- def _threepid_from_creds(self, creds):
- # TODO: get this from the homeserver rather than creating a new one for
- # each request
- http_client = SimpleHttpClient(self.hs)
- # XXX: make this configurable!
- trustedIdServers = ['matrix.org:8090', 'matrix.org']
- if not creds['idServer'] in trustedIdServers:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', creds['idServer'])
- defer.returnValue(None)
-
- data = {}
- try:
- data = yield http_client.get_json(
- # XXX: This should be HTTPS
- "http://%s%s" % (
- creds['idServer'],
- "/_matrix/identity/api/v1/3pid/getValidated3pid"
- ),
- {'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
- )
- except CodeMessageException as e:
- data = json.loads(e.msg)
-
- if 'medium' in data:
- defer.returnValue(data)
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def _bind_threepid(self, creds, mxid):
- yield
- logger.debug("binding threepid")
- http_client = SimpleHttpClient(self.hs)
- data = None
- try:
- data = yield http_client.post_urlencoded_get_json(
- # XXX: Change when ID servers are all HTTPS
- "http://%s%s" % (
- creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
- ),
- {
- 'sid': creds['sid'],
- 'clientSecret': creds['clientSecret'],
- 'mxid': mxid,
- }
- )
- logger.debug("bound threepid")
- except CodeMessageException as e:
- data = json.loads(e.msg)
- defer.returnValue(data)
-
- @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided.
+ Used only by c/s api v1
+
Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
@@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response):
+ """
+ Used only by c/s api v1
+ """
# TODO: get this from the homeserver rather than creating a new one for
# each request
client = CaptchaServerHttpClient(self.hs)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 823affc380..cfa2e38ed2 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -124,7 +124,7 @@ class RoomCreationHandler(BaseHandler):
msg_handler = self.hs.get_handlers().message_handler
for event in creation_events:
- yield msg_handler.create_and_send_event(event)
+ yield msg_handler.create_and_send_event(event, ratelimit=False)
if "name" in config:
name = config["name"]
@@ -134,7 +134,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id,
"state_key": "",
"content": {"name": name},
- })
+ }, ratelimit=False)
if "topic" in config:
topic = config["topic"]
@@ -144,7 +144,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id,
"state_key": "",
"content": {"topic": topic},
- })
+ }, ratelimit=False)
for invitee in invite_list:
yield msg_handler.create_and_send_event({
@@ -153,7 +153,7 @@ class RoomCreationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"content": {"membership": Membership.INVITE},
- })
+ }, ratelimit=False)
result = {"room_id": room_id}
@@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler):
"state_default": 50,
"ban": 50,
"kick": 50,
- "redact": 50
+ "redact": 50,
+ "invite": 0,
},
)
@@ -311,25 +312,6 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(chunk_data)
@defer.inlineCallbacks
- def get_room_member(self, room_id, member_user_id, auth_user_id):
- """Retrieve a room member from a room.
-
- Args:
- room_id : The room the member is in.
- member_user_id : The member's user ID
- auth_user_id : The user ID of the user making this request.
- Returns:
- The room member, or None if this member does not exist.
- Raises:
- SynapseError if something goes wrong.
- """
- yield self.auth.check_joined_room(room_id, auth_user_id)
-
- member = yield self.store.get_room_member(user_id=member_user_id,
- room_id=room_id)
- defer.returnValue(member)
-
- @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room.
@@ -547,11 +529,19 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True)
- for room in chunk:
- joined_users = yield self.store.get_users_in_room(
- room_id=room["room_id"],
- )
- room["num_joined_members"] = len(joined_users)
+ results = yield defer.gatherResults(
+ [
+ self.store.get_users_in_room(
+ room_id=room["room_id"],
+ )
+ for room in chunk
+ ],
+ consumeErrors=True,
+ )
+
+ for i, room in enumerate(chunk):
+ room["num_joined_members"] = len(results[i])
+
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c2762f92c7..c0b2bd7db0 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -223,6 +223,7 @@ class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
self._handler = None
+ self._room_member_handler = None
def handler(self):
# Avoid cyclic dependency in handler setup
@@ -230,6 +231,11 @@ class TypingNotificationEventSource(object):
self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler
+ def room_member_handler(self):
+ if not self._room_member_handler:
+ self._room_member_handler = self.hs.get_handlers().room_member_handler
+ return self._room_member_handler
+
def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id]
return {
@@ -240,19 +246,25 @@ class TypingNotificationEventSource(object):
},
}
+ @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
handler = self.handler()
+ joined_room_ids = (
+ yield self.room_member_handler().get_joined_rooms_for_user(user)
+ )
+
events = []
for room_id in handler._room_serials:
+ if room_id not in joined_room_ids:
+ continue
if handler._room_serials[room_id] <= from_key:
continue
- # TODO: check if user is in room
events.append(self._make_event_for(room_id))
- return (events, handler._latest_room_serial)
+ defer.returnValue((events, handler._latest_room_serial))
def get_current_key(self):
return self.handler()._latest_room_serial
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2ae1c4d3a4..e8a5dedab4 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient):
"""
Separate HTTP client for talking to google's captcha servers
Only slightly special because accepts partial download responses
+
+ used only by c/s api v1
"""
@defer.inlineCallbacks
diff --git a/synapse/http/server.py b/synapse/http/server.py
index dee49b9e18..93ecbd7589 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -24,7 +24,7 @@ from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json
)
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo
@@ -51,16 +51,90 @@ response_timer = metrics.register_distribution(
labels=["method", "servlet"]
)
+_next_request_id = 0
+
+
+def request_handler(request_handler):
+ """Wraps a method that acts as a request handler with the necessary logging
+ and exception handling.
+
+ The method must have a signature of "handle_foo(self, request)". The
+ argument "self" must have "version_string" and "clock" attributes. The
+ argument "request" must be a twisted HTTP request.
+
+ The method must return a deferred. If the deferred succeeds we assume that
+ a response has been sent. If the deferred fails with a SynapseError we use
+ it to send a JSON response with the appropriate HTTP reponse code. If the
+ deferred fails with any other type of error we send a 500 reponse.
+
+ We insert a unique request-id into the logging context for this request and
+ log the response and duration for this request.
+ """
+
+ @defer.inlineCallbacks
+ def wrapped_request_handler(self, request):
+ global _next_request_id
+ request_id = "%s-%s" % (request.method, _next_request_id)
+ _next_request_id += 1
+ with LoggingContext(request_id) as request_context:
+ request_context.request = request_id
+ code = None
+ start = self.clock.time_msec()
+ try:
+ logger.info(
+ "Received request: %s %s",
+ request.method, request.path
+ )
+ yield request_handler(self, request)
+ code = request.code
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ outgoing_responses_counter.inc(request.method, str(code))
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ version_string=self.version_string,
+ )
+ except:
+ code = 500
+ logger.exception(
+ "Failed handle request %s.%s on %r: %r",
+ request_handler.__module__,
+ request_handler.__name__,
+ self,
+ request
+ )
+ respond_with_json(
+ request,
+ 500,
+ {"error": "Internal server error"},
+ send_cors=True
+ )
+ finally:
+ code = str(code) if code else "-"
+ end = self.clock.time_msec()
+ logger.info(
+ "Processed request: %dms %s %s %s",
+ end-start, code, request.method, request.path
+ )
+ return wrapped_request_handler
+
class HttpServer(object):
""" Interface for registering callbacks on a HTTP server
"""
def register_path(self, method, path_pattern, callback):
- """ Register a callback that get's fired if we receive a http request
+ """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
- If the regex contains groups these get's passed to the calback via
+ If the regex contains groups these gets passed to the calback via
an unpacked tuple.
Args:
@@ -79,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource):
Resources.
Register callbacks via register_path()
+
+ Callbacks can return a tuple of status code and a dict in which case the
+ the dict will automatically be sent to the client as a JSON object.
+
+ The JsonResource is primarily intended for returning JSON, but callbacks
+ may send something other than JSON, they may do so by using the methods
+ on the request object and instead returning None.
"""
isLeaf = True
@@ -98,118 +179,60 @@ class JsonResource(HttpServer, resource.Resource):
self._PathEntry(path_pattern, callback)
)
- def start_listening(self, port):
- """ Registers the http server with the twisted reactor.
-
- Args:
- port (int): The port to listen on.
-
- """
- reactor.listenTCP(
- port,
- server.Site(self),
- interface=self.hs.config.bind_host
- )
-
- # Gets called by twisted
def render(self, request):
- """ This get's called by twisted every time someone sends us a request.
+ """ This gets called by twisted every time someone sends us a request.
"""
- self._async_render_with_logging_context(request)
+ self._async_render(request)
return server.NOT_DONE_YET
- _request_id = 0
-
- @defer.inlineCallbacks
- def _async_render_with_logging_context(self, request):
- request_id = "%s-%s" % (request.method, JsonResource._request_id)
- JsonResource._request_id += 1
- with LoggingContext(request_id) as request_context:
- request_context.request = request_id
- yield self._async_render(request)
-
+ @request_handler
@defer.inlineCallbacks
def _async_render(self, request):
- """ This get's called by twisted every time someone sends us a request.
+ """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
- code = None
start = self.clock.time_msec()
- try:
- # Just say yes to OPTIONS.
- if request.method == "OPTIONS":
- self._send_response(request, 200, {})
- return
-
- # Loop through all the registered callbacks to check if the method
- # and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path)
- if not m:
- continue
-
- # We found a match! Trigger callback and then return the
- # returned response. We pass both the request and any
- # matched groups from the regex to the callback.
-
- callback = path_entry.callback
-
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
- incoming_requests_counter.inc(request.method, servlet_classname)
-
- args = [
- urllib.unquote(u).decode("UTF-8") for u in m.groups()
- ]
-
- logger.info(
- "Received request: %s %s",
- request.method, request.path
- )
+ if request.method == "OPTIONS":
+ self._send_response(request, 200, {})
+ return
+ # Loop through all the registered callbacks to check if the method
+ # and path regex match
+ for path_entry in self.path_regexs.get(request.method, []):
+ m = path_entry.pattern.match(request.path)
+ if not m:
+ continue
+
+ # We found a match! Trigger callback and then return the
+ # returned response. We pass both the request and any
+ # matched groups from the regex to the callback.
+
+ callback = path_entry.callback
+
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+ incoming_requests_counter.inc(request.method, servlet_classname)
- code, response = yield callback(request, *args)
+ args = [
+ urllib.unquote(u).decode("UTF-8") for u in m.groups()
+ ]
+ callback_return = yield callback(request, *args)
+ if callback_return is not None:
+ code, response = callback_return
self._send_response(request, code, response)
- response_timer.inc_by(
- self.clock.time_msec() - start, request.method, servlet_classname
- )
- return
-
- # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- raise UnrecognizedRequestError()
- except CodeMessageException as e:
- if isinstance(e, SynapseError):
- logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
- else:
- logger.exception(e)
-
- code = e.code
- self._send_response(
- request,
- code,
- cs_exception(e),
- response_code_message=e.response_code_message
+ response_timer.inc_by(
+ self.clock.time_msec() - start, request.method, servlet_classname
)
- except Exception as e:
- logger.exception(e)
- self._send_response(
- request,
- 500,
- {"error": "Internal server error"}
- )
- finally:
- code = str(code) if code else "-"
- end = self.clock.time_msec()
- logger.info(
- "Processed request: %dms %s %s %s",
- end-start, code, request.method, request.path
- )
+ return
+
+ # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+ raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object,
response_code_message=None):
@@ -229,20 +252,10 @@ class JsonResource(HttpServer, resource.Resource):
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message,
- pretty_print=self._request_user_agent_is_curl,
+ pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
- @staticmethod
- def _request_user_agent_is_curl(request):
- user_agents = request.requestHeaders.getRawHeaders(
- "User-Agent", default=[]
- )
- for user_agent in user_agents:
- if "curl" in user_agent:
- return True
- return False
-
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
@@ -263,8 +276,8 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string=""):
- if not pretty_print:
- json_bytes = encode_pretty_printed_json(json_object)
+ if pretty_print:
+ json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
json_bytes = encode_canonical_json(json_object)
@@ -304,3 +317,13 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.write(json_bytes)
request.finish()
return NOT_DONE_YET
+
+
+def _request_user_agent_is_curl(request):
+ user_agents = request.requestHeaders.getRawHeaders(
+ "User-Agent", default=[]
+ )
+ for user_agent in user_agents:
+ if "curl" in user_agent:
+ return True
+ return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 265559a3ea..9cda17fcf8 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,6 +23,61 @@ import logging
logger = logging.getLogger(__name__)
+def parse_integer(request, name, default=None, required=False):
+ if name in request.args:
+ try:
+ return int(request.args[name][0])
+ except:
+ message = "Query parameter %r must be an integer" % (name,)
+ raise SynapseError(400, message)
+ else:
+ if required:
+ message = "Missing integer query parameter %r" % (name,)
+ raise SynapseError(400, message)
+ else:
+ return default
+
+
+def parse_boolean(request, name, default=None, required=False):
+ if name in request.args:
+ try:
+ return {
+ "true": True,
+ "false": False,
+ }[request.args[name][0]]
+ except:
+ message = (
+ "Boolean query parameter %r must be one of"
+ " ['true', 'false']"
+ ) % (name,)
+ raise SynapseError(400, message)
+ else:
+ if required:
+ message = "Missing boolean query parameter %r" % (name,)
+ raise SynapseError(400, message)
+ else:
+ return default
+
+
+def parse_string(request, name, default=None, required=False,
+ allowed_values=None, param_type="string"):
+ if name in request.args:
+ value = request.args[name][0]
+ if allowed_values is not None and value not in allowed_values:
+ message = "Query parameter %r must be one of [%s]" % (
+ name, ", ".join(repr(v) for v in allowed_values)
+ )
+ raise SynapseError(message)
+ else:
+ return value
+ else:
+ if required:
+ message = "Missing %s query parameter %r" % (param_type, name)
+ raise SynapseError(400, message)
+ else:
+ return default
+
+
class RestServlet(object):
""" A Synapse REST Servlet.
@@ -56,58 +111,3 @@ class RestServlet(object):
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")
-
- @staticmethod
- def parse_integer(request, name, default=None, required=False):
- if name in request.args:
- try:
- return int(request.args[name][0])
- except:
- message = "Query parameter %r must be an integer" % (name,)
- raise SynapseError(400, message)
- else:
- if required:
- message = "Missing integer query parameter %r" % (name,)
- raise SynapseError(400, message)
- else:
- return default
-
- @staticmethod
- def parse_boolean(request, name, default=None, required=False):
- if name in request.args:
- try:
- return {
- "true": True,
- "false": False,
- }[request.args[name][0]]
- except:
- message = (
- "Boolean query parameter %r must be one of"
- " ['true', 'false']"
- ) % (name,)
- raise SynapseError(400, message)
- else:
- if required:
- message = "Missing boolean query parameter %r" % (name,)
- raise SynapseError(400, message)
- else:
- return default
-
- @staticmethod
- def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string"):
- if name in request.args:
- value = request.args[name][0]
- if allowed_values is not None and value not in allowed_values:
- message = "Query parameter %r must be one of [%s]" % (
- name, ", ".join(repr(v) for v in allowed_values)
- )
- raise SynapseError(message)
- else:
- return value
- else:
- if required:
- message = "Missing %s query parameter %r" % (param_type, name)
- raise SynapseError(400, message)
- else:
- return default
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index dffb8a4861..9233ea3da9 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
import logging
from resource import getrusage, getpagesize, RUSAGE_SELF
+import os
+import stat
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@@ -109,3 +111,36 @@ resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# pages
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
+
+TYPES = {
+ stat.S_IFSOCK: "SOCK",
+ stat.S_IFLNK: "LNK",
+ stat.S_IFREG: "REG",
+ stat.S_IFBLK: "BLK",
+ stat.S_IFDIR: "DIR",
+ stat.S_IFCHR: "CHR",
+ stat.S_IFIFO: "FIFO",
+}
+
+
+def _process_fds():
+ counts = {(k,): 0 for k in TYPES.values()}
+ counts[("other",)] = 0
+
+ for fd in os.listdir("/proc/self/fd"):
+ try:
+ s = os.stat("/proc/self/fd/%s" % (fd))
+ fmt = stat.S_IFMT(s.st_mode)
+ if fmt in TYPES:
+ t = TYPES[fmt]
+ else:
+ t = "other"
+
+ counts[(t,)] += 1
+ except OSError:
+ # the dirh itself used by listdir() is usually missing by now
+ pass
+
+ return counts
+
+get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 7121d659d0..78eb28e4b2 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
-from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import synapse.metrics
@@ -59,10 +58,11 @@ class _NotificationListener(object):
self.limit = limit
self.timeout = timeout
self.deferred = deferred
-
self.rooms = rooms
+ self.timer = None
- self.pending_notifications = []
+ def notified(self):
+ return self.deferred.called
def notify(self, notifier, events, start_token, end_token):
""" Inform whoever is listening about the new events. This will
@@ -78,16 +78,27 @@ class _NotificationListener(object):
except defer.AlreadyCalledError:
pass
+ # Should the following be done be using intrusively linked lists?
+ # -- erikj
+
for room in self.rooms:
lst = notifier.room_to_listeners.get(room, set())
lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self)
+
if self.appservice:
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
+ # Cancel the timeout for this notifer if one exists.
+ if self.timer is not None:
+ try:
+ notifier.clock.cancel_call_later(self.timer)
+ except:
+ logger.warn("Failed to cancel notifier timer")
+
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
@@ -150,8 +161,6 @@ class Notifier(object):
listening to the room, and any listeners for the users in the
`extra_users` param.
"""
- yield run_on_reactor()
-
# poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
@@ -161,10 +170,18 @@ class Notifier(object):
room_source = self.event_sources.sources["room"]
- listeners = self.room_to_listeners.get(room_id, set()).copy()
+ room_listeners = self.room_to_listeners.get(room_id, set())
+
+ _discard_if_notified(room_listeners)
+
+ listeners = room_listeners.copy()
for user in extra_users:
- listeners |= self.user_to_listeners.get(user, set()).copy()
+ user_listeners = self.user_to_listeners.get(user, set())
+
+ _discard_if_notified(user_listeners)
+
+ listeners |= user_listeners
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
@@ -173,9 +190,13 @@ class Notifier(object):
# receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event):
- listeners |= self.appservice_to_listeners.get(
+ app_listeners = self.appservice_to_listeners.get(
appservice, set()
- ).copy()
+ )
+
+ _discard_if_notified(app_listeners)
+
+ listeners |= app_listeners
logger.debug("on_new_room_event listeners %s", listeners)
@@ -216,8 +237,6 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms.
"""
- yield run_on_reactor()
-
# TODO(paul): This is horrible, having to manually list every event
# source here individually
presence_source = self.event_sources.sources["presence"]
@@ -226,10 +245,18 @@ class Notifier(object):
listeners = set()
for user in users:
- listeners |= self.user_to_listeners.get(user, set()).copy()
+ user_listeners = self.user_to_listeners.get(user, set())
+
+ _discard_if_notified(user_listeners)
+
+ listeners |= user_listeners
for room in rooms:
- listeners |= self.room_to_listeners.get(room, set()).copy()
+ room_listeners = self.room_to_listeners.get(room, set())
+
+ _discard_if_notified(room_listeners)
+
+ listeners |= room_listeners
@defer.inlineCallbacks
def notify(listener):
@@ -300,14 +327,20 @@ class Notifier(object):
self._register_with_keys(listener[0])
result = yield callback()
+ timer = [None]
+
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
+ timer[0] = None
listener[0].notify(self, [], from_token, from_token)
- self.clock.call_later(timeout/1000., _timeout_listener)
+ # We create multiple notification listeners so we have to manage
+ # canceling the timeout ourselves.
+ timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
+
while not result and not timed_out[0]:
yield deferred
deferred = defer.Deferred()
@@ -322,6 +355,12 @@ class Notifier(object):
self._register_with_keys(listener[0])
result = yield callback()
+ if timer[0] is not None:
+ try:
+ self.clock.cancel_call_later(timer[0])
+ except:
+ logger.exception("Failed to cancel notifer timer")
+
defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout):
@@ -360,6 +399,8 @@ class Notifier(object):
def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token.
+ # Remove the timer from the listener so we don't try to cancel it.
+ listener.timer = None
listener.notify(
self,
[],
@@ -375,8 +416,11 @@ class Notifier(object):
if not timeout:
_timeout_listener()
else:
- self.clock.call_later(timeout/1000.0, _timeout_listener)
-
+ # Only add the timer if the listener hasn't been notified
+ if not listener.notified():
+ listener.timer = self.clock.call_later(
+ timeout/1000.0, _timeout_listener
+ )
return
@log_function
@@ -427,3 +471,17 @@ class Notifier(object):
listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners
+
+ for l in new_listeners:
+ l.rooms.add(room_id)
+
+
+def _discard_if_notified(listener_set):
+ """Remove any 'stale' listeners from the given set.
+ """
+ to_discard = set()
+ for l in listener_set:
+ if l.notified():
+ to_discard.add(l)
+
+ listener_set -= to_discard
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 0727f772a5..5575c847f9 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -253,7 +253,8 @@ class Pusher(object):
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
- self.app_id, self.pushkey, self.last_token)
+ self.app_id, self.pushkey, self.user_name, self.last_token
+ )
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
@@ -314,7 +315,7 @@ class Pusher(object):
pk
)
yield self.hs.get_pusherpool().remove_pusher(
- self.app_id, pk
+ self.app_id, pk, self.user_name
)
if not self.alive:
@@ -326,6 +327,7 @@ class Pusher(object):
self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
+ self.user_name,
self.last_token,
self.clock.time_msec()
)
@@ -334,6 +336,7 @@ class Pusher(object):
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
+ self.user_name,
self.failing_since)
else:
if not self.failing_since:
@@ -341,6 +344,7 @@ class Pusher(object):
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
+ self.user_name,
self.failing_since
)
@@ -358,6 +362,7 @@ class Pusher(object):
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
+ self.user_name,
self.last_token
)
@@ -365,6 +370,7 @@ class Pusher(object):
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
+ self.user_name,
self.failing_since
)
else:
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 60fd35fbfb..f3d1cf5c5f 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,3 +1,17 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
@@ -112,7 +126,25 @@ def make_base_prepend_override_rules():
def make_base_append_override_rules():
return [
{
- 'rule_id': 'global/override/.m.rule.call',
+ 'rule_id': 'global/override/.m.rule.suppress_notices',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.msgtype',
+ 'pattern': 'm.notice',
+ }
+ ],
+ 'actions': [
+ 'dont_notify',
+ ]
+ }
+ ]
+
+
+def make_base_append_underride_rules(user):
+ return [
+ {
+ 'rule_id': 'global/underride/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
@@ -132,19 +164,6 @@ def make_base_append_override_rules():
]
},
{
- 'rule_id': 'global/override/.m.rule.suppress_notices',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'content.msgtype',
- 'pattern': 'm.notice',
- }
- ],
- 'actions': [
- 'dont_notify',
- ]
- },
- {
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
@@ -162,7 +181,7 @@ def make_base_append_override_rules():
]
},
{
- 'rule_id': 'global/override/.m.rule.room_one_to_one',
+ 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
@@ -179,12 +198,7 @@ def make_base_append_override_rules():
'value': False
}
]
- }
- ]
-
-
-def make_base_append_underride_rules(user):
- return [
+ },
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 90babd7224..0ab2f65972 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -19,10 +19,7 @@ from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
-from syutil.jsonutil import encode_canonical_json
-
import logging
-import simplejson as json
logger = logging.getLogger(__name__)
@@ -52,12 +49,10 @@ class PusherPool:
@defer.inlineCallbacks
def start(self):
pushers = yield self.store.get_all_pushers()
- for p in pushers:
- p['data'] = json.loads(p['data'])
self._start_pushers(pushers)
@defer.inlineCallbacks
- def add_pusher(self, user_name, profile_tag, kind, app_id,
+ def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
@@ -71,7 +66,7 @@ class PusherPool:
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
- "pushkey_ts": self.hs.get_clock().time_msec(),
+ "ts": self.hs.get_clock().time_msec(),
"lang": lang,
"data": data,
"last_token": None,
@@ -79,17 +74,50 @@ class PusherPool:
"failing_since": None
})
yield self._add_pusher_to_store(
- user_name, profile_tag, kind, app_id,
+ user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@defer.inlineCallbacks
- def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id,
- app_display_name, device_display_name,
+ def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
+ not_user_id):
+ to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
+ app_id, pushkey
+ )
+ for p in to_remove:
+ if p['user_name'] != not_user_id:
+ logger.info(
+ "Removing pusher for app id %s, pushkey %s, user %s",
+ app_id, pushkey, p['user_name']
+ )
+ self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+
+ @defer.inlineCallbacks
+ def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
+ all = yield self.store.get_all_pushers()
+ logger.info(
+ "Removing all pushers for user %s except access token %s",
+ user_id, not_access_token_id
+ )
+ for p in all:
+ if (
+ p['user_name'] == user_id and
+ p['access_token'] != not_access_token_id
+ ):
+ logger.info(
+ "Removing pusher for app id %s, pushkey %s, user %s",
+ p['app_id'], p['pushkey'], p['user_name']
+ )
+ self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+
+ @defer.inlineCallbacks
+ def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
+ access_token=access_token,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
@@ -98,9 +126,9 @@ class PusherPool:
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
- data=encode_canonical_json(data).decode("UTF-8"),
+ data=data,
)
- self._refresh_pusher((app_id, pushkey))
+ self._refresh_pusher(app_id, pushkey, user_name)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
@@ -112,7 +140,7 @@ class PusherPool:
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
- pushkey_ts=pusherdict['pushkey_ts'],
+ pushkey_ts=pusherdict['ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
@@ -125,30 +153,48 @@ class PusherPool:
)
@defer.inlineCallbacks
- def _refresh_pusher(self, app_id_pushkey):
- p = yield self.store.get_pushers_by_app_id_and_pushkey(
- app_id_pushkey
+ def _refresh_pusher(self, app_id, pushkey, user_name):
+ resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
+ app_id, pushkey
)
- p['data'] = json.loads(p['data'])
- self._start_pushers([p])
+ p = None
+ for r in resultlist:
+ if r['user_name'] == user_name:
+ p = r
+
+ if p:
+
+ self._start_pushers([p])
def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
- p = self._create_pusher(pusherdict)
+ try:
+ p = self._create_pusher(pusherdict)
+ except PusherConfigException:
+ logger.exception("Couldn't start a pusher: caught PusherConfigException")
+ continue
if p:
- fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
+ fullid = "%s:%s:%s" % (
+ pusherdict['app_id'],
+ pusherdict['pushkey'],
+ pusherdict['user_name']
+ )
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
+ logger.info("Started pushers")
+
@defer.inlineCallbacks
- def remove_pusher(self, app_id, pushkey):
- fullid = "%s:%s" % (app_id, pushkey)
+ def remove_pusher(self, app_id, pushkey, user_name):
+ fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
- yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)
+ yield self.store.delete_pusher_by_app_id_pushkey_user_name(
+ app_id, pushkey, user_name
+ )
diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py
index 660aa4e10e..4c591aa638 100644
--- a/synapse/push/rulekinds.py
+++ b/synapse/push/rulekinds.py
@@ -1,3 +1,17 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
PRIORITY_CLASS_MAP = {
'underride': 1,
'sender': 2,
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 6b6d5508b8..b1baad81c4 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,10 +1,24 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
- "syutil>=0.0.3": ["syutil"],
+ "syutil>=0.0.6": ["syutil>=0.0.6"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@@ -19,7 +33,7 @@ REQUIREMENTS = {
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
- "matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
+ "matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"],
}
}
@@ -43,13 +57,13 @@ DEPENDENCY_LINKS = [
),
github_link(
project="matrix-org/syutil",
- version="v0.0.3",
- egg="syutil-0.0.3",
+ version="v0.0.6",
+ egg="syutil-0.0.6",
),
github_link(
project="matrix-org/matrix-angular-sdk",
- version="v0.6.5",
- egg="matrix_angular_sdk-0.6.5",
+ version="v0.6.6",
+ egg="matrix_angular_sdk-0.6.6",
),
]
diff --git a/synapse/rest/appservice/v1/base.py b/synapse/rest/appservice/v1/base.py
deleted file mode 100644
index 65d5bcf9be..0000000000
--- a/synapse/rest/appservice/v1/base.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains base REST classes for constructing client v1 servlets.
-"""
-
-from synapse.http.servlet import RestServlet
-from synapse.api.urls import APP_SERVICE_PREFIX
-import re
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-def as_path_pattern(path_regex):
- """Creates a regex compiled appservice path with the correct path
- prefix.
-
- Args:
- path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
- Returns:
- SRE_Pattern
- """
- return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
-
-
-class AppServiceRestServlet(RestServlet):
- """A base Synapse REST Servlet for the application services version 1 API.
- """
-
- def __init__(self, hs):
- self.hs = hs
- self.handler = hs.get_handlers().appservice_handler
diff --git a/synapse/rest/appservice/v1/register.py b/synapse/rest/appservice/v1/register.py
deleted file mode 100644
index ea24d88f79..0000000000
--- a/synapse/rest/appservice/v1/register.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains REST servlets to do with registration: /register"""
-from twisted.internet import defer
-
-from base import AppServiceRestServlet, as_path_pattern
-from synapse.api.errors import CodeMessageException, SynapseError
-from synapse.storage.appservice import ApplicationService
-
-import json
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class RegisterRestServlet(AppServiceRestServlet):
- """Handles AS registration with the home server.
- """
-
- PATTERN = as_path_pattern("/register$")
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- params = _parse_json(request)
-
- # sanity check required params
- try:
- as_token = params["as_token"]
- as_url = params["url"]
- if (not isinstance(as_token, basestring) or
- not isinstance(as_url, basestring)):
- raise ValueError
- except (KeyError, ValueError):
- raise SynapseError(
- 400, "Missed required keys: as_token(str) / url(str)."
- )
-
- try:
- app_service = ApplicationService(
- as_token, as_url, params["namespaces"]
- )
- except ValueError as e:
- raise SynapseError(400, e.message)
-
- app_service = yield self.handler.register(app_service)
- hs_token = app_service.hs_token
-
- defer.returnValue((200, {
- "hs_token": hs_token
- }))
-
-
-class UnregisterRestServlet(AppServiceRestServlet):
- """Handles AS registration with the home server.
- """
-
- PATTERN = as_path_pattern("/unregister$")
-
- def on_POST(self, request):
- params = _parse_json(request)
- try:
- as_token = params["as_token"]
- if not isinstance(as_token, basestring):
- raise ValueError
- except (KeyError, ValueError):
- raise SynapseError(400, "Missing required key: as_token(str)")
-
- yield self.handler.unregister(as_token)
-
- raise CodeMessageException(500, "Not implemented")
-
-
-def _parse_json(request):
- try:
- content = json.loads(request.content.read())
- if type(content) != dict:
- raise SynapseError(400, "Content must be a JSON object.")
- return content
- except ValueError as e:
- logger.warn(e)
- raise SynapseError(400, "Content not JSON.")
-
-
-def register_servlets(hs, http_server):
- RegisterRestServlet(hs).register(http_server)
- UnregisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 72332bdb10..504a5e432f 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_auth()
+ self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 6045e86f34..c83287c028 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- user, _ = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
- content['app_id'], content['pushkey']
+ content['app_id'], content['pushkey'], user_name=user.to_string()
)
defer.returnValue((200, {}))
@@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM)
+ append = False
+ if 'append' in content:
+ append = content['append']
+
+ if not append:
+ yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ app_id=content['app_id'],
+ pushkey=content['pushkey'],
+ not_user_id=user.to_string()
+ )
+
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
+ access_token=client.token_id,
profile_tag=content['profile_tag'],
kind=content['kind'],
app_id=content['app_id'],
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index bca65f2a6a..28d95b2729 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -15,7 +15,10 @@
from . import (
sync,
- filter
+ filter,
+ account,
+ register,
+ auth
)
from synapse.http.server import JsonResource
@@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource):
def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)
+ account.register_servlets(hs, client_resource)
+ register.register_servlets(hs, client_resource)
+ auth.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 22dc5cb862..4540e8dcf7 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,9 +17,11 @@
"""
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.api.errors import SynapseError
import re
import logging
+import simplejson
logger = logging.getLogger(__name__)
@@ -36,3 +38,23 @@ def client_v2_pattern(path_regex):
SRE_Pattern
"""
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
+
+
+def parse_request_allow_empty(request):
+ content = request.content.read()
+ if content is None or content == '':
+ return None
+ try:
+ return simplejson.loads(content)
+ except simplejson.JSONDecodeError:
+ raise SynapseError(400, "Content not JSON.")
+
+
+def parse_json_dict_from_request(request):
+ try:
+ content = simplejson.loads(request.content.read())
+ if type(content) != dict:
+ raise SynapseError(400, "Content must be a JSON object.")
+ return content
+ except simplejson.JSONDecodeError:
+ raise SynapseError(400, "Content not JSON.")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
new file mode 100644
index 0000000000..b082140f1f
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import LoginError, SynapseError, Codes
+from synapse.http.servlet import RestServlet
+from synapse.util.async import run_on_reactor
+
+from ._base import client_v2_pattern, parse_json_dict_from_request
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordRestServlet(RestServlet):
+ PATTERN = client_v2_pattern("/account/password")
+
+ def __init__(self, hs):
+ super(PasswordRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_handlers().auth_handler
+ self.login_handler = hs.get_handlers().login_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield run_on_reactor()
+
+ body = parse_json_dict_from_request(request)
+
+ authed, result, params = yield self.auth_handler.check_auth([
+ [LoginType.PASSWORD],
+ [LoginType.EMAIL_IDENTITY]
+ ], body)
+
+ if not authed:
+ defer.returnValue((401, result))
+
+ user_id = None
+
+ if LoginType.PASSWORD in result:
+ # if using password, they should also be logged in
+ auth_user, client = yield self.auth.get_user_by_req(request)
+ if auth_user.to_string() != result[LoginType.PASSWORD]:
+ raise LoginError(400, "", Codes.UNKNOWN)
+ user_id = auth_user.to_string()
+ elif LoginType.EMAIL_IDENTITY in result:
+ threepid = result[LoginType.EMAIL_IDENTITY]
+ if 'medium' not in threepid or 'address' not in threepid:
+ raise SynapseError(500, "Malformed threepid")
+ # if using email, we must know about the email they're authing with!
+ threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ threepid['medium'], threepid['address']
+ )
+ if not threepid_user_id:
+ raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+ user_id = threepid_user_id
+ else:
+ logger.error("Auth succeeded but no known type!", result.keys())
+ raise SynapseError(500, "", Codes.UNKNOWN)
+
+ if 'new_password' not in params:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+ new_password = params['new_password']
+
+ yield self.login_handler.set_password(
+ user_id, new_password, None
+ )
+
+ defer.returnValue((200, {}))
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+
+class ThreepidRestServlet(RestServlet):
+ PATTERN = client_v2_pattern("/account/3pid")
+
+ def __init__(self, hs):
+ super(ThreepidRestServlet, self).__init__()
+ self.hs = hs
+ self.login_handler = hs.get_handlers().login_handler
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ yield run_on_reactor()
+
+ auth_user, _ = yield self.auth.get_user_by_req(request)
+
+ threepids = yield self.hs.get_datastore().user_get_threepids(
+ auth_user.to_string()
+ )
+
+ defer.returnValue((200, {'threepids': threepids}))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield run_on_reactor()
+
+ body = parse_json_dict_from_request(request)
+
+ if 'threePidCreds' not in body:
+ raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
+ threePidCreds = body['threePidCreds']
+
+ auth_user, client = yield self.auth.get_user_by_req(request)
+
+ threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
+
+ if not threepid:
+ raise SynapseError(
+ 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
+ )
+
+ for reqd in ['medium', 'address', 'validated_at']:
+ if reqd not in threepid:
+ logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
+ raise SynapseError(500, "Invalid response from ID Server")
+
+ yield self.login_handler.add_threepid(
+ auth_user.to_string(),
+ threepid['medium'],
+ threepid['address'],
+ threepid['validated_at'],
+ )
+
+ if 'bind' in body and body['bind']:
+ logger.debug(
+ "Binding emails %s to %s",
+ threepid, auth_user.to_string()
+ )
+ yield self.identity_handler.bind_threepid(
+ threePidCreds, auth_user.to_string()
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ PasswordRestServlet(hs).register(http_server)
+ ThreepidRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
new file mode 100644
index 0000000000..4c726f05f5
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -0,0 +1,190 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import SynapseError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_pattern
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+RECAPTCHA_TEMPLATE = """
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<script src="https://www.google.com/recaptcha/api.js"
+ async defer></script>
+<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+<script>
+function captchaDone() {
+ $('#registrationForm').submit();
+}
+</script>
+</head>
+<body>
+<form id="registrationForm" method="post" action="%(myurl)s">
+ <div>
+ <p>
+ Hello! We need to prevent computer programs and other automated
+ things from creating accounts on this server.
+ </p>
+ <p>
+ Please verify that you're not a robot.
+ </p>
+ <input type="hidden" name="session" value="%(session)s" />
+ <div class="g-recaptcha"
+ data-sitekey="%(sitekey)s"
+ data-callback="captchaDone">
+ </div>
+ <noscript>
+ <input type="submit" value="All Done" />
+ </noscript>
+ </div>
+ </div>
+</form>
+</body>
+</html>
+"""
+
+SUCCESS_TEMPLATE = """
+<html>
+<head>
+<title>Success!</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+<script>
+if (window.onAuthDone != undefined) {
+ window.onAuthDone();
+}
+</script>
+</head>
+<body>
+ <div>
+ <p>Thank you</p>
+ <p>You may now close this window and return to the application</p>
+ </div>
+</body>
+</html>
+"""
+
+
+class AuthRestServlet(RestServlet):
+ """
+ Handles Client / Server API authentication in any situations where it
+ cannot be handled in the normal flow (with requests to the same endpoint).
+ Current use is for web fallback auth.
+ """
+ PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
+
+ def __init__(self, hs):
+ super(AuthRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_handlers().auth_handler
+ self.registration_handler = hs.get_handlers().registration_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, stagetype):
+ yield
+ if stagetype == LoginType.RECAPTCHA:
+ if ('session' not in request.args or
+ len(request.args['session']) == 0):
+ raise SynapseError(400, "No session supplied")
+
+ session = request.args["session"][0]
+
+ html = RECAPTCHA_TEMPLATE % {
+ 'session': session,
+ 'myurl': "%s/auth/%s/fallback/web" % (
+ CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ ),
+ 'sitekey': self.hs.config.recaptcha_public_key,
+ }
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Server", self.hs.version_string)
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ request.finish()
+ defer.returnValue(None)
+ else:
+ raise SynapseError(404, "Unknown auth stage type")
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, stagetype):
+ yield
+ if stagetype == "m.login.recaptcha":
+ if ('g-recaptcha-response' not in request.args or
+ len(request.args['g-recaptcha-response'])) == 0:
+ raise SynapseError(400, "No captcha response supplied")
+ if ('session' not in request.args or
+ len(request.args['session'])) == 0:
+ raise SynapseError(400, "No session supplied")
+
+ session = request.args['session'][0]
+
+ authdict = {
+ 'response': request.args['g-recaptcha-response'][0],
+ 'session': session,
+ }
+
+ success = yield self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA,
+ authdict,
+ self.hs.get_ip_from_request(request)
+ )
+
+ if success:
+ html = SUCCESS_TEMPLATE
+ else:
+ html = RECAPTCHA_TEMPLATE % {
+ 'session': session,
+ 'myurl': "%s/auth/%s/fallback/web" % (
+ CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ ),
+ 'sitekey': self.hs.config.recaptcha_public_key,
+ }
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Server", self.hs.version_string)
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ request.finish()
+
+ defer.returnValue(None)
+ else:
+ raise SynapseError(404, "Unknown auth stage type")
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+
+def register_servlets(hs, http_server):
+ AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
new file mode 100644
index 0000000000..3640fb4a29
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import SynapseError, Codes
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_pattern, parse_request_allow_empty
+
+import logging
+import hmac
+from hashlib import sha1
+from synapse.util.async import run_on_reactor
+
+
+# We ought to be using hmac.compare_digest() but on older pythons it doesn't
+# exist. It's a _really minor_ security flaw to use plain string comparison
+# because the timing attack is so obscured by all the other code here it's
+# unlikely to make much difference
+if hasattr(hmac, "compare_digest"):
+ compare_digest = hmac.compare_digest
+else:
+ compare_digest = lambda a, b: a == b
+
+
+logger = logging.getLogger(__name__)
+
+
+class RegisterRestServlet(RestServlet):
+ PATTERN = client_v2_pattern("/register")
+
+ def __init__(self, hs):
+ super(RegisterRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_handlers().auth_handler
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.login_handler = hs.get_handlers().login_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield run_on_reactor()
+
+ body = parse_request_allow_empty(request)
+ if 'password' not in body:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+ if 'username' in body:
+ desired_username = body['username']
+ yield self.registration_handler.check_username(desired_username)
+
+ is_using_shared_secret = False
+ is_application_server = False
+
+ service = None
+ if 'access_token' in request.args:
+ service = yield self.auth.get_appservice_by_req(request)
+
+ if self.hs.config.enable_registration_captcha:
+ flows = [
+ [LoginType.RECAPTCHA],
+ [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
+ ]
+ else:
+ flows = [
+ [LoginType.DUMMY],
+ [LoginType.EMAIL_IDENTITY]
+ ]
+
+ if service:
+ is_application_server = True
+ elif 'mac' in body:
+ # Check registration-specific shared secret auth
+ if 'username' not in body:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+ self._check_shared_secret_auth(
+ body['username'], body['mac']
+ )
+ is_using_shared_secret = True
+ else:
+ authed, result, params = yield self.auth_handler.check_auth(
+ flows, body, self.hs.get_ip_from_request(request)
+ )
+
+ if not authed:
+ defer.returnValue((401, result))
+
+ can_register = (
+ not self.hs.config.disable_registration
+ or is_application_server
+ or is_using_shared_secret
+ )
+ if not can_register:
+ raise SynapseError(403, "Registration has been disabled")
+
+ if 'password' not in params:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
+ desired_username = params['username'] if 'username' in params else None
+ new_password = params['password']
+
+ (user_id, token) = yield self.registration_handler.register(
+ localpart=desired_username,
+ password=new_password
+ )
+
+ if LoginType.EMAIL_IDENTITY in result:
+ threepid = result[LoginType.EMAIL_IDENTITY]
+
+ for reqd in ['medium', 'address', 'validated_at']:
+ if reqd not in threepid:
+ logger.info("Can't add incomplete 3pid")
+ else:
+ yield self.login_handler.add_threepid(
+ user_id,
+ threepid['medium'],
+ threepid['address'],
+ threepid['validated_at'],
+ )
+
+ if 'bind_email' in params and params['bind_email']:
+ logger.info("bind_email specified: binding")
+
+ emailThreepid = result[LoginType.EMAIL_IDENTITY]
+ threepid_creds = emailThreepid['threepid_creds']
+ logger.debug("Binding emails %s to %s" % (
+ emailThreepid, user_id
+ ))
+ yield self.identity_handler.bind_threepid(threepid_creds, user_id)
+ else:
+ logger.info("bind_email not specified: not binding email")
+
+ result = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ }
+
+ defer.returnValue((200, result))
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+ def _check_shared_secret_auth(self, username, mac):
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ user = username.encode("utf-8")
+
+ # str() because otherwise hmac complains that 'unicode' does not
+ # have the buffer interface
+ got_mac = str(mac)
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret,
+ msg=user,
+ digestmod=sha1,
+ ).hexdigest()
+
+ if compare_digest(want_mac, got_mac):
+ return True
+ else:
+ raise SynapseError(
+ 403, "HMAC incorrect",
+ )
+
+
+def register_servlets(hs, http_server):
+ RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 3056ec45cf..f2fd0b9f32 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -15,7 +15,9 @@
from twisted.internet import defer
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import (
+ RestServlet, parse_string, parse_integer, parse_boolean
+)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events.utils import (
@@ -87,20 +89,20 @@ class SyncRestServlet(RestServlet):
def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request)
- timeout = self.parse_integer(request, "timeout", default=0)
- limit = self.parse_integer(request, "limit", required=True)
- gap = self.parse_boolean(request, "gap", default=True)
- sort = self.parse_string(
+ timeout = parse_integer(request, "timeout", default=0)
+ limit = parse_integer(request, "limit", required=True)
+ gap = parse_boolean(request, "gap", default=True)
+ sort = parse_string(
request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT
)
- since = self.parse_string(request, "since")
- set_presence = self.parse_string(
+ since = parse_string(request, "since")
+ set_presence = parse_string(
request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE
)
- backfill = self.parse_boolean(request, "backfill", default=False)
- filter_id = self.parse_string(request, "filter", default=None)
+ backfill = parse_boolean(request, "backfill", default=False)
+ filter_id = parse_string(request, "filter", default=None)
logger.info(
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
diff --git a/synapse/rest/appservice/__init__.py b/synapse/rest/key/__init__.py
index 1a84d94cd9..1a84d94cd9 100644
--- a/synapse/rest/appservice/__init__.py
+++ b/synapse/rest/key/__init__.py
diff --git a/synapse/rest/key/v1/__init__.py b/synapse/rest/key/v1/__init__.py
new file mode 100644
index 0000000000..1a84d94cd9
--- /dev/null
+++ b/synapse/rest/key/v1/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/http/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index 71e9a51f5c..71e9a51f5c 100644
--- a/synapse/http/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
diff --git a/synapse/rest/appservice/v1/__init__.py b/synapse/rest/key/v2/__init__.py
index a7877609ad..1c14791b09 100644
--- a/synapse/rest/appservice/v1/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,18 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import register
-from synapse.http.server import JsonResource
+from twisted.web.resource import Resource
+from .local_key_resource import LocalKey
+from .remote_key_resource import RemoteKey
-class AppServiceRestResource(JsonResource):
- """A resource for version 1 of the matrix application service API."""
-
+class KeyApiV2Resource(Resource):
def __init__(self, hs):
- JsonResource.__init__(self, hs)
- self.register_servlets(self, hs)
-
- @staticmethod
- def register_servlets(appservice_resource, hs):
- register.register_servlets(hs, appservice_resource)
+ Resource.__init__(self)
+ self.putChild("server", LocalKey(hs))
+ self.putChild("query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
new file mode 100644
index 0000000000..33cbd7cf8e
--- /dev/null
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.web.resource import Resource
+from synapse.http.server import respond_with_json_bytes
+from syutil.crypto.jsonsign import sign_json
+from syutil.base64util import encode_base64
+from syutil.jsonutil import encode_canonical_json
+from hashlib import sha256
+from OpenSSL import crypto
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class LocalKey(Resource):
+ """HTTP resource containing encoding the TLS X.509 certificate and NACL
+ signature verification keys for this server::
+
+ GET /_matrix/key/v2/server/a.key.id HTTP/1.1
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+ {
+ "valid_until_ts": # integer posix timestamp when this result expires.
+ "server_name": "this.server.example.com"
+ "verify_keys": {
+ "algorithm:version": {
+ "key": # base64 encoded NACL verification key.
+ }
+ },
+ "old_verify_keys": {
+ "algorithm:version": {
+ "expired_ts": # integer posix timestamp when the key expired.
+ "key": # base64 encoded NACL verification key.
+ }
+ }
+ "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
+ "signatures": {
+ "this.server.example.com": {
+ "algorithm:version": # NACL signature for this server
+ }
+ }
+ }
+ """
+
+ isLeaf = True
+
+ def __init__(self, hs):
+ self.version_string = hs.version_string
+ self.config = hs.config
+ self.clock = hs.clock
+ self.update_response_body(self.clock.time_msec())
+ Resource.__init__(self)
+
+ def update_response_body(self, time_now_msec):
+ refresh_interval = self.config.key_refresh_interval
+ self.valid_until_ts = int(time_now_msec + refresh_interval)
+ self.response_body = encode_canonical_json(self.response_json_object())
+
+ def response_json_object(self):
+ verify_keys = {}
+ for key in self.config.signing_key:
+ verify_key_bytes = key.verify_key.encode()
+ key_id = "%s:%s" % (key.alg, key.version)
+ verify_keys[key_id] = {
+ u"key": encode_base64(verify_key_bytes)
+ }
+
+ old_verify_keys = {}
+ for key in self.config.old_signing_keys:
+ key_id = "%s:%s" % (key.alg, key.version)
+ verify_key_bytes = key.encode()
+ old_verify_keys[key_id] = {
+ u"key": encode_base64(verify_key_bytes),
+ u"expired_ts": key.expired,
+ }
+
+ x509_certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1,
+ self.config.tls_certificate
+ )
+
+ sha256_fingerprint = sha256(x509_certificate_bytes).digest()
+
+ json_object = {
+ u"valid_until_ts": self.valid_until_ts,
+ u"server_name": self.config.server_name,
+ u"verify_keys": verify_keys,
+ u"old_verify_keys": old_verify_keys,
+ u"tls_fingerprints": [{
+ u"sha256": encode_base64(sha256_fingerprint),
+ }]
+ }
+ for key in self.config.signing_key:
+ json_object = sign_json(
+ json_object,
+ self.config.server_name,
+ key,
+ )
+ return json_object
+
+ def render_GET(self, request):
+ time_now = self.clock.time_msec()
+ # Update the expiry time if less than half the interval remains.
+ if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
+ self.update_response_body(time_now)
+ return respond_with_json_bytes(
+ request, 200, self.response_body,
+ version_string=self.version_string
+ )
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
new file mode 100644
index 0000000000..e434847b45
--- /dev/null
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -0,0 +1,242 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.http.servlet import parse_integer
+from synapse.api.errors import SynapseError, Codes
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+
+
+from io import BytesIO
+import json
+import logging
+logger = logging.getLogger(__name__)
+
+
+class RemoteKey(Resource):
+ """HTTP resource for retreiving the TLS certificate and NACL signature
+ verification keys for a collection of servers. Checks that the reported
+ X.509 TLS certificate matches the one used in the HTTPS connection. Checks
+ that the NACL signature for the remote server is valid. Returns a dict of
+ JSON signed by both the remote server and by this server.
+
+ Supports individual GET APIs and a bulk query POST API.
+
+ Requsts:
+
+ GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
+
+ GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
+
+ POST /_matrix/v2/query HTTP/1.1
+ Content-Type: application/json
+ {
+ "server_keys": {
+ "remote.server.example.com": {
+ "a.key.id": {
+ "minimum_valid_until_ts": 1234567890123
+ }
+ }
+ }
+ }
+
+ Response:
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+ {
+ "server_keys": [
+ {
+ "server_name": "remote.server.example.com"
+ "valid_until_ts": # posix timestamp
+ "verify_keys": {
+ "a.key.id": { # The identifier for a key.
+ key: "" # base64 encoded verification key.
+ }
+ }
+ "old_verify_keys": {
+ "an.old.key.id": { # The identifier for an old key.
+ key: "", # base64 encoded key
+ "expired_ts": 0, # when the key stop being used.
+ }
+ }
+ "tls_fingerprints": [
+ { "sha256": # fingerprint }
+ ]
+ "signatures": {
+ "remote.server.example.com": {...}
+ "this.server.example.com": {...}
+ }
+ }
+ ]
+ }
+ """
+
+ isLeaf = True
+
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.store = hs.get_datastore()
+ self.version_string = hs.version_string
+ self.clock = hs.get_clock()
+
+ def render_GET(self, request):
+ self.async_render_GET(request)
+ return NOT_DONE_YET
+
+ @request_handler
+ @defer.inlineCallbacks
+ def async_render_GET(self, request):
+ if len(request.postpath) == 1:
+ server, = request.postpath
+ query = {server: {}}
+ elif len(request.postpath) == 2:
+ server, key_id = request.postpath
+ minimum_valid_until_ts = parse_integer(
+ request, "minimum_valid_until_ts"
+ )
+ arguments = {}
+ if minimum_valid_until_ts is not None:
+ arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
+ query = {server: {key_id: arguments}}
+ else:
+ raise SynapseError(
+ 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
+ )
+ yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+ def render_POST(self, request):
+ self.async_render_POST(request)
+ return NOT_DONE_YET
+
+ @request_handler
+ @defer.inlineCallbacks
+ def async_render_POST(self, request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise ValueError()
+ except ValueError:
+ raise SynapseError(
+ 400, "Content must be JSON object.", errcode=Codes.NOT_JSON
+ )
+
+ query = content["server_keys"]
+
+ yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+ @defer.inlineCallbacks
+ def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ logger.info("Handling query for keys %r", query)
+ store_queries = []
+ for server_name, key_ids in query.items():
+ if not key_ids:
+ key_ids = (None,)
+ for key_id in key_ids:
+ store_queries.append((server_name, key_id, None))
+
+ cached = yield self.store.get_server_keys_json(store_queries)
+
+ json_results = set()
+
+ time_now_ms = self.clock.time_msec()
+
+ cache_misses = dict()
+ for (server_name, key_id, from_server), results in cached.items():
+ results = [
+ (result["ts_added_ms"], result) for result in results
+ ]
+
+ if not results and key_id is not None:
+ cache_misses.setdefault(server_name, set()).add(key_id)
+ continue
+
+ if key_id is not None:
+ ts_added_ms, most_recent_result = max(results)
+ ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+ req_key = query.get(server_name, {}).get(key_id, {})
+ req_valid_until = req_key.get("minimum_valid_until_ts")
+ miss = False
+ if req_valid_until is not None:
+ if ts_valid_until_ms < req_valid_until:
+ logger.debug(
+ "Cached response for %r/%r is older than requested"
+ ": valid_until (%r) < minimum_valid_until (%r)",
+ server_name, key_id,
+ ts_valid_until_ms, req_valid_until
+ )
+ miss = True
+ else:
+ logger.debug(
+ "Cached response for %r/%r is newer than requested"
+ ": valid_until (%r) >= minimum_valid_until (%r)",
+ server_name, key_id,
+ ts_valid_until_ms, req_valid_until
+ )
+ elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
+ logger.debug(
+ "Cached response for %r/%r is too old"
+ ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+ server_name, key_id,
+ ts_added_ms, ts_valid_until_ms, time_now_ms
+ )
+ # We more than half way through the lifetime of the
+ # response. We should fetch a fresh copy.
+ miss = True
+ else:
+ logger.debug(
+ "Cached response for %r/%r is still valid"
+ ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
+ server_name, key_id,
+ ts_added_ms, ts_valid_until_ms, time_now_ms
+ )
+
+ if miss:
+ cache_misses.setdefault(server_name, set()).add(key_id)
+ json_results.add(bytes(most_recent_result["key_json"]))
+ else:
+ for ts_added, result in results:
+ json_results.add(bytes(result["key_json"]))
+
+ if cache_misses and query_remote_on_cache_miss:
+ for server_name, key_ids in cache_misses.items():
+ try:
+ yield self.keyring.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except:
+ logger.exception("Failed to get key for %r", server_name)
+ pass
+ yield self.query_keys(
+ request, query, query_remote_on_cache_miss=False
+ )
+ else:
+ result_io = BytesIO()
+ result_io.write(b"{\"server_keys\":")
+ sep = b"["
+ for json_bytes in json_results:
+ result_io.write(sep)
+ result_io.write(json_bytes)
+ sep = b","
+ if sep == b"[":
+ result_io.write(sep)
+ result_io.write(b"]}")
+
+ respond_with_json_bytes(
+ request, 200, result_io.getvalue(),
+ version_string=self.version_string
+ )
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index b10cbddb81..08c8d75af4 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -18,13 +18,15 @@ from .thumbnailer import Thumbnailer
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
- cs_exception, CodeMessageException, cs_error, Codes, SynapseError
+ cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
+from synapse.util.async import create_observer
+
import os
import logging
@@ -32,6 +34,18 @@ import logging
logger = logging.getLogger(__name__)
+def parse_media_id(request):
+ try:
+ server_name, media_id = request.postpath
+ return (server_name, media_id)
+ except:
+ raise SynapseError(
+ 404,
+ "Invalid media id token %r" % (request.postpath,),
+ Codes.UNKNOWN,
+ )
+
+
class BaseMediaResource(Resource):
isLeaf = True
@@ -45,74 +59,9 @@ class BaseMediaResource(Resource):
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
+ self.version_string = hs.version_string
self.downloads = {}
- @staticmethod
- def catch_errors(request_handler):
- @defer.inlineCallbacks
- def wrapped_request_handler(self, request):
- try:
- yield request_handler(self, request)
- except CodeMessageException as e:
- logger.info("Responding with error: %r", e)
- respond_with_json(
- request, e.code, cs_exception(e), send_cors=True
- )
- except:
- logger.exception(
- "Failed handle request %s.%s on %r",
- request_handler.__module__,
- request_handler.__name__,
- self,
- )
- respond_with_json(
- request,
- 500,
- {"error": "Internal server error"},
- send_cors=True
- )
- return wrapped_request_handler
-
- @staticmethod
- def _parse_media_id(request):
- try:
- server_name, media_id = request.postpath
- return (server_name, media_id)
- except:
- raise SynapseError(
- 404,
- "Invalid media id token %r" % (request.postpath,),
- Codes.UNKNOWN,
- )
-
- @staticmethod
- def _parse_integer(request, arg_name, default=None):
- try:
- if default is None:
- return int(request.args[arg_name][0])
- else:
- return int(request.args.get(arg_name, [default])[0])
- except:
- raise SynapseError(
- 400,
- "Missing integer argument %r" % (arg_name,),
- Codes.UNKNOWN,
- )
-
- @staticmethod
- def _parse_string(request, arg_name, default=None):
- try:
- if default is None:
- return request.args[arg_name][0]
- else:
- return request.args.get(arg_name, [default])[0]
- except:
- raise SynapseError(
- 400,
- "Missing string argument %r" % (arg_name,),
- Codes.UNKNOWN,
- )
-
def _respond_404(self, request):
respond_with_json(
request, 404,
@@ -140,7 +89,7 @@ class BaseMediaResource(Resource):
def callback(media_info):
del self.downloads[key]
return media_info
- return download
+ return create_observer(download)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index c585bb11f7..0fe6abf647 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .base_resource import BaseMediaResource
+from .base_resource import BaseMediaResource, parse_media_id
+from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -28,15 +29,10 @@ class DownloadResource(BaseMediaResource):
self._async_render_GET(request)
return NOT_DONE_YET
- @BaseMediaResource.catch_errors
+ @request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- try:
- server_name, media_id = request.postpath
- except:
- self._respond_404(request)
- return
-
+ server_name, media_id = parse_media_id(request)
if server_name == self.server_name:
yield self._respond_local_file(request, media_id)
else:
diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py
index 912856386a..603859d5d4 100644
--- a/synapse/rest/media/v1/identicon_resource.py
+++ b/synapse/rest/media/v1/identicon_resource.py
@@ -1,3 +1,17 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from pydenticon import Generator
from twisted.web.resource import Resource
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 84f5e3463c..1dadd880b2 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,9 @@
# limitations under the License.
-from .base_resource import BaseMediaResource
+from .base_resource import BaseMediaResource, parse_media_id
+from synapse.http.servlet import parse_string, parse_integer
+from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -31,14 +33,14 @@ class ThumbnailResource(BaseMediaResource):
self._async_render_GET(request)
return NOT_DONE_YET
- @BaseMediaResource.catch_errors
+ @request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- server_name, media_id = self._parse_media_id(request)
- width = self._parse_integer(request, "width")
- height = self._parse_integer(request, "height")
- method = self._parse_string(request, "method", "scale")
- m_type = self._parse_string(request, "type", "image/png")
+ server_name, media_id = parse_media_id(request)
+ width = parse_integer(request, "width")
+ height = parse_integer(request, "height")
+ method = parse_string(request, "method", "scale")
+ m_type = parse_string(request, "type", "image/png")
if server_name == self.server_name:
yield self._respond_local_thumbnail(
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index e5aba3af4c..cc571976a5 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json
+from synapse.http.server import respond_with_json, request_handler
from synapse.util.stringutils import random_string
-from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException
-)
+from synapse.api.errors import SynapseError
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -69,53 +67,42 @@ class UploadResource(BaseMediaResource):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ @request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
- try:
- auth_user, client = yield self.auth.get_user_by_req(request)
- # TODO: The checks here are a bit late. The content will have
- # already been uploaded to a tmp file at this point
- content_length = request.getHeader("Content-Length")
- if content_length is None:
- raise SynapseError(
- msg="Request must specify a Content-Length", code=400
- )
- if int(content_length) > self.max_upload_size:
- raise SynapseError(
- msg="Upload request body is too large",
- code=413,
- )
-
- headers = request.requestHeaders
-
- if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders("Content-Type")[0]
- else:
- raise SynapseError(
- msg="Upload request missing 'Content-Type'",
- code=400,
- )
-
- # if headers.hasHeader("Content-Disposition"):
- # disposition = headers.getRawHeaders("Content-Disposition")[0]
- # TODO(markjh): parse content-dispostion
-
- content_uri = yield self.create_content(
- media_type, None, request.content.read(),
- content_length, auth_user
+ auth_user, client = yield self.auth.get_user_by_req(request)
+ # TODO: The checks here are a bit late. The content will have
+ # already been uploaded to a tmp file at this point
+ content_length = request.getHeader("Content-Length")
+ if content_length is None:
+ raise SynapseError(
+ msg="Request must specify a Content-Length", code=400
)
-
- respond_with_json(
- request, 200, {"content_uri": content_uri}, send_cors=True
+ if int(content_length) > self.max_upload_size:
+ raise SynapseError(
+ msg="Upload request body is too large",
+ code=413,
)
- except CodeMessageException as e:
- logger.exception(e)
- respond_with_json(request, e.code, cs_exception(e), send_cors=True)
- except:
- logger.exception("Failed to store file")
- respond_with_json(
- request,
- 500,
- {"error": "Internal server error"},
- send_cors=True
+
+ headers = request.requestHeaders
+
+ if headers.hasHeader("Content-Type"):
+ media_type = headers.getRawHeaders("Content-Type")[0]
+ else:
+ raise SynapseError(
+ msg="Upload request missing 'Content-Type'",
+ code=400,
)
+
+ # if headers.hasHeader("Content-Disposition"):
+ # disposition = headers.getRawHeaders("Content-Disposition")[0]
+ # TODO(markjh): parse content-dispostion
+
+ content_uri = yield self.create_content(
+ media_type, None, request.content.read(),
+ content_length, auth_user
+ )
+
+ respond_with_json(
+ request, 200, {"content_uri": content_uri}, send_cors=True
+ )
diff --git a/synapse/server.py b/synapse/server.py
index c7772244ba..8b3dc675cc 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -59,12 +59,12 @@ class BaseHomeServer(object):
'config',
'clock',
'http_client',
- 'db_name',
'db_pool',
'persistence_service',
'replication_layer',
'datastore',
'handlers',
+ 'v1auth',
'auth',
'rest_servlet_factory',
'state_handler',
@@ -78,8 +78,8 @@ class BaseHomeServer(object):
'resource_for_web_client',
'resource_for_content_repo',
'resource_for_server_key',
+ 'resource_for_server_key_v2',
'resource_for_media_repository',
- 'resource_for_app_services',
'resource_for_metrics',
'event_sources',
'ratelimiter',
@@ -182,6 +182,15 @@ class HomeServer(BaseHomeServer):
def build_auth(self):
return Auth(self)
+ def build_v1auth(self):
+ orf = Auth(self)
+ # Matrix spec makes no reference to what HTTP status code is returned,
+ # but the V1 API uses 403 where it means 401, and the webclient
+ # relies on this behaviour, so V1 gets its own copy of the auth
+ # with backwards compat behaviour.
+ orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
+ return orf
+
def build_state_handler(self):
return StateHandler(self)
diff --git a/synapse/state.py b/synapse/state.py
index ba2500d61c..9dddb77d5b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -86,12 +86,7 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
"""
- events = yield self.store.get_latest_events_in_room(room_id)
-
- event_ids = [
- e_id
- for e_id, _, _ in events
- ]
+ event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
cache = None
if self._state_cache is not None:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4b16f445d6..0cc14fb692 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,13 +14,12 @@
# limitations under the License.
from twisted.internet import defer
-
-from synapse.util.logutils import log_function
-from synapse.api.constants import EventTypes
-
-from .appservice import ApplicationServiceStore
+from .appservice import (
+ ApplicationServiceStore, ApplicationServiceTransactionStore
+)
+from ._base import Cache
from .directory import DirectoryStore
-from .feedback import FeedbackStore
+from .events import EventsStore
from .presence import PresenceStore
from .profile import ProfileStore
from .registration import RegistrationStore
@@ -39,11 +38,6 @@ from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
-from syutil.base64util import decode_base64
-from syutil.jsonutil import encode_canonical_json
-
-from synapse.crypto.event_signing import compute_event_reference_hash
-
import fnmatch
import imp
@@ -57,20 +51,18 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 14
+SCHEMA_VERSION = 17
dir_path = os.path.abspath(os.path.dirname(__file__))
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
- pass
+# Number of msec of granularity to store the user IP 'last seen' time. Smaller
+# times give more inserts into the database even for readonly API hits
+# 120 seconds == 2 minutes
+LAST_SEEN_GRANULARITY = 120*1000
class DataStore(RoomMemberStore, RoomStore,
- RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
+ RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
@@ -79,7 +71,9 @@ class DataStore(RoomMemberStore, RoomStore,
RejectionsStore,
FilteringStore,
PusherStore,
- PushRuleStore
+ PushRuleStore,
+ ApplicationServiceTransactionStore,
+ EventsStore,
):
def __init__(self, hs):
@@ -89,474 +83,53 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token()
self.min_token = None
- @defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False,
- is_new_state=True, current_state=None):
- stream_ordering = None
- if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
-
- try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
- except _RollbackButIsFineException:
- pass
-
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- event = yield self.runInteraction(
- "get_event", self._get_event_txn,
- event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- if not event and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
-
- defer.returnValue(event)
-
- @log_function
- def _persist_event_txn(self, txn, event, context, backfilled,
- stream_ordering=None, is_new_state=True,
- current_state=None):
-
- # Remove the any existing cache entries for the event_id
- self._get_event_cache.pop(event.event_id)
-
- # We purposefully do this first since if we include a `current_state`
- # key, we *want* to update the `current_state_events` table
- if current_state:
- txn.execute(
- "DELETE FROM current_state_events WHERE room_id = ?",
- (event.room_id,)
- )
-
- for s in current_state:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": s.event_id,
- "room_id": s.room_id,
- "type": s.type,
- "state_key": s.state_key,
- },
- or_replace=True,
- )
-
- if event.is_state() and is_new_state:
- if not backfilled and not context.rejected:
- self._simple_insert_txn(
- txn,
- table="state_forward_extremities",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for prev_state_id, _ in event.prev_state:
- self._simple_delete_txn(
- txn,
- table="state_forward_extremities",
- keyvalues={
- "event_id": prev_state_id,
- }
- )
-
- outlier = event.internal_metadata.is_outlier()
-
- if not outlier:
- self._store_state_groups_txn(txn, event, context)
-
- self._update_min_depth_for_room_txn(
- txn,
- event.room_id,
- event.depth
- )
-
- self._handle_prev_events(
- txn,
- outlier=outlier,
- event_id=event.event_id,
- prev_events=event.prev_events,
- room_id=event.room_id,
- )
-
- have_persisted = self._simple_select_one_onecol_txn(
- txn,
- table="event_json",
- keyvalues={"event_id": event.event_id},
- retcol="event_id",
- allow_none=True,
- )
-
- metadata_json = encode_canonical_json(
- event.internal_metadata.get_dict()
- )
-
- # If we have already persisted this event, we don't need to do any
- # more processing.
- # The processing above must be done on every call to persist event,
- # since they might not have happened on previous calls. For example,
- # if we are persisting an event that we had persisted as an outlier,
- # but is no longer one.
- if have_persisted:
- if not outlier:
- sql = (
- "UPDATE event_json SET internal_metadata = ?"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (metadata_json.decode("UTF-8"), event.event_id,)
- )
-
- sql = (
- "UPDATE events SET outlier = 0"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (event.event_id,)
- )
- return
-
- if event.type == EventTypes.Member:
- self._store_room_member_txn(txn, event)
- elif event.type == EventTypes.Feedback:
- self._store_feedback_txn(txn, event)
- elif event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
-
- event_dict = {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- self._simple_insert_txn(
- txn,
- table="event_json",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": metadata_json.decode("UTF-8"),
- "json": encode_canonical_json(event_dict).decode("UTF-8"),
- },
- or_replace=True,
- )
-
- content = encode_canonical_json(
- event.content
- ).decode("UTF-8")
-
- vals = {
- "topological_ordering": event.depth,
- "event_id": event.event_id,
- "type": event.type,
- "room_id": event.room_id,
- "content": content,
- "processed": True,
- "outlier": outlier,
- "depth": event.depth,
- }
-
- if stream_ordering is not None:
- vals["stream_ordering"] = stream_ordering
-
- unrec = {
- k: v
- for k, v in event.get_dict().items()
- if k not in vals.keys() and k not in [
- "redacted",
- "redacted_because",
- "signatures",
- "hashes",
- "prev_events",
- ]
- }
-
- vals["unrecognized_keys"] = encode_canonical_json(
- unrec
- ).decode("UTF-8")
-
- try:
- self._simple_insert_txn(
- txn,
- "events",
- vals,
- or_replace=(not outlier),
- or_ignore=bool(outlier),
- )
- except:
- logger.warn(
- "Failed to persist, probably duplicate: %s",
- event.event_id,
- exc_info=True,
- )
- raise _RollbackButIsFineException("_persist_event")
-
- if context.rejected:
- self._store_rejections_txn(txn, event.event_id, context.rejected)
-
- if event.is_state():
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- self._simple_insert_txn(
- txn,
- "state_events",
- vals,
- or_replace=True,
- )
-
- if is_new_state and not context.rejected:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for e_id, h in event.prev_state:
- self._simple_insert_txn(
- txn,
- table="event_edges",
- values={
- "event_id": event.event_id,
- "prev_event_id": e_id,
- "room_id": event.room_id,
- "is_state": 1,
- },
- or_ignore=True,
- )
-
- for hash_alg, hash_base64 in event.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_event_content_hash_txn(
- txn, event.event_id, hash_alg, hash_bytes,
- )
-
- for prev_event_id, prev_hashes in event.prev_events:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg, hash_bytes
- )
-
- for auth_id, _ in event.auth_events:
- self._simple_insert_txn(
- txn,
- table="event_auth",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- },
- or_ignore=True,
- )
-
- (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
- self._store_event_reference_hash_txn(
- txn, event.event_id, ref_alg, ref_hash_bytes
- )
-
- def _store_redaction(self, txn, event):
- # invalidate the cache for the redacted event
- self._get_event_cache.pop(event.redacts)
- txn.execute(
- "INSERT OR IGNORE INTO redactions "
- "(event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts)
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
)
@defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
- )
-
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- if event_type and state_key is not None:
- sql += " AND s.type = ? AND s.state_key = ? "
- args = (room_id, event_type, state_key)
- elif event_type:
- sql += " AND s.type = ?"
- args = (room_id, event_type)
- else:
- args = (room_id, )
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
- defer.returnValue(events)
-
- @defer.inlineCallbacks
- def get_room_name_and_aliases(self, room_id):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
- )
-
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
- sql += " OR s.type = 'm.room.aliases')"
- args = (room_id,)
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
-
- name = None
- aliases = []
-
- for e in events:
- if e.type == 'm.room.name':
- if 'name' in e.content:
- name = e.content['name']
- elif e.type == 'm.room.aliases':
- if 'aliases' in e.content:
- aliases.extend(e.content['aliases'])
-
- defer.returnValue((name, aliases))
-
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
+ def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
+ now = int(self._clock.time_msec())
+ key = (user.to_string(), access_token, device_id, ip)
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
+ try:
+ last_seen = self.client_ip_last_seen.get(*key)
+ except KeyError:
+ last_seen = None
- logger.debug("min_token is: %s", self.min_token)
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ defer.returnValue(None)
- defer.returnValue(self.min_token)
+ self.client_ip_last_seen.prefill(*key + (now,))
- def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
- return self._simple_insert(
+ # It's safe not to lock here: a) no unique constraint,
+ # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
+ yield self._simple_upsert(
"user_ips",
- {
- "user": user.to_string(),
+ keyvalues={
+ "user_id": user.to_string(),
"access_token": access_token,
- "device_id": device_id,
"ip": ip,
"user_agent": user_agent,
- "last_seen": int(self._clock.time_msec()),
- }
+ },
+ values={
+ "device_id": device_id,
+ "last_seen": now,
+ },
+ desc="insert_client_ip",
+ lock=False,
)
def get_user_ip_and_agents(self, user):
return self._simple_select_list(
table="user_ips",
- keyvalues={"user": user.to_string()},
+ keyvalues={"user_id": user.to_string()},
retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen"
],
- )
-
- def have_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction(
- "have_events", f,
+ desc="get_user_ip_and_agents",
)
@@ -580,21 +153,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn):
+def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
"""
try:
cur = db_conn.cursor()
- version_info = _get_or_create_schema_state(cur)
+ version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
- _upgrade_existing_database(cur, user_version, delta_files, upgraded)
+ _upgrade_existing_database(
+ cur, user_version, delta_files, upgraded, database_engine
+ )
else:
- _setup_new_database(cur)
+ _setup_new_database(cur, database_engine)
- cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+ # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
@@ -603,7 +178,7 @@ def prepare_database(db_conn):
raise
-def _setup_new_database(cur):
+def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@@ -657,31 +232,30 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir)
- sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
- sql_script += read_schema(sql_loc)
- sql_script += "\n"
- sql_script += "COMMIT TRANSACTION;"
- cur.executescript(sql_script)
+ executescript(cur, sql_loc)
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
- (max_current_ver, False)
+ database_engine.convert_param_style(
+ "INSERT INTO schema_version (version, upgraded)"
+ " VALUES (?,?)"
+ ),
+ (max_current_ver, False,)
)
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
- upgraded=False
+ upgraded=False,
+ database_engine=database_engine,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
- upgraded):
+ upgraded, database_engine):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -737,6 +311,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded:
start_ver += 1
+ logger.debug("applied_delta_files: %s", applied_delta_files)
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
@@ -753,6 +329,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
+ logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files:
continue
@@ -774,9 +351,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
- delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path)
- cur.executescript(delta_schema)
+ executescript(cur, absolute_path)
else:
# Not a valid delta file.
logger.warn(
@@ -788,24 +364,83 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done.
cur.execute(
- "INSERT INTO applied_schema_deltas (version, file)"
- " VALUES (?,?)",
+ database_engine.convert_param_style(
+ "INSERT INTO applied_schema_deltas (version, file)"
+ " VALUES (?,?)",
+ ),
(v, relative_path)
)
+ cur.execute("DELETE FROM schema_version")
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
+ database_engine.convert_param_style(
+ "INSERT INTO schema_version (version, upgraded)"
+ " VALUES (?,?)",
+ ),
(v, True)
)
-def _get_or_create_schema_state(txn):
+def get_statements(f):
+ statement_buffer = ""
+ in_comment = False # If we're in a /* ... */ style comment
+
+ for line in f:
+ line = line.strip()
+
+ if in_comment:
+ # Check if this line contains an end to the comment
+ comments = line.split("*/", 1)
+ if len(comments) == 1:
+ continue
+ line = comments[1]
+ in_comment = False
+
+ # Remove inline block comments
+ line = re.sub(r"/\*.*\*/", " ", line)
+
+ # Does this line start a comment?
+ comments = line.split("/*", 1)
+ if len(comments) > 1:
+ line = comments[0]
+ in_comment = True
+
+ # Deal with line comments
+ line = line.split("--", 1)[0]
+ line = line.split("//", 1)[0]
+
+ # Find *all* semicolons. We need to treat first and last entry
+ # specially.
+ statements = line.split(";")
+
+ # We must prepend statement_buffer to the first statement
+ first_statement = "%s %s" % (
+ statement_buffer.strip(),
+ statements[0].strip()
+ )
+ statements[0] = first_statement
+
+ # Every entry, except the last, is a full statement
+ for statement in statements[:-1]:
+ yield statement.strip()
+
+ # The last entry did *not* end in a semicolon, so we store it for the
+ # next semicolon we find
+ statement_buffer = statements[-1].strip()
+
+
+def executescript(txn, schema_path):
+ with open(schema_path, 'r') as f:
+ for statement in get_statements(f):
+ txn.execute(statement)
+
+
+def _get_or_create_schema_state(txn, database_engine):
+ # Bluntly try creating the schema_version tables.
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
- create_schema = read_schema(schema_path)
- txn.executescript(create_schema)
+ executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
@@ -814,10 +449,13 @@ def _get_or_create_schema_state(txn):
if current_version:
txn.execute(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+ database_engine.convert_param_style(
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?"
+ ),
(current_version,)
)
- return current_version, txn.fetchall(), upgraded
+ applied_deltas = [d for d, in txn.fetchall()]
+ return current_version, applied_deltas, upgraded
return None
@@ -849,7 +487,19 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]:
db_conn.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
+ "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+ sql = database_engine.convert_param_style(
+ "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+ )
+ pat = "%:" + domain
+ txn.execute(sql, (pat,))
+ num_not_matching = txn.fetchall()[0][0]
+ if num_not_matching == 0:
+ return True
+ return False
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9125bb1198..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
+from util.id_generators import IdGenerator, StreamIdGenerator
+
from twisted.internet import defer
from collections import namedtuple, OrderedDict
@@ -29,12 +31,15 @@ import functools
import simplejson as json
import sys
import time
+import threading
+DEBUG_CACHES = False
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
metrics = synapse.metrics.get_metrics_for("synapse.storage")
@@ -53,14 +58,78 @@ cache_counter = metrics.register_cache(
)
-# TODO(paul):
-# * more generic key management
-# * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
+
+ self.name = name
+ self.keylen = keylen
+ self.sequence = 0
+ self.thread = None
+ caches_by_name[name] = self.cache
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if keyargs in self.cache:
+ cache_counter.inc_hits(self.name)
+ return self.cache[keyargs]
+
+ cache_counter.inc_misses(self.name)
+ raise KeyError()
+
+ def update(self, sequence, *args):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(*args)
+
+ def prefill(self, *args): # because I can't *keyargs, value
+ keyargs = args[:-1]
+ value = args[-1]
+
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[keyargs] = value
+
+ def invalidate(self, *keyargs):
+ self.check_thread()
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
+ self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
- The function is presumed to take one additional argument, which is used as
- the key for the cache. Cache hits are served directly from the cache;
+ The function is presumed to take zero or more arguments, which are used in
+ a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
@@ -71,34 +140,42 @@ def cached(max_entries=1000):
calling the calculation function.
"""
def wrap(orig):
- cache = OrderedDict()
- name = orig.__name__
-
- caches_by_name[name] = cache
-
- def prefill(key, value):
- while len(cache) > max_entries:
- cache.popitem(last=False)
-
- cache[key] = value
+ cache = Cache(
+ name=orig.__name__,
+ max_entries=max_entries,
+ keylen=num_args,
+ lru=lru,
+ )
@functools.wraps(orig)
@defer.inlineCallbacks
- def wrapped(self, key):
- if key in cache:
- cache_counter.inc_hits(name)
- defer.returnValue(cache[key])
-
- cache_counter.inc_misses(name)
- ret = yield orig(self, key)
- prefill(key, ret)
- defer.returnValue(ret)
-
- def invalidate(key):
- cache.pop(key, None)
-
- wrapped.invalidate = invalidate
- wrapped.prefill = prefill
+ def wrapped(self, *keyargs):
+ try:
+ cached_result = cache.get(*keyargs)
+ if DEBUG_CACHES:
+ actual_result = yield orig(self, *keyargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ orig.__name__, keyargs,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ except KeyError:
+ # Get the sequence number of the cache before reading from the
+ # database so that we can tell if the cache is invalidated
+ # while the SELECT is executing (SYN-369)
+ sequence = cache.sequence
+
+ ret = yield orig(self, *keyargs)
+
+ cache.update(sequence, *keyargs + (ret,))
+
+ defer.returnValue(ret)
+
+ wrapped.invalidate = cache.invalidate
+ wrapped.prefill = cache.prefill
return wrapped
return wrap
@@ -108,11 +185,20 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name"]
+ __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
- def __init__(self, txn, name):
+ def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
+ object.__setattr__(self, "database_engine", database_engine)
+ object.__setattr__(self, "after_callbacks", after_callbacks)
+
+ def call_after(self, callback, *args):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -120,30 +206,37 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
- def execute(self, sql, *args, **kwargs):
+ def execute(self, sql, *args):
+ self._do_execute(self.txn.execute, sql, *args)
+
+ def executemany(self, sql, *args):
+ self._do_execute(self.txn.executemany, sql, *args)
+
+ def _do_execute(self, func, sql, *args):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
- try:
- if args and args[0]:
- values = args[0]
+ sql = self.database_engine.convert_param_style(sql)
+
+ if args:
+ try:
sql_logger.debug(
- "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
- self.name,
- *values
+ "[SQL values] {%s} %r",
+ self.name, args[0]
)
- except:
- # Don't let logging failures stop SQL from working
- pass
+ except:
+ # Don't let logging failures stop SQL from working
+ pass
start = time.time() * 1000
+
try:
- return self.txn.execute(
- sql, *args, **kwargs
+ return func(
+ sql, *args
)
- except:
- logger.exception("[SQL FAIL] {%s}", self.name)
- raise
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
finally:
msecs = (time.time() * 1000) - start
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
@@ -205,10 +298,16 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = LruCache(hs.config.event_cache_size)
+ self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+ max_entries=hs.config.event_cache_size)
+
+ self.database_engine = hs.database_engine
- # Pretend the getEventCache is just another named cache
- caches_by_name["*getEvent*"] = self._get_event_cache
+ self._stream_id_gen = StreamIdGenerator()
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -232,7 +331,7 @@ class SQLBaseStore(object):
time_now - time_then, limit=3
)
- logger.info(
+ perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters
)
@@ -246,8 +345,14 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
- def inner_func(txn, *args, **kwargs):
+ after_callbacks = []
+
+ def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
+ if self.database_engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
+
current_context.copy_to(context)
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -261,9 +366,48 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
- return func(LoggingTransaction(txn, name), *args, **kwargs)
- except:
- logger.exception("[TXN FAIL] {%s}", name)
+ i = 0
+ N = 5
+ while True:
+ try:
+ txn = conn.cursor()
+ txn = LoggingTransaction(
+ txn, name, self.database_engine, after_callbacks
+ )
+ return func(txn, *args, **kwargs)
+ except self.database_engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warn(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name, e, i, N
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ except self.database_engine.module.DatabaseError as e:
+ if self.database_engine.is_deadlock(e):
+ logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ raise
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
@@ -276,9 +420,11 @@ class SQLBaseStore(object):
sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
- result = yield self._db_pool.runInteraction(
+ result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@@ -307,11 +453,11 @@ class SQLBaseStore(object):
The result of decoder(results)
"""
def interaction(txn):
- cursor = txn.execute(query, args)
+ txn.execute(query, args)
if decoder:
- return decoder(cursor)
+ return decoder(txn)
else:
- return cursor.fetchall()
+ return txn.fetchall()
return self.runInteraction(desc, interaction)
@@ -321,53 +467,94 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
+ @defer.inlineCallbacks
+ def _simple_insert(self, table, values, or_ignore=False,
+ desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
table : string giving the table name
values : dict of new column names and values for them
- or_replace : bool; if True performs an INSERT OR REPLACE
"""
- return self.runInteraction(
- "_simple_insert",
- self._simple_insert_txn, table, values, or_replace=or_replace,
- or_ignore=or_ignore,
- )
+ try:
+ yield self.runInteraction(
+ desc,
+ self._simple_insert_txn, table, values,
+ )
+ except self.database_engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
@log_function
- def _simple_insert_txn(self, txn, table, values, or_replace=False,
- or_ignore=False):
- sql = "%s INTO %s (%s) VALUES(%s)" % (
- ("INSERT OR REPLACE" if or_replace else
- "INSERT OR IGNORE" if or_ignore else "INSERT"),
+ def _simple_insert_txn(self, txn, table, values):
+ keys, vals = zip(*values.items())
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
- ", ".join(k for k in values),
- ", ".join("?" for k in values)
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys)
)
- logger.debug(
- "[SQL] %s Args=%s",
- sql, values.values(),
+ txn.execute(sql, vals)
+
+ def _simple_insert_many_txn(self, txn, table, values):
+ if not values:
+ return
+
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
+ keys, vals = zip(*[
+ zip(
+ *(sorted(i.items(), key=lambda kv: kv[0]))
+ )
+ for i in values
+ if i
+ ])
+
+ for k in keys:
+ if k != keys[0]:
+ raise RuntimeError(
+ "All items must have the same keys"
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0])
)
- txn.execute(sql, values.values())
- return txn.lastrowid
+ txn.executemany(sql, vals)
- def _simple_upsert(self, table, keyvalues, values):
+ def _simple_upsert(self, table, keyvalues, values,
+ insertion_values={}, desc="_simple_upsert", lock=True):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
+ insertion_values (dict): key/values to use when inserting
Returns: A deferred
"""
return self.runInteraction(
- "_simple_upsert",
- self._simple_upsert_txn, table, keyvalues, values
+ desc,
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+ lock
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values):
+ def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
+ lock=True):
+ # We need to lock the table :(, unless we're *really* careful
+ if lock:
+ self.database_engine.lock_table(txn, table)
+
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
@@ -386,6 +573,7 @@ class SQLBaseStore(object):
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
+ allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
@@ -399,7 +587,7 @@ class SQLBaseStore(object):
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
- allow_none=False):
+ allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -411,12 +599,15 @@ class SQLBaseStore(object):
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
- return self._simple_selectupdate_one(
- table, keyvalues, retcols=retcols, allow_none=allow_none
+ return self.runInteraction(
+ desc,
+ self._simple_select_one_txn,
+ table, keyvalues, retcols, allow_none,
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
- allow_none=False):
+ allow_none=False,
+ desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
@@ -426,7 +617,7 @@ class SQLBaseStore(object):
retcol : string giving the name of the column to return
"""
return self.runInteraction(
- "_simple_select_one_onecol",
+ desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
@@ -450,8 +641,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = (
- "SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
- "ORDER BY rowid asc"
+ "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
"retcol": retcol,
"table": table,
@@ -462,7 +652,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()]
- def _simple_select_onecol(self, table, keyvalues, retcol):
+ def _simple_select_onecol(self, table, keyvalues, retcol,
+ desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -475,12 +666,13 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- "_simple_select_onecol",
+ desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list(self, table, keyvalues, retcols,
+ desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -491,7 +683,7 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
- "_simple_select_list",
+ desc,
self._simple_select_list_txn,
table, keyvalues, retcols
)
@@ -507,14 +699,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
- sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
@@ -523,7 +715,7 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
- retcols=None):
+ desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -541,56 +733,81 @@ class SQLBaseStore(object):
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
- return self._simple_selectupdate_one(table, keyvalues, updatevalues,
- retcols=retcols)
+ return self.runInteraction(
+ desc,
+ self._simple_update_one_txn,
+ table, keyvalues, updatevalues,
+ )
- def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
- retcols=None, allow_none=False):
- """ Combined SELECT then UPDATE."""
- if retcols:
- select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
- )
+ def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ update_sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
- if updatevalues:
- update_sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
- )
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
+ txn.execute(select_sql, keyvalues.values())
+
+ row = txn.fetchone()
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ return dict(zip(retcols, row))
+
+ def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+ retcols=None, allow_none=False,
+ desc="_simple_selectupdate_one"):
+ """ Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
- txn.execute(select_sql, keyvalues.values())
-
- row = txn.fetchone()
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched")
-
- ret = dict(zip(retcols, row))
+ ret = self._simple_select_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ retcols=retcols,
+ allow_none=allow_none,
+ )
if updatevalues:
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
+ self._simple_update_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ updatevalues=updatevalues,
)
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
+ # if txn.rowcount == 0:
+ # raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction("_simple_selectupdate_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete_one(self, table, keyvalues):
+ def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -609,9 +826,9 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction("_simple_delete_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete(self, table, keyvalues):
+ def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
@@ -619,7 +836,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction("_simple_delete", self._simple_delete_txn)
+ return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
@@ -670,6 +887,12 @@ class SQLBaseStore(object):
return [e for e in events if e]
+ def _invalidate_get_event_cache(self, event_id):
+ for check_redacted in (False, True):
+ for get_prev_content in (False, True):
+ self._get_event_cache.invalidate(event_id, check_redacted,
+ get_prev_content)
+
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@@ -680,16 +903,14 @@ class SQLBaseStore(object):
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
- cache = self._get_event_cache.setdefault(event_id, {})
-
try:
- # Separate cache entries for each way to invoke _get_event_txn
- ret = cache[(check_redacted, get_prev_content, allow_rejected)]
+ ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
- cache_counter.inc_hits("*getEvent*")
- return ret
+ if allow_rejected or not ret.rejected_reason:
+ return ret
+ else:
+ return None
except KeyError:
- cache_counter.inc_misses("*getEvent*")
pass
finally:
start_time = update_counter("event_cache", start_time)
@@ -714,19 +935,22 @@ class SQLBaseStore(object):
start_time = update_counter("select_event", start_time)
+ result = self._get_event_from_row_txn(
+ txn, internal_metadata, js, redacted,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=rejected_reason,
+ )
+ self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
if allow_rejected or not rejected_reason:
- result = self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- )
- cache[(check_redacted, get_prev_content, allow_rejected)] = result
return result
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False):
+ check_redacted=True, get_prev_content=False,
+ rejected_reason=None):
start_time = time.time() * 1000
@@ -741,7 +965,11 @@ class SQLBaseStore(object):
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
- ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
+ ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
@@ -788,6 +1016,19 @@ class SQLBaseStore(object):
result = txn.fetchone()
return result[0] if result else None
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
+
+class _RollbackButIsFineException(Exception):
+ """ This exception is used to rollback a transaction without implying
+ something went wrong.
+ """
+ pass
+
class Table(object):
""" A base class used to store information about a particular table.
@@ -804,7 +1045,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
- _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+ _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 850676ce6c..39b7881c40 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,154 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import simplejson
+import urllib
+import yaml
from simplejson import JSONDecodeError
+import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.errors import StoreError
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.storage.roommember import RoomsForUser
+from synapse.types import UserID
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-def log_failure(failure):
- logger.error("Failed to detect application services: %s", failure.value)
- logger.error(failure.getTraceback())
-
-
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
+ self.hostname = hs.hostname
self.services_cache = []
- self.cache_defer = self._populate_cache()
- self.cache_defer.addErrback(log_failure)
-
- @defer.inlineCallbacks
- def unregister_app_service(self, token):
- """Unregisters this service.
-
- This removes all AS specific regex and the base URL. The token is the
- only thing preserved for future registration attempts.
- """
- yield self.cache_defer # make sure the cache is ready
- yield self.runInteraction(
- "unregister_app_service",
- self._unregister_app_service_txn,
- token,
- )
- # update cache TODO: Should this be in the txn?
- for service in self.services_cache:
- if service.token == token:
- service.url = None
- service.namespaces = None
- service.hs_token = None
-
- def _unregister_app_service_txn(self, txn, token):
- # kill the url to prevent pushes
- txn.execute(
- "UPDATE application_services SET url=NULL WHERE token=?",
- (token,)
- )
-
- # cleanup regex
- as_id = self._get_as_id_txn(txn, token)
- if not as_id:
- logger.warning(
- "unregister_app_service_txn: Failed to find as_id for token=",
- token
- )
- return False
-
- txn.execute(
- "DELETE FROM application_services_regex WHERE as_id=?",
- (as_id,)
+ self._populate_appservice_cache(
+ hs.config.app_service_config_files
)
- return True
- @defer.inlineCallbacks
- def update_app_service(self, service):
- """Update an application service, clobbering what was previously there.
-
- Args:
- service(ApplicationService): The updated service.
- """
- yield self.cache_defer # make sure the cache is ready
-
- # NB: There is no "insert" since we provide no public-facing API to
- # allocate new ASes. It relies on the server admin inserting the AS
- # token into the database manually.
-
- if not service.token or not service.url:
- raise StoreError(400, "Token and url must be specified.")
-
- if not service.hs_token:
- raise StoreError(500, "No HS token")
-
- yield self.runInteraction(
- "update_app_service",
- self._update_app_service_txn,
- service
- )
-
- # update cache TODO: Should this be in the txn?
- for (index, cache_service) in enumerate(self.services_cache):
- if service.token == cache_service.token:
- self.services_cache[index] = service
- logger.info("Updated: %s", service)
- return
- # new entry
- self.services_cache.append(service)
- logger.info("Updated(new): %s", service)
-
- def _update_app_service_txn(self, txn, service):
- as_id = self._get_as_id_txn(txn, service.token)
- if not as_id:
- logger.warning(
- "update_app_service_txn: Failed to find as_id for token=",
- service.token
- )
- return False
-
- txn.execute(
- "UPDATE application_services SET url=?, hs_token=?, sender=? "
- "WHERE id=?",
- (service.url, service.hs_token, service.sender, as_id,)
- )
- # cleanup regex
- txn.execute(
- "DELETE FROM application_services_regex WHERE as_id=?",
- (as_id,)
- )
- for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
- if ns_str in service.namespaces:
- for regex_obj in service.namespaces[ns_str]:
- txn.execute(
- "INSERT INTO application_services_regex("
- "as_id, namespace, regex) values(?,?,?)",
- (as_id, ns_int, simplejson.dumps(regex_obj))
- )
- return True
-
- def _get_as_id_txn(self, txn, token):
- cursor = txn.execute(
- "SELECT id FROM application_services WHERE token=?",
- (token,)
- )
- res = cursor.fetchone()
- if res:
- return res[0]
-
- @defer.inlineCallbacks
def get_app_services(self):
- yield self.cache_defer # make sure the cache is ready
- defer.returnValue(self.services_cache)
+ return defer.succeed(self.services_cache)
- @defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
@@ -174,37 +55,23 @@ class ApplicationServiceStore(SQLBaseStore):
Returns:
synapse.appservice.ApplicationService or None.
"""
-
- yield self.cache_defer # make sure the cache is ready
-
for service in self.services_cache:
if service.sender == user_id:
- defer.returnValue(service)
- return
- defer.returnValue(None)
+ return defer.succeed(service)
+ return defer.succeed(None)
- @defer.inlineCallbacks
- def get_app_service_by_token(self, token, from_cache=True):
+ def get_app_service_by_token(self, token):
"""Get the application service with the given appservice token.
Args:
token (str): The application service token.
- from_cache (bool): True to get this service from the cache, False to
- check the database.
- Raises:
- StoreError if there was a problem retrieving this service.
+ Returns:
+ synapse.appservice.ApplicationService or None.
"""
- yield self.cache_defer # make sure the cache is ready
-
- if from_cache:
- for service in self.services_cache:
- if service.token == token:
- defer.returnValue(service)
- return
- defer.returnValue(None)
-
- # TODO: The from_cache=False impl
- # TODO: This should be JOINed with the application_services_regex table.
+ for service in self.services_cache:
+ if service.token == token:
+ return defer.succeed(service)
+ return defer.succeed(None)
def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service.
@@ -277,12 +144,7 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
- @defer.inlineCallbacks
- def _populate_cache(self):
- """Populates the ApplicationServiceCache from the database."""
- sql = ("SELECT * FROM application_services LEFT JOIN "
- "application_services_regex ON application_services.id = "
- "application_services_regex.as_id")
+ def _parse_services_dict(self, results):
# SQL results in the form:
# [
# {
@@ -296,12 +158,14 @@ class ApplicationServiceStore(SQLBaseStore):
# }
# ]
services = {}
- results = yield self._execute_and_decode("_populate_cache", sql)
for res in results:
as_token = res["token"]
+ if as_token is None:
+ continue
if as_token not in services:
# add the service
services[as_token] = {
+ "id": res["id"],
"url": res["url"],
"token": as_token,
"hs_token": res["hs_token"],
@@ -319,20 +183,289 @@ class ApplicationServiceStore(SQLBaseStore):
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
- simplejson.loads(res["regex"])
+ json.loads(res["regex"])
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
- # TODO get last successful txn id f.e. service
+ service_list = []
for service in services.values():
- logger.info("Found application service: %s", service)
- self.services_cache.append(ApplicationService(
+ service_list.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
hs_token=service["hs_token"],
- sender=service["sender"]
+ sender=service["sender"],
+ id=service["id"]
))
+ return service_list
+
+ def _load_appservice(self, as_info):
+ required_string_fields = [
+ "url", "as_token", "hs_token", "sender_localpart"
+ ]
+ for field in required_string_fields:
+ if not isinstance(as_info.get(field), basestring):
+ raise KeyError("Required string field: '%s'", field)
+
+ localpart = as_info["sender_localpart"]
+ if urllib.quote(localpart) != localpart:
+ raise ValueError(
+ "sender_localpart needs characters which are not URL encoded."
+ )
+ user = UserID(localpart, self.hostname)
+ user_id = user.to_string()
+
+ # namespace checks
+ if not isinstance(as_info.get("namespaces"), dict):
+ raise KeyError("Requires 'namespaces' object.")
+ for ns in ApplicationService.NS_LIST:
+ # specific namespaces are optional
+ if ns in as_info["namespaces"]:
+ # expect a list of dicts with exclusive and regex keys
+ for regex_obj in as_info["namespaces"][ns]:
+ if not isinstance(regex_obj, dict):
+ raise ValueError(
+ "Expected namespace entry in %s to be an object,"
+ " but got %s", ns, regex_obj
+ )
+ if not isinstance(regex_obj.get("regex"), basestring):
+ raise ValueError(
+ "Missing/bad type 'regex' key in %s", regex_obj
+ )
+ if not isinstance(regex_obj.get("exclusive"), bool):
+ raise ValueError(
+ "Missing/bad type 'exclusive' key in %s", regex_obj
+ )
+ return ApplicationService(
+ token=as_info["as_token"],
+ url=as_info["url"],
+ namespaces=as_info["namespaces"],
+ hs_token=as_info["hs_token"],
+ sender=user_id,
+ id=as_info["as_token"] # the token is the only unique thing here
+ )
+
+ def _populate_appservice_cache(self, config_files):
+ """Populates a cache of Application Services from the config files."""
+ if not isinstance(config_files, list):
+ logger.warning(
+ "Expected %s to be a list of AS config files.", config_files
+ )
+ return
+
+ for config_file in config_files:
+ try:
+ with open(config_file, 'r') as f:
+ appservice = self._load_appservice(yaml.load(f))
+ logger.info("Loaded application service: %s", appservice)
+ self.services_cache.append(appservice)
+ except Exception as e:
+ logger.error("Failed to load appservice from '%s'", config_file)
+ logger.exception(e)
+
+
+class ApplicationServiceTransactionStore(SQLBaseStore):
+
+ def __init__(self, hs):
+ super(ApplicationServiceTransactionStore, self).__init__(hs)
+
+ @defer.inlineCallbacks
+ def get_appservices_by_state(self, state):
+ """Get a list of application services based on their state.
+
+ Args:
+ state(ApplicationServiceState): The state to filter on.
+ Returns:
+ A Deferred which resolves to a list of ApplicationServices, which
+ may be empty.
+ """
+ results = yield self._simple_select_list(
+ "application_services_state",
+ dict(state=state),
+ ["as_id"]
+ )
+ # NB: This assumes this class is linked with ApplicationServiceStore
+ as_list = yield self.get_app_services()
+ services = []
+
+ for res in results:
+ for service in as_list:
+ if service.id == res["as_id"]:
+ services.append(service)
+ defer.returnValue(services)
+
+ @defer.inlineCallbacks
+ def get_appservice_state(self, service):
+ """Get the application service state.
+
+ Args:
+ service(ApplicationService): The service whose state to set.
+ Returns:
+ A Deferred which resolves to ApplicationServiceState.
+ """
+ result = yield self._simple_select_one(
+ "application_services_state",
+ dict(as_id=service.id),
+ ["state"],
+ allow_none=True
+ )
+ if result:
+ defer.returnValue(result.get("state"))
+ return
+ defer.returnValue(None)
+
+ def set_appservice_state(self, service, state):
+ """Set the application service state.
+
+ Args:
+ service(ApplicationService): The service whose state to set.
+ state(ApplicationServiceState): The connectivity state to apply.
+ Returns:
+ A Deferred which resolves when the state was set successfully.
+ """
+ return self._simple_upsert(
+ "application_services_state",
+ dict(as_id=service.id),
+ dict(state=state)
+ )
+
+ def create_appservice_txn(self, service, events):
+ """Atomically creates a new transaction for this application service
+ with the given list of events.
+
+ Args:
+ service(ApplicationService): The service who the transaction is for.
+ events(list<Event>): A list of events to put in the transaction.
+ Returns:
+ AppServiceTransaction: A new transaction.
+ """
+ return self.runInteraction(
+ "create_appservice_txn",
+ self._create_appservice_txn,
+ service, events
+ )
+
+ def _create_appservice_txn(self, txn, service, events):
+ # work out new txn id (highest txn id for this service += 1)
+ # The highest id may be the last one sent (in which case it is last_txn)
+ # or it may be the highest in the txns list (which are waiting to be/are
+ # being sent)
+ last_txn_id = self._get_last_txn(txn, service.id)
+
+ txn.execute(
+ "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
+ (service.id,)
+ )
+ highest_txn_id = txn.fetchone()[0]
+ if highest_txn_id is None:
+ highest_txn_id = 0
+
+ new_txn_id = max(highest_txn_id, last_txn_id) + 1
+
+ # Insert new txn into txn table
+ event_ids = json.dumps([e.event_id for e in events])
+ txn.execute(
+ "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+ "VALUES(?,?,?)",
+ (service.id, new_txn_id, event_ids)
+ )
+ return AppServiceTransaction(
+ service=service, id=new_txn_id, events=events
+ )
+
+ def complete_appservice_txn(self, txn_id, service):
+ """Completes an application service transaction.
+
+ Args:
+ txn_id(str): The transaction ID being completed.
+ service(ApplicationService): The application service which was sent
+ this transaction.
+ Returns:
+ A Deferred which resolves if this transaction was stored
+ successfully.
+ """
+ return self.runInteraction(
+ "complete_appservice_txn",
+ self._complete_appservice_txn,
+ txn_id, service
+ )
+
+ def _complete_appservice_txn(self, txn, txn_id, service):
+ txn_id = int(txn_id)
+
+ # Debugging query: Make sure the txn being completed is EXACTLY +1 from
+ # what was there before. If it isn't, we've got problems (e.g. the AS
+ # has probably missed some events), so whine loudly but still continue,
+ # since it shouldn't fail completion of the transaction.
+ last_txn_id = self._get_last_txn(txn, service.id)
+ if (last_txn_id + 1) != txn_id:
+ logger.error(
+ "appservice: Completing a transaction which has an ID > 1 from "
+ "the last ID sent to this AS. We've either dropped events or "
+ "sent it to the AS out of order. FIX ME. last_txn=%s "
+ "completing_txn=%s service_id=%s", last_txn_id, txn_id,
+ service.id
+ )
+
+ # Set current txn_id for AS to 'txn_id'
+ self._simple_upsert_txn(
+ txn, "application_services_state", dict(as_id=service.id),
+ dict(last_txn=txn_id)
+ )
+
+ # Delete txn
+ self._simple_delete_txn(
+ txn, "application_services_txns",
+ dict(txn_id=txn_id, as_id=service.id)
+ )
+
+ def get_oldest_unsent_txn(self, service):
+ """Get the oldest transaction which has not been sent for this
+ service.
+
+ Args:
+ service(ApplicationService): The app service to get the oldest txn.
+ Returns:
+ A Deferred which resolves to an AppServiceTransaction or
+ None.
+ """
+ return self.runInteraction(
+ "get_oldest_unsent_appservice_txn",
+ self._get_oldest_unsent_txn,
+ service
+ )
+
+ def _get_oldest_unsent_txn(self, txn, service):
+ # Monotonically increasing txn ids, so just select the smallest
+ # one in the txns table (we delete them when they are sent)
+ txn.execute(
+ "SELECT * FROM application_services_txns WHERE as_id=?"
+ " ORDER BY txn_id ASC LIMIT 1",
+ (service.id,)
+ )
+ rows = self.cursor_to_dict(txn)
+ if not rows:
+ return None
+
+ entry = rows[0]
+
+ event_ids = json.loads(entry["event_ids"])
+ events = self._get_events_txn(txn, event_ids)
+
+ return AppServiceTransaction(
+ service=service, id=entry["txn_id"], events=events
+ )
+
+ def _get_last_txn(self, txn, service_id):
+ txn.execute(
+ "SELECT last_txn FROM application_services_state WHERE as_id=?",
+ (service_id,)
+ )
+ last_txn_id = txn.fetchone()
+ if last_txn_id is None or last_txn_id[0] is None: # no row exists
+ return 0
+ else:
+ return int(last_txn_id[0]) # select 'last_txn' col
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 68b7d59693..2b2bdf8615 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.errors import SynapseError
@@ -21,8 +21,6 @@ from twisted.internet import defer
from collections import namedtuple
-import sqlite3
-
RoomAliasMapping = namedtuple(
"RoomAliasMapping",
@@ -48,6 +46,7 @@ class DirectoryStore(SQLBaseStore):
{"room_alias": room_alias.to_string()},
"room_id",
allow_none=True,
+ desc="get_association_from_room_alias",
)
if not room_id:
@@ -58,6 +57,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
+ desc="get_association_from_room_alias",
)
if not servers:
@@ -87,8 +87,9 @@ class DirectoryStore(SQLBaseStore):
"room_alias": room_alias.to_string(),
"room_id": room_id,
},
+ desc="create_room_alias_association",
)
- except sqlite3.IntegrityError:
+ except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
@@ -100,23 +101,29 @@ class DirectoryStore(SQLBaseStore):
{
"room_alias": room_alias.to_string(),
"server": server,
- }
+ },
+ desc="create_room_alias_association",
)
+ self.get_aliases_for_room.invalidate(room_id)
+ @defer.inlineCallbacks
def delete_room_alias(self, room_alias):
- return self.runInteraction(
+ room_id = yield self.runInteraction(
"delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
+ self.get_aliases_for_room.invalidate(room_id)
+ defer.returnValue(room_id)
+
def _delete_room_alias_txn(self, txn, room_alias):
- cursor = txn.execute(
+ txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),)
)
- res = cursor.fetchone()
+ res = txn.fetchone()
if res:
room_id = res[0]
else:
@@ -134,9 +141,11 @@ class DirectoryStore(SQLBaseStore):
return room_id
+ @cached()
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
+ desc="get_aliases_for_room",
)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
new file mode 100644
index 0000000000..bd3c8f9452
--- /dev/null
+++ b/synapse/storage/engines/__init__.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import IncorrectDatabaseSetup
+from .postgres import PostgresEngine
+from .sqlite3 import Sqlite3Engine
+
+import importlib
+
+
+SUPPORTED_MODULE = {
+ "sqlite3": Sqlite3Engine,
+ "psycopg2": PostgresEngine,
+}
+
+
+def create_engine(name):
+ engine_class = SUPPORTED_MODULE.get(name, None)
+
+ if engine_class:
+ module = importlib.import_module(name)
+ return engine_class(module)
+
+ raise RuntimeError(
+ "Unsupported database engine '%s'" % (name,)
+ )
+
+
+__all__ = ["create_engine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
new file mode 100644
index 0000000000..0b549d314b
--- /dev/null
+++ b/synapse/storage/engines/_base.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class IncorrectDatabaseSetup(RuntimeError):
+ pass
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
new file mode 100644
index 0000000000..a323028546
--- /dev/null
+++ b/synapse/storage/engines/postgres.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage import prepare_database
+
+from ._base import IncorrectDatabaseSetup
+
+
+class PostgresEngine(object):
+ def __init__(self, database_module):
+ self.module = database_module
+ self.module.extensions.register_type(self.module.extensions.UNICODE)
+
+ def check_database(self, txn):
+ txn.execute("SHOW SERVER_ENCODING")
+ rows = txn.fetchall()
+ if rows and rows[0][0] != "UTF8":
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+ "See docs/postgres.rst for more information."
+ % (rows[0][0],)
+ )
+
+ def convert_param_style(self, sql):
+ return sql.replace("?", "%s")
+
+ def on_new_connection(self, db_conn):
+ db_conn.set_isolation_level(
+ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+ )
+
+ def prepare_database(self, db_conn):
+ prepare_database(db_conn, self)
+
+ def is_deadlock(self, error):
+ if isinstance(error, self.module.DatabaseError):
+ return error.pgcode in ["40001", "40P01"]
+ return False
+
+ def is_connection_closed(self, conn):
+ return bool(conn.closed)
+
+ def lock_table(self, txn, table):
+ txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
new file mode 100644
index 0000000000..ff13d8006a
--- /dev/null
+++ b/synapse/storage/engines/sqlite3.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage import prepare_database, prepare_sqlite3_database
+
+
+class Sqlite3Engine(object):
+ def __init__(self, database_module):
+ self.module = database_module
+
+ def check_database(self, txn):
+ pass
+
+ def convert_param_style(self, sql):
+ return sql
+
+ def on_new_connection(self, db_conn):
+ self.prepare_database(db_conn)
+
+ def prepare_database(self, db_conn):
+ prepare_sqlite3_database(db_conn)
+ prepare_database(db_conn, self)
+
+ def is_deadlock(self, error):
+ return False
+
+ def is_connection_closed(self, conn):
+ return False
+
+ def lock_table(self, txn, table):
+ return
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 032334bfd6..74b4e23590 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from syutil.base64util import encode_base64
import logging
@@ -96,11 +96,23 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
+ @cached()
+ def get_latest_event_ids_in_room(self, room_id):
+ return self._simple_select_onecol(
+ table="event_forward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ desc="get_latest_event_ids_in_room",
+ )
+
def _get_latest_events_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
"WHERE f.room_id = ?"
)
@@ -153,7 +165,7 @@ class EventFederationStore(SQLBaseStore):
results = self._get_prev_events_and_state(
txn,
event_id,
- is_state=1,
+ is_state=True,
)
return [(e_id, h, ) for e_id, h, _ in results]
@@ -164,7 +176,7 @@ class EventFederationStore(SQLBaseStore):
}
if is_state is not None:
- keyvalues["is_state"] = is_state
+ keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn(
txn,
@@ -242,7 +254,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id,
"min_depth": depth,
},
- or_replace=True,
)
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@@ -251,19 +262,19 @@ class EventFederationStore(SQLBaseStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- for e_id, _ in prev_events:
- # TODO (erikj): This could be done as a bulk insert
- self._simple_insert_txn(
- txn,
- table="event_edges",
- values={
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
"event_id": event_id,
"prev_event_id": e_id,
"room_id": room_id,
- "is_state": 0,
- },
- or_ignore=True,
- )
+ "is_state": False,
+ }
+ for e_id, _ in prev_events
+ ],
+ )
# Update the extremities table if this is not an outlier.
if not outlier:
@@ -281,33 +292,33 @@ class EventFederationStore(SQLBaseStore):
# We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event
query = (
- "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
- "SELECT ?, ? WHERE NOT EXISTS ("
- "SELECT 1 FROM %(event_edges)s WHERE "
- "prev_event_id = ? "
- ")"
- ) % {
- "table": "event_forward_extremities",
- "event_edges": "event_edges",
- }
+ "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
+ )
- logger.debug("query: %s", query)
+ txn.execute(query, (event_id,))
+
+ if not txn.fetchone():
+ query = (
+ "INSERT INTO event_forward_extremities"
+ " (event_id, room_id)"
+ " VALUES (?, ?)"
+ )
- txn.execute(query, (event_id, room_id, event_id))
+ txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
- for e_id, _ in prev_events:
- # TODO (erikj): This could be done as a bulk insert
- self._simple_insert_txn(
- txn,
- table="event_backward_extremities",
- values={
+ self._simple_insert_many_txn(
+ txn,
+ table="event_backward_extremities",
+ values=[
+ {
"event_id": e_id,
"room_id": room_id,
- },
- or_ignore=True,
- )
+ }
+ for e_id, _ in prev_events
+ ],
+ )
# Also delete from the backwards extremities table all ones that
# reference events that we have already seen
@@ -321,6 +332,10 @@ class EventFederationStore(SQLBaseStore):
)
txn.execute(query)
+ txn.call_after(
+ self.get_latest_event_ids_in_room.invalidate, room_id
+ )
+
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -400,7 +415,7 @@ class EventFederationStore(SQLBaseStore):
query = (
"SELECT prev_event_id FROM event_edges "
- "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+ "WHERE room_id = ? AND event_id = ? AND is_state = ? "
"LIMIT ?"
)
@@ -409,7 +424,7 @@ class EventFederationStore(SQLBaseStore):
for event_id in front:
txn.execute(
query,
- (room_id, event_id, limit - len(event_results))
+ (room_id, event_id, False, limit - len(event_results))
)
for e_id, in txn.fetchall():
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
new file mode 100644
index 0000000000..38395c66ab
--- /dev/null
+++ b/synapse/storage/events.py
@@ -0,0 +1,391 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore, _RollbackButIsFineException
+
+from twisted.internet import defer
+
+from synapse.util.logutils import log_function
+from synapse.api.constants import EventTypes
+from synapse.crypto.event_signing import compute_event_reference_hash
+
+from syutil.base64util import decode_base64
+from syutil.jsonutil import encode_canonical_json
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class EventsStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ @log_function
+ def persist_event(self, event, context, backfilled=False,
+ is_new_state=True, current_state=None):
+ stream_ordering = None
+ if backfilled:
+ if not self.min_token_deferred.called:
+ yield self.min_token_deferred
+ self.min_token -= 1
+ stream_ordering = self.min_token
+
+ try:
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
+ except _RollbackButIsFineException:
+ pass
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ event = yield self.runInteraction(
+ "get_event", self._get_event_txn,
+ event_id,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not event and not allow_none:
+ raise RuntimeError("Could not find event %s" % (event_id,))
+
+ defer.returnValue(event)
+
+ @log_function
+ def _persist_event_txn(self, txn, event, context, backfilled,
+ stream_ordering=None, is_new_state=True,
+ current_state=None):
+
+ # Remove the any existing cache entries for the event_id
+ txn.call_after(self._invalidate_get_event_cache, event.event_id)
+
+ if stream_ordering is None:
+ with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
+ return self._persist_event_txn(
+ txn, event, context, backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
+
+ # We purposefully do this first since if we include a `current_state`
+ # key, we *want* to update the `current_state_events` table
+ if current_state:
+ self._simple_delete_txn(
+ txn,
+ table="current_state_events",
+ keyvalues={"room_id": event.room_id},
+ )
+
+ for s in current_state:
+ if s.type == EventTypes.Member:
+ txn.call_after(
+ self.get_rooms_for_user.invalidate, s.state_key
+ )
+ txn.call_after(
+ self.get_joined_hosts_for_room.invalidate, s.room_id
+ )
+ self._simple_insert_txn(
+ txn,
+ "current_state_events",
+ {
+ "event_id": s.event_id,
+ "room_id": s.room_id,
+ "type": s.type,
+ "state_key": s.state_key,
+ }
+ )
+
+ outlier = event.internal_metadata.is_outlier()
+
+ if not outlier:
+ self._store_state_groups_txn(txn, event, context)
+
+ self._update_min_depth_for_room_txn(
+ txn,
+ event.room_id,
+ event.depth
+ )
+
+ have_persisted = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event.event_id},
+ retcol="event_id",
+ allow_none=True,
+ )
+
+ metadata_json = encode_canonical_json(
+ event.internal_metadata.get_dict()
+ ).decode("UTF-8")
+
+ # If we have already persisted this event, we don't need to do any
+ # more processing.
+ # The processing above must be done on every call to persist event,
+ # since they might not have happened on previous calls. For example,
+ # if we are persisting an event that we had persisted as an outlier,
+ # but is no longer one.
+ if have_persisted:
+ if not outlier:
+ sql = (
+ "UPDATE event_json SET internal_metadata = ?"
+ " WHERE event_id = ?"
+ )
+ txn.execute(
+ sql,
+ (metadata_json, event.event_id,)
+ )
+
+ sql = (
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id = ?"
+ )
+ txn.execute(
+ sql,
+ (False, event.event_id,)
+ )
+ return
+
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ if event.type == EventTypes.Member:
+ self._store_room_member_txn(txn, event)
+ elif event.type == EventTypes.Name:
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ self._store_redaction(txn, event)
+
+ event_dict = {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in [
+ "redacted",
+ "redacted_because",
+ ]
+ }
+
+ self._simple_insert_txn(
+ txn,
+ table="event_json",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "internal_metadata": metadata_json,
+ "json": encode_canonical_json(event_dict).decode("UTF-8"),
+ },
+ )
+
+ content = encode_canonical_json(
+ event.content
+ ).decode("UTF-8")
+
+ vals = {
+ "topological_ordering": event.depth,
+ "event_id": event.event_id,
+ "type": event.type,
+ "room_id": event.room_id,
+ "content": content,
+ "processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
+ }
+
+ unrec = {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in vals.keys() and k not in [
+ "redacted",
+ "redacted_because",
+ "signatures",
+ "hashes",
+ "prev_events",
+ ]
+ }
+
+ vals["unrecognized_keys"] = encode_canonical_json(
+ unrec
+ ).decode("UTF-8")
+
+ sql = (
+ "INSERT INTO events"
+ " (stream_ordering, topological_ordering, event_id, type,"
+ " room_id, content, processed, outlier, depth)"
+ " VALUES (?,?,?,?,?,?,?,?,?)"
+ )
+
+ txn.execute(
+ sql,
+ (
+ stream_ordering, event.depth, event.event_id, event.type,
+ event.room_id, content, True, outlier, event.depth
+ )
+ )
+
+ if context.rejected:
+ self._store_rejections_txn(
+ txn, event.event_id, context.rejected
+ )
+
+ for hash_alg, hash_base64 in event.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_event_content_hash_txn(
+ txn, event.event_id, hash_alg, hash_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg,
+ hash_bytes
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_auth",
+ values=[
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "auth_id": auth_id,
+ }
+ for auth_id, _ in event.auth_events
+ ],
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+ self._store_event_reference_hash_txn(
+ txn, event.event_id, ref_alg, ref_hash_bytes
+ )
+
+ if event.is_state():
+ vals = {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+
+ # TODO: How does this work with backfilling?
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
+
+ self._simple_insert_txn(
+ txn,
+ "state_events",
+ vals,
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": event.event_id,
+ "prev_event_id": e_id,
+ "room_id": event.room_id,
+ "is_state": True,
+ }
+ for e_id, h in event.prev_state
+ ],
+ )
+
+ if is_new_state and not context.rejected:
+ self._simple_upsert_txn(
+ txn,
+ "current_state_events",
+ keyvalues={
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ }
+ )
+
+ return
+
+ def _store_redaction(self, txn, event):
+ # invalidate the cache for the redacted event
+ txn.call_after(self._invalidate_get_event_cache, event.redacts)
+ txn.execute(
+ "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
+ (event.event_id, event.redacts)
+ )
+
+ def have_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Returns:
+ dict: Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps to
+ None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction(
+ "have_events", f,
+ )
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
deleted file mode 100644
index 8eab769b71..0000000000
--- a/synapse/storage/feedback.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore
-
-
-class FeedbackStore(SQLBaseStore):
-
- def _store_feedback_txn(self, txn, event):
- self._simple_insert_txn(txn, "feedback", {
- "event_id": event.event_id,
- "feedback_type": event.content["type"],
- "room_id": event.room_id,
- "target_event_id": event.content["target_event_id"],
- "sender": event.user_id,
- })
-
- @defer.inlineCallbacks
- def get_feedback_for_event(self, event_id):
- sql = (
- "SELECT events.* FROM events INNER JOIN feedback "
- "ON events.event_id = feedback.event_id "
- "WHERE feedback.target_event_id = ? "
- )
-
- rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
-
- defer.returnValue(
- [
- (yield self._parse_events(r))
- for r in rows
- ]
- )
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 457a11fd02..8800116570 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
},
retcol="filter_json",
allow_none=False,
+ desc="get_user_filter",
)
defer.returnValue(json.loads(def_json))
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 09d1e63657..5bdf497b93 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -57,16 +57,18 @@ class KeyStore(SQLBaseStore):
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
)
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
- return self._simple_insert(
+ return self._simple_upsert(
table="server_tls_certificates",
- values={
+ keyvalues={
"server_name": server_name,
"fingerprint": fingerprint,
+ },
+ values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes),
},
- or_ignore=True,
+ desc="store_server_certificate",
)
@defer.inlineCallbacks
@@ -107,14 +109,85 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- return self._simple_insert(
+ return self._simple_upsert(
table="server_signature_keys",
- values={
+ keyvalues={
"server_name": server_name,
"key_id": "%s:%s" % (verify_key.alg, verify_key.version),
+ },
+ values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()),
},
- or_ignore=True,
+ desc="store_server_verify_key",
+ )
+
+ def store_server_keys_json(self, server_name, key_id, from_server,
+ ts_now_ms, ts_expires_ms, key_json_bytes):
+ """Stores the JSON bytes for a set of keys from a server
+ The JSON should be signed by the originating server, the intermediate
+ server, and by this server. Updates the value for the
+ (server_name, key_id, from_server) triplet if one already existed.
+ Args:
+ server_name (str): The name of the server.
+ key_id (str): The identifer of the key this JSON is for.
+ from_server (str): The server this JSON was fetched from.
+ ts_now_ms (int): The time now in milliseconds.
+ ts_valid_until_ms (int): The time when this json stops being valid.
+ key_json (bytes): The encoded JSON.
+ """
+ return self._simple_upsert(
+ table="server_keys_json",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ },
+ values={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ "ts_added_ms": ts_now_ms,
+ "ts_valid_until_ms": ts_expires_ms,
+ "key_json": buffer(key_json_bytes),
+ },
+ )
+
+ def get_server_keys_json(self, server_keys):
+ """Retrive the key json for a list of server_keys and key ids.
+ If no keys are found for a given server, key_id and source then
+ that server, key_id, and source triplet entry will be an empty list.
+ The JSON is returned as a byte array so that it can be efficiently
+ used in an HTTP response.
+ Args:
+ server_keys (list): List of (server_name, key_id, source) triplets.
+ Returns:
+ Dict mapping (server_name, key_id, source) triplets to dicts with
+ "ts_valid_until_ms" and "key_json" keys.
+ """
+ def _get_server_keys_json_txn(txn):
+ results = {}
+ for server_name, key_id, from_server in server_keys:
+ keyvalues = {"server_name": server_name}
+ if key_id is not None:
+ keyvalues["key_id"] = key_id
+ if from_server is not None:
+ keyvalues["from_server"] = from_server
+ rows = self._simple_select_list_txn(
+ txn,
+ "server_keys_json",
+ keyvalues=keyvalues,
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ )
+ results[(server_name, key_id, from_server)] = rows
+ return results
+ return self.runInteraction(
+ "get_server_keys_json", _get_server_keys_json_txn
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 7101d2beec..7bf57234f6 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"),
allow_none=True,
+ desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
@@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
- }
+ },
+ desc="store_local_media",
)
def get_local_media_thumbnails(self, media_id):
@@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length",
- )
+ ),
+ desc="get_local_media_thumbnails",
)
def store_local_thumbnail(self, media_id, thumbnail_width,
@@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
- }
+ },
+ desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
@@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
"filesystem_id",
),
allow_none=True,
+ desc="get_cached_remote_media",
)
def store_cached_remote_media(self, origin, media_id, media_type,
@@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
- }
+ },
+ desc="store_cached_remote_media",
)
def get_remote_media_thumbnails(self, origin, media_id):
@@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id",
- )
+ ),
+ desc="get_remote_media_thumbnails",
)
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
@@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id,
- }
+ },
+ desc="store_remote_media_thumbnail",
)
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1dcd34723b..22ec94bc16 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
return self._simple_insert(
table="presence",
values={"user_id": user_localpart},
+ desc="create_presence",
)
def has_presence_state(self, user_localpart):
@@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
+ desc="has_presence_state",
)
def get_presence_state(self, user_localpart):
@@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
+ desc="get_presence_state",
)
def set_presence_state(self, user_localpart, new_state):
@@ -45,7 +48,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
- retcols=["state"],
+ desc="set_presence_state",
)
def allow_presence_visible(self, observed_localpart, observer_userid):
@@ -53,6 +56,8 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
+ desc="allow_presence_visible",
+ or_ignore=True,
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
@@ -60,6 +65,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
+ desc="disallow_presence_visible",
)
def is_presence_visible(self, observed_localpart, observer_userid):
@@ -69,6 +75,7 @@ class PresenceStore(SQLBaseStore):
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
+ desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
@@ -77,6 +84,7 @@ class PresenceStore(SQLBaseStore):
values={"user_id": observer_localpart,
"observed_user_id": observed_userid,
"accepted": False},
+ desc="add_presence_list_pending",
)
def set_presence_list_accepted(self, observer_localpart, observed_userid):
@@ -85,6 +93,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
+ desc="set_presence_list_accepted",
)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -96,6 +105,7 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
+ desc="get_presence_list",
)
def del_presence_list(self, observer_localpart, observed_userid):
@@ -103,4 +113,5 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
+ desc="del_presence_list",
)
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 153c7ad027..a6e52cb248 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
return self._simple_insert(
table="profiles",
values={"user_id": user_localpart},
+ desc="create_profile",
)
def get_profile_displayname(self, user_localpart):
@@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
+ desc="get_profile_displayname",
)
def set_profile_displayname(self, user_localpart, new_displayname):
@@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
+ desc="set_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
@@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
+ desc="get_profile_avatar_url",
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
@@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
+ desc="set_profile_avatar_url",
)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index d769db2c78..ee7718d5ed 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name},
- PushRuleEnableTable.fields
+ PushRuleEnableTable.fields,
+ desc="get_push_rules_enabled_for_user",
)
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
@@ -153,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
- sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@@ -182,7 +183,7 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
- sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
"""
yield self._simple_delete_one(
PushRuleTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id}
+ {'user_name': user_name, 'rule_id': rule_id},
+ desc="delete_push_rule",
)
@defer.inlineCallbacks
@@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
- {'enabled': enabled}
+ {'enabled': enabled},
+ desc="set_push_rule_enabled",
)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 587dada68f..08ea62681b 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -13,162 +13,141 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
-
from ._base import SQLBaseStore, Table
from twisted.internet import defer
from synapse.api.errors import StoreError
+from syutil.jsonutil import encode_canonical_json
+
import logging
+import simplejson as json
+import types
logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore):
+ def _decode_pushers_rows(self, rows):
+ for r in rows:
+ dataJson = r['data']
+ r['data'] = None
+ try:
+ if isinstance(dataJson, types.BufferType):
+ dataJson = str(dataJson).decode("UTF8")
+
+ r['data'] = json.loads(dataJson)
+ except Exception as e:
+ logger.warn(
+ "Invalid JSON in data for pusher %d: %s, %s",
+ r['id'], dataJson, e.message,
+ )
+ pass
+
+ if isinstance(r['pushkey'], types.BufferType):
+ r['pushkey'] = str(r['pushkey']).decode("UTF8")
+
+ return rows
+
@defer.inlineCallbacks
- def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
- sql = (
- "SELECT id, user_name, kind, profile_tag, app_id,"
- "app_display_name, device_display_name, pushkey, ts, data, "
- "last_token, last_success, failing_since "
- "FROM pushers "
- "WHERE app_id = ? AND pushkey = ?"
- )
+ def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
+ def r(txn):
+ sql = (
+ "SELECT * FROM pushers"
+ " WHERE app_id = ? AND pushkey = ?"
+ )
- rows = yield self._execute(
- "get_pushers_by_app_id_and_pushkey", None, sql,
- app_id_and_pushkey[0], app_id_and_pushkey[1]
+ txn.execute(sql, (app_id, pushkey,))
+ rows = self.cursor_to_dict(txn)
+
+ return self._decode_pushers_rows(rows)
+
+ rows = yield self.runInteraction(
+ "get_pushers_by_app_id_and_pushkey", r
)
- ret = [
- {
- "id": r[0],
- "user_name": r[1],
- "kind": r[2],
- "profile_tag": r[3],
- "app_id": r[4],
- "app_display_name": r[5],
- "device_display_name": r[6],
- "pushkey": r[7],
- "pushkey_ts": r[8],
- "data": r[9],
- "last_token": r[10],
- "last_success": r[11],
- "failing_since": r[12]
- }
- for r in rows
- ]
-
- defer.returnValue(ret[0])
+ defer.returnValue(rows)
@defer.inlineCallbacks
def get_all_pushers(self):
- sql = (
- "SELECT id, user_name, kind, profile_tag, app_id,"
- "app_display_name, device_display_name, pushkey, ts, data, "
- "last_token, last_success, failing_since "
- "FROM pushers"
- )
+ def get_pushers(txn):
+ txn.execute("SELECT * FROM pushers")
+ rows = self.cursor_to_dict(txn)
- rows = yield self._execute("get_all_pushers", None, sql)
-
- ret = [
- {
- "id": r[0],
- "user_name": r[1],
- "kind": r[2],
- "profile_tag": r[3],
- "app_id": r[4],
- "app_display_name": r[5],
- "device_display_name": r[6],
- "pushkey": r[7],
- "pushkey_ts": r[8],
- "data": r[9],
- "last_token": r[10],
- "last_success": r[11],
- "failing_since": r[12]
- }
- for r in rows
- ]
-
- defer.returnValue(ret)
+ return self._decode_pushers_rows(rows)
+
+ rows = yield self.runInteraction("get_all_pushers", get_pushers)
+ defer.returnValue(rows)
@defer.inlineCallbacks
- def add_pusher(self, user_name, profile_tag, kind, app_id,
+ def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data):
try:
+ next_id = yield self._pushers_id_gen.get_next()
yield self._simple_upsert(
PushersTable.table_name,
dict(
app_id=app_id,
pushkey=pushkey,
+ user_name=user_name,
),
dict(
- user_name=user_name,
+ access_token=access_token,
kind=kind,
profile_tag=profile_tag,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
- data=data
- ))
+ data=encode_canonical_json(data),
+ ),
+ insertion_values=dict(
+ id=next_id,
+ ),
+ desc="add_pusher",
+ )
except Exception as e:
logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
+ def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
yield self._simple_delete_one(
PushersTable.table_name,
- dict(app_id=app_id, pushkey=pushkey)
+ {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
+ desc="delete_pusher_by_app_id_pushkey_user_name",
)
@defer.inlineCallbacks
- def update_pusher_last_token(self, app_id, pushkey, last_token):
+ def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
yield self._simple_update_one(
PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey},
- {'last_token': last_token}
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ {'last_token': last_token},
+ desc="update_pusher_last_token",
)
@defer.inlineCallbacks
- def update_pusher_last_token_and_success(self, app_id, pushkey,
+ def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
last_token, last_success):
yield self._simple_update_one(
PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey},
- {'last_token': last_token, 'last_success': last_success}
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ {'last_token': last_token, 'last_success': last_success},
+ desc="update_pusher_last_token_and_success",
)
@defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, failing_since):
+ def update_pusher_failing_since(self, app_id, pushkey, user_name,
+ failing_since):
yield self._simple_update_one(
PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey},
- {'failing_since': failing_since}
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ {'failing_since': failing_since},
+ desc="update_pusher_failing_since",
)
class PushersTable(Table):
table_name = "pushers"
-
- fields = [
- "id",
- "user_name",
- "kind",
- "profile_tag",
- "app_id",
- "app_display_name",
- "device_display_name",
- "pushkey",
- "pushkey_ts",
- "data",
- "last_token",
- "last_success",
- "failing_since"
- ]
-
- EntryType = collections.namedtuple("PusherEntry", fields)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3c2f1d6a15..90e2606be2 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
-from sqlite3 import IntegrityError
-
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached
@@ -39,16 +37,16 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
- if not row:
- raise StoreError(400, "Bad user ID supplied.")
- row_id = row["id"]
+ next_id = yield self._access_tokens_id_gen.get_next()
+
yield self._simple_insert(
"access_tokens",
{
- "user_id": row_id,
+ "id": next_id,
+ "user_id": user_id,
"token": token
- }
+ },
+ desc="add_access_token_to_user",
)
@defer.inlineCallbacks
@@ -70,32 +68,72 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
+ next_id = self._access_tokens_id_gen.get_next_txn(txn)
+
try:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)",
[user_id, password_hash, now])
- except IntegrityError:
+ except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
- txn.execute("INSERT INTO access_tokens(user_id, token) " +
- "VALUES (?,?)", [txn.lastrowid, token])
+ txn.execute(
+ "INSERT INTO access_tokens(id, user_id, token)"
+ " VALUES (?,?,?)",
+ (next_id, user_id, token,)
+ )
def get_user_by_id(self, user_id):
- query = ("SELECT users.name, users.password_hash FROM users"
- " WHERE users.name = ?")
- return self._execute(
- "get_user_by_id", self.cursor_to_dict, query, user_id
+ return self._simple_select_one(
+ table="users",
+ keyvalues={
+ "name": user_id,
+ },
+ retcols=["name", "password_hash"],
+ allow_none=True,
+ )
+
+ @defer.inlineCallbacks
+ def user_set_password_hash(self, user_id, password_hash):
+ """
+ NB. This does *not* evict any cache because the one use for this
+ removes most of the entries subsequently anyway so it would be
+ pointless. Use flush_user separately.
+ """
+ yield self._simple_update_one('users', {
+ 'name': user_id
+ }, {
+ 'password_hash': password_hash
+ })
+
+ @defer.inlineCallbacks
+ def user_delete_access_tokens_apart_from(self, user_id, token_id):
+ yield self.runInteraction(
+ "user_delete_access_tokens_apart_from",
+ self._user_delete_access_tokens_apart_from, user_id, token_id
)
+ def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
+ txn.execute(
+ "DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
+ (user_id, token_id)
+ )
+
+ @defer.inlineCallbacks
+ def flush_user(self, user_id):
+ rows = yield self._execute(
+ 'flush_user', None,
+ "SELECT token FROM access_tokens WHERE user_id = ?",
+ user_id
+ )
+ for r in rows:
+ self.get_user_by_token.invalidate(r)
+
@cached()
- # TODO(paul): Currently there's no code to invalidate this cache. That
- # means if/when we ever add internal ways to invalidate access tokens or
- # change whether a user is a server admin, those will need to invoke
- # store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token):
"""Get a user from the given access token.
@@ -120,6 +158,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
+ desc="is_server_admin",
)
defer.returnValue(res if res else False)
@@ -129,13 +168,49 @@ class RegistrationStore(SQLBaseStore):
"SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id"
" FROM users"
- " INNER JOIN access_tokens on users.id = access_tokens.user_id"
+ " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
- cursor = txn.execute(sql, (token,))
- rows = self.cursor_to_dict(cursor)
+ txn.execute(sql, (token,))
+ rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
- raise StoreError(404, "Token not found.")
+ return None
+
+ @defer.inlineCallbacks
+ def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ yield self._simple_upsert("user_threepids", {
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ }, {
+ "validated_at": validated_at,
+ "added_at": added_at,
+ })
+
+ @defer.inlineCallbacks
+ def user_get_threepids(self, user_id):
+ ret = yield self._simple_select_list(
+ "user_threepids", {
+ "user_id": user_id
+ },
+ ['medium', 'address', 'validated_at', 'added_at'],
+ 'user_get_threepids'
+ )
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def get_user_id_by_threepid(self, medium, address):
+ ret = yield self._simple_select_one(
+ "user_threepids",
+ {
+ "medium": medium,
+ "address": address
+ },
+ ['user_id'], True, 'get_user_id_by_threepid'
+ )
+ if ret:
+ defer.returnValue(ret['user_id'])
+ defer.returnValue(None)
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 4e1a9a2783..0838eb3d12 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
"reason": reason,
"last_check": self._clock.time_msec(),
- }
+ },
)
def get_rejection_reason(self, event_id):
@@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
},
allow_none=True,
+ desc="get_rejection_reason",
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 549c9af393..f956377632 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -15,11 +15,9 @@
from twisted.internet import defer
-from sqlite3 import IntegrityError
-
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Table
+from ._base import SQLBaseStore
import collections
import logging
@@ -27,8 +25,9 @@ import logging
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple("OpsLevel", (
- "ban_level", "kick_level", "redact_level")
+OpsLevel = collections.namedtuple(
+ "OpsLevel",
+ ("ban_level", "kick_level", "redact_level",)
)
@@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored.
"""
try:
- yield self._simple_insert(RoomsTable.table_name, dict(
- room_id=room_id,
- creator=room_creator_user_id,
- is_public=is_public
- ))
- except IntegrityError:
- raise StoreError(409, "Room ID in use.")
+ yield self._simple_insert(
+ RoomsTable.table_name,
+ {
+ "room_id": room_id,
+ "creator": room_creator_user_id,
+ "is_public": is_public,
+ },
+ desc="store_room",
+ )
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -66,9 +67,22 @@ class RoomStore(SQLBaseStore):
Returns:
A namedtuple containing the room information, or an empty list.
"""
- query = RoomsTable.select_statement("room_id=?")
- return self._execute(
- "get_room", RoomsTable.decode_single_result, query, room_id,
+ return self._simple_select_one(
+ table=RoomsTable.table_name,
+ keyvalues={"room_id": room_id},
+ retcols=RoomsTable.fields,
+ desc="get_room",
+ allow_none=True,
+ )
+
+ def get_public_room_ids(self):
+ return self._simple_select_onecol(
+ table="rooms",
+ keyvalues={
+ "is_public": True,
+ },
+ retcol="room_id",
+ desc="get_public_room_ids",
)
@defer.inlineCallbacks
@@ -99,24 +113,37 @@ class RoomStore(SQLBaseStore):
"ON c.event_id = room_names.event_id "
)
- # We use non printing ascii character US () as a seperator
+ # We use non printing ascii character US (\x1F) as a separator
sql = (
- "SELECT r.room_id, n.name, t.topic, "
- "group_concat(a.room_alias, '') "
- "FROM rooms AS r "
- "LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id "
- "LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id "
- "INNER JOIN room_aliases AS a ON a.room_id = r.room_id "
- "WHERE r.is_public = ? "
- "GROUP BY r.room_id "
+ "SELECT r.room_id, max(n.name), max(t.topic)"
+ " FROM rooms AS r"
+ " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
+ " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
+ " WHERE r.is_public = ?"
+ " GROUP BY r.room_id"
) % {
"topic": topic_subquery,
"name": name_subquery,
}
- c = txn.execute(sql, (is_public,))
+ txn.execute(sql, (is_public,))
+
+ rows = txn.fetchall()
- return c.fetchall()
+ for i, row in enumerate(rows):
+ room_id = row[0]
+ aliases = self._simple_select_onecol_txn(
+ txn,
+ table="room_aliases",
+ keyvalues={
+ "room_id": room_id
+ },
+ retcol="room_alias",
+ )
+
+ rows[i] = list(row) + [aliases]
+
+ return rows
rows = yield self.runInteraction(
"get_rooms", f
@@ -127,9 +154,10 @@ class RoomStore(SQLBaseStore):
"room_id": r[0],
"name": r[1],
"topic": r[2],
- "aliases": r[3].split(""),
+ "aliases": r[3],
}
for r in rows
+ if r[3] # We only return rooms that have at least one alias.
]
defer.returnValue(ret)
@@ -143,7 +171,7 @@ class RoomStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"topic": event.content["topic"],
- }
+ },
)
def _store_room_name_txn(self, txn, event):
@@ -158,8 +186,39 @@ class RoomStore(SQLBaseStore):
}
)
+ @defer.inlineCallbacks
+ def get_room_name_and_aliases(self, room_id):
+ def f(txn):
+ sql = (
+ "SELECT event_id FROM current_state_events "
+ "WHERE room_id = ? "
+ )
+
+ sql += " AND ((type = 'm.room.name' AND state_key = '')"
+ sql += " OR type = 'm.room.aliases')"
+
+ txn.execute(sql, (room_id,))
+ results = self.cursor_to_dict(txn)
+
+ return self._parse_events_txn(txn, results)
+
+ events = yield self.runInteraction("get_room_name_and_aliases", f)
-class RoomsTable(Table):
+ name = None
+ aliases = []
+
+ for e in events:
+ if e.type == 'm.room.name':
+ if 'name' in e.content:
+ name = e.content['name']
+ elif e.type == 'm.room.aliases':
+ if 'aliases' in e.content:
+ aliases.extend(e.content['aliases'])
+
+ defer.returnValue((name, aliases))
+
+
+class RoomsTable(object):
table_name = "rooms"
fields = [
@@ -167,5 +226,3 @@ class RoomsTable(Table):
"is_public",
"creator"
]
-
- EntryType = collections.namedtuple("RoomEntry", fields)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 65ffb4627f..839c74f63a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -40,7 +40,6 @@ class RoomMemberStore(SQLBaseStore):
"""
try:
target_user_id = event.state_key
- domain = UserID.from_string(target_user_id).domain
except:
logger.exception(
"Failed to parse target_user_id=%s", target_user_id
@@ -65,42 +64,9 @@ class RoomMemberStore(SQLBaseStore):
}
)
- # Update room hosts table
- if event.membership == Membership.JOIN:
- sql = (
- "INSERT OR IGNORE INTO room_hosts (room_id, host) "
- "VALUES (?, ?)"
- )
- txn.execute(sql, (event.room_id, domain))
- elif event.membership != Membership.INVITE:
- # Check if this was the last person to have left.
- member_events = self._get_members_query_txn(
- txn,
- where_clause=("c.room_id = ? AND m.membership = ?"
- " AND m.user_id != ?"),
- where_values=(event.room_id, Membership.JOIN, target_user_id,)
- )
-
- joined_domains = set()
- for e in member_events:
- try:
- joined_domains.add(
- UserID.from_string(e.state_key).domain
- )
- except:
- # FIXME: How do we deal with invalid user ids in the db?
- logger.exception("Invalid user_id: %s", event.state_key)
-
- if domain not in joined_domains:
- sql = (
- "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
- )
+ txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
- txn.execute(sql, (event.room_id, domain))
-
- self.get_rooms_for_user.invalidate(target_user_id)
-
- @defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -110,41 +76,27 @@ class RoomMemberStore(SQLBaseStore):
Returns:
Deferred: Results in a MembershipEvent or None.
"""
- rows = yield self._get_members_by_dict({
- "e.room_id": room_id,
- "m.user_id": user_id,
- })
+ def f(txn):
+ events = self._get_members_events_txn(
+ txn,
+ room_id,
+ user_id=user_id,
+ )
- defer.returnValue(rows[0] if rows else None)
+ return events[0] if events else None
- def _get_room_member(self, txn, user_id, room_id):
- sql = (
- "SELECT e.* FROM events as e"
- " INNER JOIN room_memberships as m"
- " ON e.event_id = m.event_id"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE m.user_id = ? and e.room_id = ?"
- " LIMIT 1"
- )
- txn.execute(sql, (user_id, room_id))
- rows = self.cursor_to_dict(txn)
- if rows:
- return self._parse_events_txn(txn, rows)[0]
- else:
- return None
+ return self.runInteraction("get_room_member", f)
def get_users_in_room(self, room_id):
def f(txn):
- sql = (
- "SELECT m.user_id FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE m.membership = ? AND m.room_id = ?"
+
+ rows = self._get_members_rows_txn(
+ txn,
+ room_id=room_id,
+ membership=Membership.JOIN,
)
- txn.execute(sql, (Membership.JOIN, room_id))
- return [r[0] for r in txn.fetchall()]
+ return [r["user_id"] for r in rows]
return self.runInteraction("get_users_in_room", f)
def get_room_members(self, room_id, membership=None):
@@ -159,11 +111,14 @@ class RoomMemberStore(SQLBaseStore):
list of namedtuples representing the members in this room.
"""
- where = {"m.room_id": room_id}
- if membership:
- where["m.membership"] = membership
+ def f(txn):
+ return self._get_members_events_txn(
+ txn,
+ room_id,
+ membership=membership,
+ )
- return self._get_members_by_dict(where)
+ return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -199,7 +154,9 @@ class RoomMemberStore(SQLBaseStore):
"SELECT m.room_id, m.sender, m.membership"
" FROM room_memberships as m"
" INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
@@ -208,32 +165,59 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
+ @cached()
def get_joined_hosts_for_room(self, room_id):
- return self._simple_select_onecol(
- "room_hosts",
- {"room_id": room_id},
- "host"
+ return self.runInteraction(
+ "get_joined_hosts_for_room",
+ self._get_joined_hosts_for_room_txn,
+ room_id,
+ )
+
+ def _get_joined_hosts_for_room_txn(self, txn, room_id):
+ rows = self._get_members_rows_txn(
+ txn,
+ room_id, membership=Membership.JOIN
+ )
+
+ joined_domains = set(
+ UserID.from_string(r["user_id"]).domain
+ for r in rows
)
- def _get_members_by_dict(self, where_dict):
- clause = " AND ".join("%s = ?" % k for k in where_dict.keys())
- vals = where_dict.values()
- return self._get_members_query(clause, vals)
+ return joined_domains
def _get_members_query(self, where_clause, where_values):
return self.runInteraction(
- "get_members_query", self._get_members_query_txn,
+ "get_members_query", self._get_members_events_txn,
where_clause, where_values
)
- def _get_members_query_txn(self, txn, where_clause, where_values):
+ def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
+ rows = self._get_members_rows_txn(
+ txn,
+ room_id, membership, user_id,
+ )
+ return self._get_events_txn(txn, [r["event_id"] for r in rows])
+
+ def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
+ where_clause = "c.room_id = ?"
+ where_values = [room_id]
+
+ if membership:
+ where_clause += " AND m.membership = ?"
+ where_values.append(membership)
+
+ if user_id:
+ where_clause += " AND m.user_id = ?"
+ where_values.append(user_id)
+
sql = (
- "SELECT e.* FROM events as e "
- "INNER JOIN room_memberships as m "
- "ON e.event_id = m.event_id "
- "INNER JOIN current_state_events as c "
- "ON m.event_id = c.event_id "
- "WHERE %(where)s "
+ "SELECT m.* FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
+ " WHERE %(where)s"
) % {
"where": where_clause,
}
@@ -241,8 +225,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
- results = self._parse_events_txn(txn, rows)
- return results
+ return rows
@cached()
def get_rooms_for_user(self, user_id):
diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/schema/delta/12/v12.sql
index b87ef1fe79..878c36260a 100644
--- a/synapse/storage/schema/delta/12/v12.sql
+++ b/synapse/storage/schema/delta/12/v12.sql
@@ -17,26 +17,25 @@ CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
- CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+ UNIQUE (event_id)
);
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
- profile_tag varchar(32) NOT NULL,
- kind varchar(8) NOT NULL,
- app_id varchar(64) NOT NULL,
- app_display_name varchar(64) NOT NULL,
- device_display_name varchar(128) NOT NULL,
- pushkey blob NOT NULL,
- ts BIGINT NOT NULL,
- lang varchar(8),
- data blob,
+ profile_tag VARCHAR(32) NOT NULL,
+ kind VARCHAR(8) NOT NULL,
+ app_id VARCHAR(64) NOT NULL,
+ app_display_name VARCHAR(64) NOT NULL,
+ device_display_name VARCHAR(128) NOT NULL,
+ pushkey VARBINARY(512) NOT NULL,
+ ts BIGINT UNSIGNED NOT NULL,
+ lang VARCHAR(8),
+ data LONGBLOB,
last_token TEXT,
- last_success BIGINT,
- failing_since BIGINT,
- FOREIGN KEY(user_name) REFERENCES users(name),
+ last_success BIGINT UNSIGNED,
+ failing_since BIGINT UNSIGNED,
UNIQUE (app_id, pushkey)
);
@@ -55,13 +54,10 @@ CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
- filter_id INTEGER,
- filter_json TEXT,
- FOREIGN KEY(user_id) REFERENCES users(id)
+ filter_id BIGINT UNSIGNED,
+ filter_json LONGBLOB
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
- user_id, filter_id
+ user_id, filter_id
);
-
-PRAGMA user_version = 12;
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/schema/delta/13/v13.sql
index e491ad5aec..3265924013 100644
--- a/synapse/storage/schema/delta/13/v13.sql
+++ b/synapse/storage/schema/delta/13/v13.sql
@@ -19,16 +19,13 @@ CREATE TABLE IF NOT EXISTS application_services(
token TEXT,
hs_token TEXT,
sender TEXT,
- UNIQUE(token) ON CONFLICT ROLLBACK
+ UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT,
- as_id INTEGER NOT NULL,
+ as_id BIGINT UNSIGNED NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);
-
-
-
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 847b1c5b89..9f3a4dd4c5 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -1,3 +1,17 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import logging
diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/schema/delta/14/v14.sql
index 0212726448..1d09ad7a15 100644
--- a/synapse/storage/schema/delta/14/v14.sql
+++ b/synapse/storage/schema/delta/14/v14.sql
@@ -1,3 +1,17 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
CREATE TABLE IF NOT EXISTS push_rules_enable (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/schema/delta/15/appservice_txns.sql
new file mode 100644
index 0000000000..db2e720393
--- /dev/null
+++ b/synapse/storage/schema/delta/15/appservice_txns.sql
@@ -0,0 +1,31 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS application_services_state(
+ as_id TEXT PRIMARY KEY,
+ state VARCHAR(5),
+ last_txn INTEGER
+);
+
+CREATE TABLE IF NOT EXISTS application_services_txns(
+ as_id TEXT NOT NULL,
+ txn_id INTEGER NOT NULL,
+ event_ids TEXT NOT NULL,
+ UNIQUE(as_id, txn_id)
+);
+
+CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns (
+ as_id
+);
diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/schema/delta/15/presence_indices.sql
new file mode 100644
index 0000000000..6b8d0f1ca7
--- /dev/null
+++ b/synapse/storage/schema/delta/15/presence_indices.sql
@@ -0,0 +1,2 @@
+
+CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id);
diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/schema/delta/15/v15.sql
new file mode 100644
index 0000000000..f5b2a08ca4
--- /dev/null
+++ b/synapse/storage/schema/delta/15/v15.sql
@@ -0,0 +1,25 @@
+-- Drop, copy & recreate pushers table to change unique key
+-- Also add access_token column at the same time
+CREATE TABLE IF NOT EXISTS pushers2 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_name TEXT NOT NULL,
+ access_token INTEGER DEFAULT NULL,
+ profile_tag varchar(32) NOT NULL,
+ kind varchar(8) NOT NULL,
+ app_id varchar(64) NOT NULL,
+ app_display_name varchar(64) NOT NULL,
+ device_display_name varchar(128) NOT NULL,
+ pushkey blob NOT NULL,
+ ts BIGINT NOT NULL,
+ lang varchar(8),
+ data blob,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ FOREIGN KEY(user_name) REFERENCES users(name),
+ UNIQUE (app_id, pushkey, user_name)
+);
+INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since)
+ SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers;
+DROP TABLE pushers;
+ALTER TABLE pushers2 RENAME TO pushers;
diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/schema/delta/16/events_order_index.sql
new file mode 100644
index 0000000000..a48f215170
--- /dev/null
+++ b/synapse/storage/schema/delta/16/events_order_index.sql
@@ -0,0 +1,4 @@
+CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
+CREATE INDEX events_order_room ON events (
+ room_id, topological_ordering, stream_ordering
+);
diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/schema/delta/16/remote_media_cache_index.sql
new file mode 100644
index 0000000000..7a15265cb1
--- /dev/null
+++ b/synapse/storage/schema/delta/16/remote_media_cache_index.sql
@@ -0,0 +1,2 @@
+CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
+ ON remote_media_cache_thumbnails (media_id);
\ No newline at end of file
diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/schema/delta/16/remove_duplicates.sql
new file mode 100644
index 0000000000..65c97b5e2f
--- /dev/null
+++ b/synapse/storage/schema/delta/16/remove_duplicates.sql
@@ -0,0 +1,9 @@
+
+
+DELETE FROM event_to_state_groups WHERE state_group not in (
+ SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id
+);
+
+DELETE FROM event_to_state_groups WHERE rowid not in (
+ SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id
+);
diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/schema/delta/16/room_alias_index.sql
new file mode 100644
index 0000000000..f82486132b
--- /dev/null
+++ b/synapse/storage/schema/delta/16/room_alias_index.sql
@@ -0,0 +1,3 @@
+
+CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id);
+CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias);
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/schema/delta/16/unique_constraints.sql
new file mode 100644
index 0000000000..fecf11118c
--- /dev/null
+++ b/synapse/storage/schema/delta/16/unique_constraints.sql
@@ -0,0 +1,80 @@
+
+-- We can use SQLite features here, since other db support was only added in v16
+
+--
+DELETE FROM current_state_events WHERE rowid not in (
+ SELECT MIN(rowid) FROM current_state_events GROUP BY event_id
+);
+
+DROP INDEX IF EXISTS current_state_events_event_id;
+CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id);
+
+--
+DELETE FROM room_memberships WHERE rowid not in (
+ SELECT MIN(rowid) FROM room_memberships GROUP BY event_id
+);
+
+DROP INDEX IF EXISTS room_memberships_event_id;
+CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id);
+
+--
+DELETE FROM feedback WHERE rowid not in (
+ SELECT MIN(rowid) FROM feedback GROUP BY event_id
+);
+
+DROP INDEX IF EXISTS feedback_event_id;
+CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id);
+
+--
+DELETE FROM topics WHERE rowid not in (
+ SELECT MIN(rowid) FROM topics GROUP BY event_id
+);
+
+DROP INDEX IF EXISTS topics_event_id;
+CREATE UNIQUE INDEX topics_event_id ON topics(event_id);
+
+--
+DELETE FROM room_names WHERE rowid not in (
+ SELECT MIN(rowid) FROM room_names GROUP BY event_id
+);
+
+DROP INDEX IF EXISTS room_names_id;
+CREATE UNIQUE INDEX room_names_id ON room_names(event_id);
+
+--
+DELETE FROM presence WHERE rowid not in (
+ SELECT MIN(rowid) FROM presence GROUP BY user_id
+);
+
+DROP INDEX IF EXISTS presence_id;
+CREATE UNIQUE INDEX presence_id ON presence(user_id);
+
+--
+DELETE FROM presence_allow_inbound WHERE rowid not in (
+ SELECT MIN(rowid) FROM presence_allow_inbound
+ GROUP BY observed_user_id, observer_user_id
+);
+
+DROP INDEX IF EXISTS presence_allow_inbound_observers;
+CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound(
+ observed_user_id, observer_user_id
+);
+
+--
+DELETE FROM presence_list WHERE rowid not in (
+ SELECT MIN(rowid) FROM presence_list
+ GROUP BY user_id, observed_user_id
+);
+
+DROP INDEX IF EXISTS presence_list_observers;
+CREATE UNIQUE INDEX presence_list_observers ON presence_list(
+ user_id, observed_user_id
+);
+
+--
+DELETE FROM room_aliases WHERE rowid not in (
+ SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias
+);
+
+DROP INDEX IF EXISTS room_aliases_id;
+CREATE INDEX room_aliases_id ON room_aliases(room_id);
diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/schema/delta/16/users.sql
new file mode 100644
index 0000000000..cd0709250d
--- /dev/null
+++ b/synapse/storage/schema/delta/16/users.sql
@@ -0,0 +1,56 @@
+-- Convert `access_tokens`.user from rowids to user strings.
+-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW
+CREATE TABLE IF NOT EXISTS new_access_tokens(
+ id BIGINT UNSIGNED PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ device_id TEXT,
+ token TEXT NOT NULL,
+ last_used BIGINT UNSIGNED,
+ UNIQUE(token)
+);
+
+INSERT INTO new_access_tokens
+ SELECT a.id, u.name, a.device_id, a.token, a.last_used
+ FROM access_tokens as a
+ INNER JOIN users as u ON u.id = a.user_id;
+
+DROP TABLE access_tokens;
+
+ALTER TABLE new_access_tokens RENAME TO access_tokens;
+
+-- Remove ID column from `users` table
+CREATE TABLE IF NOT EXISTS new_users(
+ name TEXT,
+ password_hash TEXT,
+ creation_ts BIGINT UNSIGNED,
+ admin BOOL DEFAULT 0 NOT NULL,
+ UNIQUE(name)
+);
+
+INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users;
+
+DROP TABLE users;
+
+ALTER TABLE new_users RENAME TO users;
+
+
+-- Remove UNIQUE constraint from `user_ips` table
+CREATE TABLE IF NOT EXISTS new_user_ips (
+ user_id TEXT NOT NULL,
+ access_token TEXT NOT NULL,
+ device_id TEXT,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ last_seen BIGINT UNSIGNED NOT NULL
+);
+
+INSERT INTO new_user_ips
+ SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips;
+
+DROP TABLE user_ips;
+
+ALTER TABLE new_user_ips RENAME TO user_ips;
+
+CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id);
+CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip);
+
diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/schema/delta/17/drop_indexes.sql
new file mode 100644
index 0000000000..8eb3325a6b
--- /dev/null
+++ b/synapse/storage/schema/delta/17/drop_indexes.sql
@@ -0,0 +1,18 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS sent_transaction_dest;
+DROP INDEX IF EXISTS sent_transaction_sent;
+DROP INDEX IF EXISTS user_ips_user;
diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/schema/delta/17/server_keys.sql
new file mode 100644
index 0000000000..513c30a717
--- /dev/null
+++ b/synapse/storage/schema/delta/17/server_keys.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS server_keys_json (
+ server_name TEXT, -- Server name.
+ key_id TEXT, -- Requested key id.
+ from_server TEXT, -- Which server the keys were fetched from.
+ ts_added_ms INTEGER, -- When the keys were fetched
+ ts_valid_until_ms INTEGER, -- When this version of the keys exipires.
+ key_json bytea, -- JSON certificate for the remote server.
+ CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server)
+);
diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/schema/delta/17/user_threepids.sql
new file mode 100644
index 0000000000..c17715ac80
--- /dev/null
+++ b/synapse/storage/schema/delta/17/user_threepids.sql
@@ -0,0 +1,9 @@
+CREATE TABLE user_threepids (
+ user_id TEXT NOT NULL,
+ medium TEXT NOT NULL,
+ address TEXT NOT NULL,
+ validated_at BIGINT NOT NULL,
+ added_at BIGINT NOT NULL,
+ CONSTRAINT user_medium_address UNIQUE (user_id, medium, address)
+);
+CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
diff --git a/synapse/storage/schema/full_schemas/11/event_edges.sql b/synapse/storage/schema/full_schemas/11/event_edges.sql
index 1e766d6db2..f7020f7793 100644
--- a/synapse/storage/schema/full_schemas/11/event_edges.sql
+++ b/synapse/storage/schema/full_schemas/11/event_edges.sql
@@ -16,52 +16,52 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+ UNIQUE (event_id, room_id)
);
-CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
-CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+ UNIQUE (event_id, room_id)
);
-CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
-CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- is_state INTEGER NOT NULL,
- CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
+ is_state BOOL NOT NULL,
+ UNIQUE (event_id, prev_event_id, room_id, is_state)
);
-CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
-CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+CREATE INDEX ev_edges_id ON event_edges(event_id);
+CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL,
- CONSTRAINT uniqueness UNIQUE (room_id)
+ UNIQUE (room_id)
);
-CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+CREATE INDEX room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL,
destination TEXT NOT NULL,
- delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
- CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+ delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
+ UNIQUE (event_id, destination)
);
-CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+CREATE INDEX event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
@@ -69,21 +69,21 @@ CREATE TABLE IF NOT EXISTS state_forward_extremities(
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
- CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+ UNIQUE (event_id, room_id)
);
-CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+CREATE INDEX st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
-CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
+ UNIQUE (event_id, auth_id, room_id)
);
-CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
-CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);
\ No newline at end of file
+CREATE INDEX evauth_edges_id ON event_auth(event_id);
+CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);
diff --git a/synapse/storage/schema/full_schemas/11/event_signatures.sql b/synapse/storage/schema/full_schemas/11/event_signatures.sql
index c28c39c48a..636b2d3353 100644
--- a/synapse/storage/schema/full_schemas/11/event_signatures.sql
+++ b/synapse/storage/schema/full_schemas/11/event_signatures.sql
@@ -16,50 +16,40 @@
CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT,
algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+ hash bytea,
+ UNIQUE (event_id, algorithm)
);
-CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
- event_id
-);
+CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+ hash bytea,
+ UNIQUE (event_id, algorithm)
);
-CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
- event_id
-);
+CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_signatures (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
- signature BLOB,
- CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
+ signature bytea,
+ UNIQUE (event_id, signature_name, key_id)
);
-CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
- event_id
-);
+CREATE INDEX event_signatures_id ON event_signatures(event_id);
CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT,
prev_event_id TEXT,
algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (
- event_id, prev_event_id, algorithm
- )
+ hash bytea,
+ UNIQUE (event_id, prev_event_id, algorithm)
);
-CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
- event_id
-);
+CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
diff --git a/synapse/storage/schema/full_schemas/11/im.sql b/synapse/storage/schema/full_schemas/11/im.sql
index dd00c1cd2f..1901654ac2 100644
--- a/synapse/storage/schema/full_schemas/11/im.sql
+++ b/synapse/storage/schema/full_schemas/11/im.sql
@@ -15,7 +15,7 @@
CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT,
- topological_ordering INTEGER NOT NULL,
+ topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
@@ -23,26 +23,24 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
- depth INTEGER DEFAULT 0 NOT NULL,
- CONSTRAINT ev_uniq UNIQUE (event_id)
+ depth BIGINT DEFAULT 0 NOT NULL,
+ UNIQUE (event_id)
);
-CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id);
-CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
-CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
-CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
+CREATE INDEX events_stream_ordering ON events (stream_ordering);
+CREATE INDEX events_topological_ordering ON events (topological_ordering);
+CREATE INDEX events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- internal_metadata NOT NULL,
- json BLOB NOT NULL,
- CONSTRAINT ev_j_uniq UNIQUE (event_id)
+ internal_metadata TEXT NOT NULL,
+ json TEXT NOT NULL,
+ UNIQUE (event_id)
);
-CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
-CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
+CREATE INDEX event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events(
@@ -50,13 +48,13 @@ CREATE TABLE IF NOT EXISTS state_events(
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
- prev_state TEXT
+ prev_state TEXT,
+ UNIQUE (event_id)
);
-CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id);
-CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id);
-CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type);
-CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key);
+CREATE INDEX state_events_room_id ON state_events (room_id);
+CREATE INDEX state_events_type ON state_events (type);
+CREATE INDEX state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events(
@@ -64,13 +62,13 @@ CREATE TABLE IF NOT EXISTS current_state_events(
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
- CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE
+ UNIQUE (room_id, type, state_key)
);
-CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id);
-CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id);
-CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type);
-CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key);
+CREATE INDEX curr_events_event_id ON current_state_events (event_id);
+CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
+CREATE INDEX current_state_events_type ON current_state_events (type);
+CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL,
@@ -80,9 +78,9 @@ CREATE TABLE IF NOT EXISTS room_memberships(
membership TEXT NOT NULL
);
-CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id);
-CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id);
-CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id);
+CREATE INDEX room_memberships_event_id ON room_memberships (event_id);
+CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
+CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL,
@@ -98,8 +96,8 @@ CREATE TABLE IF NOT EXISTS topics(
topic TEXT NOT NULL
);
-CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id);
-CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id);
+CREATE INDEX topics_event_id ON topics(event_id);
+CREATE INDEX topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
@@ -107,19 +105,19 @@ CREATE TABLE IF NOT EXISTS room_names(
name TEXT NOT NULL
);
-CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id);
-CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id);
+CREATE INDEX room_names_event_id ON room_names(event_id);
+CREATE INDEX room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
- is_public INTEGER,
+ is_public BOOL,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,
- CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE
+ UNIQUE (room_id, host)
);
-CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id);
+CREATE INDEX room_hosts_room_id ON room_hosts (room_id);
diff --git a/synapse/storage/schema/full_schemas/11/keys.sql b/synapse/storage/schema/full_schemas/11/keys.sql
index a9e0a4fe0d..afc142045e 100644
--- a/synapse/storage/schema/full_schemas/11/keys.sql
+++ b/synapse/storage/schema/full_schemas/11/keys.sql
@@ -16,16 +16,16 @@ CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
- ts_added_ms INTEGER, -- When the certifcate was added.
- tls_certificate BLOB, -- DER encoded x509 certificate.
- CONSTRAINT uniqueness UNIQUE (server_name, fingerprint)
+ ts_added_ms BIGINT, -- When the certifcate was added.
+ tls_certificate bytea, -- DER encoded x509 certificate.
+ UNIQUE (server_name, fingerprint)
);
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form.
- ts_added_ms INTEGER, -- When the key was added.
- verify_key BLOB, -- NACL verification key.
- CONSTRAINT uniqueness UNIQUE (server_name, key_id)
+ ts_added_ms BIGINT, -- When the key was added.
+ verify_key bytea, -- NACL verification key.
+ UNIQUE (server_name, key_id)
);
diff --git a/synapse/storage/schema/full_schemas/11/media_repository.sql b/synapse/storage/schema/full_schemas/11/media_repository.sql
index afdf48cbfb..e927e581d1 100644
--- a/synapse/storage/schema/full_schemas/11/media_repository.sql
+++ b/synapse/storage/schema/full_schemas/11/media_repository.sql
@@ -17,10 +17,10 @@ CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
- created_ts INTEGER, -- When the content was uploaded in ms.
+ created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
- CONSTRAINT uniqueness UNIQUE (media_id)
+ UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
@@ -30,23 +30,23 @@ CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
- CONSTRAINT uniqueness UNIQUE (
+ UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
-CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
+CREATE INDEX local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
- created_ts INTEGER, -- When the content was uploaded in ms.
+ created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
- CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
+ UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
@@ -58,11 +58,8 @@ CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
- CONSTRAINT uniqueness UNIQUE (
+ UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
- thumbnail_type, thumbnail_type
- )
+ thumbnail_type
+ )
);
-
-CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
- ON local_media_repository_thumbnails (media_id);
diff --git a/synapse/storage/schema/full_schemas/11/presence.sql b/synapse/storage/schema/full_schemas/11/presence.sql
index f9f8db9697..d8d82e9fe3 100644
--- a/synapse/storage/schema/full_schemas/11/presence.sql
+++ b/synapse/storage/schema/full_schemas/11/presence.sql
@@ -13,26 +13,23 @@
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS presence(
- user_id INTEGER NOT NULL,
- state INTEGER,
+ user_id TEXT NOT NULL,
+ state VARCHAR(20),
status_msg TEXT,
- mtime INTEGER, -- miliseconds since last state change
- FOREIGN KEY(user_id) REFERENCES users(id)
+ mtime BIGINT -- miliseconds since last state change
);
-- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound(
- observed_user_id INTEGER NOT NULL,
- observer_user_id TEXT, -- a UserID,
- FOREIGN KEY(observed_user_id) REFERENCES users(id)
+ observed_user_id TEXT NOT NULL,
+ observer_user_id TEXT NOT NULL -- a UserID,
);
-- For each of /my/ users (watcher), which possibly-remote users are they
-- watching?
CREATE TABLE IF NOT EXISTS presence_list(
- user_id INTEGER NOT NULL,
- observed_user_id TEXT, -- a UserID,
- accepted BOOLEAN,
- FOREIGN KEY(user_id) REFERENCES users(id)
+ user_id TEXT NOT NULL,
+ observed_user_id TEXT NOT NULL, -- a UserID,
+ accepted BOOLEAN NOT NULL
);
diff --git a/synapse/storage/schema/full_schemas/11/profiles.sql b/synapse/storage/schema/full_schemas/11/profiles.sql
index f06a528b4d..26e4204437 100644
--- a/synapse/storage/schema/full_schemas/11/profiles.sql
+++ b/synapse/storage/schema/full_schemas/11/profiles.sql
@@ -13,8 +13,7 @@
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS profiles(
- user_id INTEGER NOT NULL,
+ user_id TEXT NOT NULL,
displayname TEXT,
- avatar_url TEXT,
- FOREIGN KEY(user_id) REFERENCES users(id)
+ avatar_url TEXT
);
diff --git a/synapse/storage/schema/full_schemas/11/redactions.sql b/synapse/storage/schema/full_schemas/11/redactions.sql
index 5011d95db8..69621955d4 100644
--- a/synapse/storage/schema/full_schemas/11/redactions.sql
+++ b/synapse/storage/schema/full_schemas/11/redactions.sql
@@ -15,8 +15,8 @@
CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL,
redacts TEXT NOT NULL,
- CONSTRAINT ev_uniq UNIQUE (event_id)
+ UNIQUE (event_id)
);
-CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id);
-CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts);
+CREATE INDEX redactions_event_id ON redactions (event_id);
+CREATE INDEX redactions_redacts ON redactions (redacts);
diff --git a/synapse/storage/schema/full_schemas/11/room_aliases.sql b/synapse/storage/schema/full_schemas/11/room_aliases.sql
index 0d2df01603..5027b1e3f6 100644
--- a/synapse/storage/schema/full_schemas/11/room_aliases.sql
+++ b/synapse/storage/schema/full_schemas/11/room_aliases.sql
@@ -22,6 +22,3 @@ CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL,
server TEXT NOT NULL
);
-
-
-
diff --git a/synapse/storage/schema/full_schemas/11/state.sql b/synapse/storage/schema/full_schemas/11/state.sql
index 1fe8f1e430..ffd164ab71 100644
--- a/synapse/storage/schema/full_schemas/11/state.sql
+++ b/synapse/storage/schema/full_schemas/11/state.sql
@@ -30,18 +30,11 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group INTEGER NOT NULL,
- CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id)
+ UNIQUE (event_id)
);
-CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);
+CREATE INDEX state_groups_id ON state_groups(id);
-CREATE INDEX IF NOT EXISTS state_groups_state_id ON state_groups_state(
- state_group
-);
-CREATE INDEX IF NOT EXISTS state_groups_state_tuple ON state_groups_state(
- room_id, type, state_key
-);
-
-CREATE INDEX IF NOT EXISTS event_to_state_groups_id ON event_to_state_groups(
- event_id
-);
\ No newline at end of file
+CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
+CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
+CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);
diff --git a/synapse/storage/schema/full_schemas/11/transactions.sql b/synapse/storage/schema/full_schemas/11/transactions.sql
index 2d30f99b06..cc5b54f5aa 100644
--- a/synapse/storage/schema/full_schemas/11/transactions.sql
+++ b/synapse/storage/schema/full_schemas/11/transactions.sql
@@ -14,17 +14,16 @@
*/
-- Stores what transaction ids we have received and what our response was
CREATE TABLE IF NOT EXISTS received_transactions(
- transaction_id TEXT,
- origin TEXT,
- ts INTEGER,
+ transaction_id TEXT,
+ origin TEXT,
+ ts BIGINT,
response_code INTEGER,
- response_json TEXT,
- has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx
- CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE
+ response_json bytea,
+ has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx
+ UNIQUE (transaction_id, origin)
);
-CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin);
-CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
+CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
@@ -35,17 +34,14 @@ CREATE TABLE IF NOT EXISTS sent_transactions(
destination TEXT,
response_code INTEGER DEFAULT 0,
response_json TEXT,
- ts INTEGER
+ ts BIGINT
);
-CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
-CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(
- destination
-);
-CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id);
+CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
+CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent.
-CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code);
+CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only.
@@ -56,13 +52,12 @@ CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
pdu_origin TEXT
);
-CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
-CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
-CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
+CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
+CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
- retry_last_ts INTEGER,
+ retry_last_ts BIGINT,
retry_interval INTEGER
);
diff --git a/synapse/storage/schema/full_schemas/11/users.sql b/synapse/storage/schema/full_schemas/11/users.sql
index 08ccfdac0a..eec3da3c35 100644
--- a/synapse/storage/schema/full_schemas/11/users.sql
+++ b/synapse/storage/schema/full_schemas/11/users.sql
@@ -16,19 +16,18 @@ CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
password_hash TEXT,
- creation_ts INTEGER,
- admin BOOL DEFAULT 0 NOT NULL,
- UNIQUE(name) ON CONFLICT ROLLBACK
+ creation_ts BIGINT,
+ admin SMALLINT DEFAULT 0 NOT NULL,
+ UNIQUE(name)
);
CREATE TABLE IF NOT EXISTS access_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
+ user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
- last_used INTEGER,
- FOREIGN KEY(user_id) REFERENCES users(id),
- UNIQUE(token) ON CONFLICT ROLLBACK
+ last_used BIGINT,
+ UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS user_ips (
@@ -37,9 +36,8 @@ CREATE TABLE IF NOT EXISTS user_ips (
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
- last_seen INTEGER NOT NULL,
- CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE
+ last_seen BIGINT NOT NULL,
+ UNIQUE (user, access_token, ip, user_agent)
);
-CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user);
-
+CREATE INDEX user_ips_user ON user_ips(user);
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/schema/full_schemas/16/application_services.sql
new file mode 100644
index 0000000000..d382d63fbd
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/application_services.sql
@@ -0,0 +1,48 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS application_services(
+ id BIGINT PRIMARY KEY,
+ url TEXT,
+ token TEXT,
+ hs_token TEXT,
+ sender TEXT,
+ UNIQUE(token)
+);
+
+CREATE TABLE IF NOT EXISTS application_services_regex(
+ id BIGINT PRIMARY KEY,
+ as_id BIGINT NOT NULL,
+ namespace INTEGER, /* enum[room_id|room_alias|user_id] */
+ regex TEXT,
+ FOREIGN KEY(as_id) REFERENCES application_services(id)
+);
+
+CREATE TABLE IF NOT EXISTS application_services_state(
+ as_id TEXT PRIMARY KEY,
+ state VARCHAR(5),
+ last_txn INTEGER
+);
+
+CREATE TABLE IF NOT EXISTS application_services_txns(
+ as_id TEXT NOT NULL,
+ txn_id INTEGER NOT NULL,
+ event_ids TEXT NOT NULL,
+ UNIQUE(as_id, txn_id)
+);
+
+CREATE INDEX application_services_txns_id ON application_services_txns (
+ as_id
+);
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql
new file mode 100644
index 0000000000..f7020f7793
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/event_edges.sql
@@ -0,0 +1,89 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ UNIQUE (event_id, room_id)
+);
+
+CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ UNIQUE (event_id, room_id)
+);
+
+CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+ event_id TEXT NOT NULL,
+ prev_event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_state BOOL NOT NULL,
+ UNIQUE (event_id, prev_event_id, room_id, is_state)
+);
+
+CREATE INDEX ev_edges_id ON event_edges(event_id);
+CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+ room_id TEXT NOT NULL,
+ min_depth INTEGER NOT NULL,
+ UNIQUE (room_id)
+);
+
+CREATE INDEX room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+ event_id TEXT NOT NULL,
+ destination TEXT NOT NULL,
+ delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
+ UNIQUE (event_id, destination)
+);
+
+CREATE INDEX event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ UNIQUE (event_id, room_id)
+);
+
+CREATE INDEX st_extrem_keys ON state_forward_extremities(
+ room_id, type, state_key
+);
+CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_auth(
+ event_id TEXT NOT NULL,
+ auth_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ UNIQUE (event_id, auth_id, room_id)
+);
+
+CREATE INDEX evauth_edges_id ON event_auth(event_id);
+CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/schema/full_schemas/16/event_signatures.sql
new file mode 100644
index 0000000000..636b2d3353
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/event_signatures.sql
@@ -0,0 +1,55 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash bytea,
+ UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash bytea,
+ UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_signatures (
+ event_id TEXT,
+ signature_name TEXT,
+ key_id TEXT,
+ signature bytea,
+ UNIQUE (event_id, signature_name, key_id)
+);
+
+CREATE INDEX event_signatures_id ON event_signatures(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+ event_id TEXT,
+ prev_event_id TEXT,
+ algorithm TEXT,
+ hash bytea,
+ UNIQUE (event_id, prev_event_id, algorithm)
+);
+
+CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql
new file mode 100644
index 0000000000..576653a3c9
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/im.sql
@@ -0,0 +1,128 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS events(
+ stream_ordering INTEGER PRIMARY KEY,
+ topological_ordering BIGINT NOT NULL,
+ event_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ content TEXT NOT NULL,
+ unrecognized_keys TEXT,
+ processed BOOL NOT NULL,
+ outlier BOOL NOT NULL,
+ depth BIGINT DEFAULT 0 NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX events_stream_ordering ON events (stream_ordering);
+CREATE INDEX events_topological_ordering ON events (topological_ordering);
+CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
+CREATE INDEX events_room_id ON events (room_id);
+CREATE INDEX events_order_room ON events (
+ room_id, topological_ordering, stream_ordering
+);
+
+
+CREATE TABLE IF NOT EXISTS event_json(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ internal_metadata TEXT NOT NULL,
+ json TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX event_json_room_id ON event_json(room_id);
+
+
+CREATE TABLE IF NOT EXISTS state_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ prev_state TEXT,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX state_events_room_id ON state_events (room_id);
+CREATE INDEX state_events_type ON state_events (type);
+CREATE INDEX state_events_state_key ON state_events (state_key);
+
+
+CREATE TABLE IF NOT EXISTS current_state_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ UNIQUE (event_id),
+ UNIQUE (room_id, type, state_key)
+);
+
+CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
+CREATE INDEX current_state_events_type ON current_state_events (type);
+CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
+
+CREATE TABLE IF NOT EXISTS room_memberships(
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ sender TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ membership TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
+CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
+
+CREATE TABLE IF NOT EXISTS feedback(
+ event_id TEXT NOT NULL,
+ feedback_type TEXT,
+ target_event_id TEXT,
+ sender TEXT,
+ room_id TEXT,
+ UNIQUE (event_id)
+);
+
+CREATE TABLE IF NOT EXISTS topics(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ topic TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX topics_room_id ON topics(room_id);
+
+CREATE TABLE IF NOT EXISTS room_names(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ name TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX room_names_room_id ON room_names(room_id);
+
+CREATE TABLE IF NOT EXISTS rooms(
+ room_id TEXT PRIMARY KEY NOT NULL,
+ is_public BOOL,
+ creator TEXT
+);
+
+CREATE TABLE IF NOT EXISTS room_hosts(
+ room_id TEXT NOT NULL,
+ host TEXT NOT NULL,
+ UNIQUE (room_id, host)
+);
+
+CREATE INDEX room_hosts_room_id ON room_hosts (room_id);
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/schema/full_schemas/16/keys.sql
new file mode 100644
index 0000000000..afc142045e
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/keys.sql
@@ -0,0 +1,31 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS server_tls_certificates(
+ server_name TEXT, -- Server name.
+ fingerprint TEXT, -- Certificate fingerprint.
+ from_server TEXT, -- Which key server the certificate was fetched from.
+ ts_added_ms BIGINT, -- When the certifcate was added.
+ tls_certificate bytea, -- DER encoded x509 certificate.
+ UNIQUE (server_name, fingerprint)
+);
+
+CREATE TABLE IF NOT EXISTS server_signature_keys(
+ server_name TEXT, -- Server name.
+ key_id TEXT, -- Key version.
+ from_server TEXT, -- Which key server the key was fetched form.
+ ts_added_ms BIGINT, -- When the key was added.
+ verify_key bytea, -- NACL verification key.
+ UNIQUE (server_name, key_id)
+);
diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/schema/full_schemas/16/media_repository.sql
new file mode 100644
index 0000000000..dacbda40ca
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/media_repository.sql
@@ -0,0 +1,68 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS local_media_repository (
+ media_id TEXT, -- The id used to refer to the media.
+ media_type TEXT, -- The MIME-type of the media.
+ media_length INTEGER, -- Length of the media in bytes.
+ created_ts BIGINT, -- When the content was uploaded in ms.
+ upload_name TEXT, -- The name the media was uploaded with.
+ user_id TEXT, -- The user who uploaded the file.
+ UNIQUE (media_id)
+);
+
+CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
+ media_id TEXT, -- The id used to refer to the media.
+ thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
+ thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
+ thumbnail_type TEXT, -- The MIME-type of the thumbnail.
+ thumbnail_method TEXT, -- The method used to make the thumbnail.
+ thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
+ UNIQUE (
+ media_id, thumbnail_width, thumbnail_height, thumbnail_type
+ )
+);
+
+CREATE INDEX local_media_repository_thumbnails_media_id
+ ON local_media_repository_thumbnails (media_id);
+
+CREATE TABLE IF NOT EXISTS remote_media_cache (
+ media_origin TEXT, -- The remote HS the media came from.
+ media_id TEXT, -- The id used to refer to the media on that server.
+ media_type TEXT, -- The MIME-type of the media.
+ created_ts BIGINT, -- When the content was uploaded in ms.
+ upload_name TEXT, -- The name the media was uploaded with.
+ media_length INTEGER, -- Length of the media in bytes.
+ filesystem_id TEXT, -- The name used to store the media on disk.
+ UNIQUE (media_origin, media_id)
+);
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
+ media_origin TEXT, -- The remote HS the media came from.
+ media_id TEXT, -- The id used to refer to the media.
+ thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
+ thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
+ thumbnail_method TEXT, -- The method used to make the thumbnail
+ thumbnail_type TEXT, -- The MIME-type of the thumbnail.
+ thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
+ filesystem_id TEXT, -- The name used to store the media on disk.
+ UNIQUE (
+ media_origin, media_id, thumbnail_width, thumbnail_height,
+ thumbnail_type
+ )
+);
+
+CREATE INDEX remote_media_cache_thumbnails_media_id
+ ON remote_media_cache_thumbnails (media_id);
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql
new file mode 100644
index 0000000000..80088413ba
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/presence.sql
@@ -0,0 +1,40 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS presence(
+ user_id TEXT NOT NULL,
+ state VARCHAR(20),
+ status_msg TEXT,
+ mtime BIGINT, -- miliseconds since last state change
+ UNIQUE (user_id)
+);
+
+-- For each of /my/ users which possibly-remote users are allowed to see their
+-- presence state
+CREATE TABLE IF NOT EXISTS presence_allow_inbound(
+ observed_user_id TEXT NOT NULL,
+ observer_user_id TEXT NOT NULL, -- a UserID,
+ UNIQUE (observed_user_id, observer_user_id)
+);
+
+-- For each of /my/ users (watcher), which possibly-remote users are they
+-- watching?
+CREATE TABLE IF NOT EXISTS presence_list(
+ user_id TEXT NOT NULL,
+ observed_user_id TEXT NOT NULL, -- a UserID,
+ accepted BOOLEAN NOT NULL,
+ UNIQUE (user_id, observed_user_id)
+);
+
+CREATE INDEX presence_list_user_id ON presence_list (user_id);
diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/schema/full_schemas/16/profiles.sql
new file mode 100644
index 0000000000..934be86520
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/profiles.sql
@@ -0,0 +1,20 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS profiles(
+ user_id TEXT NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ UNIQUE(user_id)
+);
diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/schema/full_schemas/16/push.sql
new file mode 100644
index 0000000000..9387f920f0
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/push.sql
@@ -0,0 +1,74 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS rejections(
+ event_id TEXT NOT NULL,
+ reason TEXT NOT NULL,
+ last_check TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+-- Push notification endpoints that users have configured
+CREATE TABLE IF NOT EXISTS pushers (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag VARCHAR(32) NOT NULL,
+ kind VARCHAR(8) NOT NULL,
+ app_id VARCHAR(64) NOT NULL,
+ app_display_name VARCHAR(64) NOT NULL,
+ device_display_name VARCHAR(128) NOT NULL,
+ pushkey bytea NOT NULL,
+ ts BIGINT NOT NULL,
+ lang VARCHAR(8),
+ data bytea,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey)
+);
+
+CREATE TABLE IF NOT EXISTS push_rules (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ priority_class SMALLINT NOT NULL,
+ priority INTEGER NOT NULL DEFAULT 0,
+ conditions TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ UNIQUE(user_name, rule_id)
+);
+
+CREATE INDEX push_rules_user_name on push_rules (user_name);
+
+CREATE TABLE IF NOT EXISTS user_filters(
+ user_id TEXT,
+ filter_id BIGINT,
+ filter_json bytea
+);
+
+CREATE INDEX user_filters_by_user_id_filter_id ON user_filters(
+ user_id, filter_id
+);
+
+CREATE TABLE IF NOT EXISTS push_rules_enable (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ enabled SMALLINT,
+ UNIQUE(user_name, rule_id)
+);
+
+CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name);
diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/schema/full_schemas/16/redactions.sql
new file mode 100644
index 0000000000..69621955d4
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/redactions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS redactions (
+ event_id TEXT NOT NULL,
+ redacts TEXT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX redactions_event_id ON redactions (event_id);
+CREATE INDEX redactions_redacts ON redactions (redacts);
diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/schema/full_schemas/16/room_aliases.sql
new file mode 100644
index 0000000000..412bb97fad
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/room_aliases.sql
@@ -0,0 +1,29 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS room_aliases(
+ room_alias TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ UNIQUE (room_alias)
+);
+
+CREATE INDEX room_aliases_id ON room_aliases(room_id);
+
+CREATE TABLE IF NOT EXISTS room_alias_servers(
+ room_alias TEXT NOT NULL,
+ server TEXT NOT NULL
+);
+
+CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias);
diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/schema/full_schemas/16/state.sql
new file mode 100644
index 0000000000..705cac6ce9
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/state.sql
@@ -0,0 +1,40 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+ id BIGINT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+ state_group BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+ event_id TEXT NOT NULL,
+ state_group BIGINT NOT NULL,
+ UNIQUE (event_id)
+);
+
+CREATE INDEX state_groups_id ON state_groups(id);
+
+CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
+CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
+CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);
diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/schema/full_schemas/16/transactions.sql
new file mode 100644
index 0000000000..1ab77cdb63
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/transactions.sql
@@ -0,0 +1,63 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Stores what transaction ids we have received and what our response was
+CREATE TABLE IF NOT EXISTS received_transactions(
+ transaction_id TEXT,
+ origin TEXT,
+ ts BIGINT,
+ response_code INTEGER,
+ response_json bytea,
+ has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx
+ UNIQUE (transaction_id, origin)
+);
+
+CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
+
+
+-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
+-- since referenced the transaction in another outgoing transaction
+CREATE TABLE IF NOT EXISTS sent_transactions(
+ id BIGINT PRIMARY KEY, -- This is used to apply insertion ordering
+ transaction_id TEXT,
+ destination TEXT,
+ response_code INTEGER DEFAULT 0,
+ response_json TEXT,
+ ts BIGINT
+);
+
+CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
+CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
+-- So that we can do an efficient look up of all transactions that have yet to be successfully
+-- sent.
+CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
+
+
+-- For sent transactions only.
+CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
+ transaction_id INTEGER,
+ destination TEXT,
+ pdu_id TEXT,
+ pdu_origin TEXT,
+ UNIQUE (transaction_id, destination)
+);
+
+CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
+
+-- To track destination health
+CREATE TABLE IF NOT EXISTS destinations(
+ destination TEXT PRIMARY KEY,
+ retry_last_ts BIGINT,
+ retry_interval INTEGER
+);
diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/schema/full_schemas/16/users.sql
new file mode 100644
index 0000000000..d2fa3122da
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/16/users.sql
@@ -0,0 +1,42 @@
+/* Copyright 2014, 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS users(
+ name TEXT,
+ password_hash TEXT,
+ creation_ts BIGINT,
+ admin SMALLINT DEFAULT 0 NOT NULL,
+ UNIQUE(name)
+);
+
+CREATE TABLE IF NOT EXISTS access_tokens(
+ id BIGINT PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ device_id TEXT,
+ token TEXT NOT NULL,
+ last_used BIGINT,
+ UNIQUE(token)
+);
+
+CREATE TABLE IF NOT EXISTS user_ips (
+ user_id TEXT NOT NULL,
+ access_token TEXT NOT NULL,
+ device_id TEXT,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ last_seen BIGINT NOT NULL
+);
+
+CREATE INDEX user_ips_user ON user_ips(user_id);
+CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip);
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index 0431e2d051..d682608aa0 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -14,17 +14,14 @@
*/
CREATE TABLE IF NOT EXISTS schema_version(
- Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row.
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
version INTEGER NOT NULL,
upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema.
- CONSTRAINT schema_version_lock_x CHECK (Lock='X')
- CONSTRAINT schema_version_lock_uniq UNIQUE (Lock)
+ CHECK (Lock='X')
);
CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL,
file TEXT NOT NULL,
- CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE
+ UNIQUE(version, file)
);
-
-CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index d0d53770f2..f051828630 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -56,7 +56,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
- or_ignore=True,
)
def get_event_reference_hashes(self, event_ids):
@@ -100,7 +99,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
- return dict(txn.fetchall())
+ return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
@@ -119,7 +118,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
- or_ignore=True,
)
def _get_event_signatures_txn(self, txn, event_id):
@@ -164,7 +162,6 @@ class SignatureStore(SQLBaseStore):
"key_id": key_id,
"signature": buffer(signature_bytes),
},
- or_ignore=True,
)
def _get_prev_event_hashes_txn(self, txn, event_id):
@@ -198,5 +195,4 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
- or_ignore=True,
)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 456e4bd45d..dbc0e49c1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,6 +15,10 @@
from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from synapse.util.stringutils import random_string
+
import logging
logger = logging.getLogger(__name__)
@@ -89,29 +93,31 @@ class StateStore(SQLBaseStore):
state_group = context.state_group
if not state_group:
- state_group = self._simple_insert_txn(
+ state_group = self._state_groups_id_gen.get_next_txn(txn)
+ self._simple_insert_txn(
txn,
table="state_groups",
values={
+ "id": state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
- or_ignore=True,
)
- for state in state_events.values():
- self._simple_insert_txn(
- txn,
- table="state_groups_state",
- values={
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
"state_group": state_group,
"room_id": state.room_id,
"type": state.type,
"state_key": state.state_key,
"event_id": state.event_id,
- },
- or_ignore=True,
- )
+ }
+ for state in state_events.values()
+ ],
+ )
self._simple_insert_txn(
txn,
@@ -120,5 +126,33 @@ class StateStore(SQLBaseStore):
"state_group": state_group,
"event_id": event.event_id,
},
- or_replace=True,
)
+
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ def f(txn):
+ sql = (
+ "SELECT event_id FROM current_state_events"
+ " WHERE room_id = ? "
+ )
+
+ if event_type and state_key is not None:
+ sql += " AND type = ? AND state_key = ? "
+ args = (room_id, event_type, state_key)
+ elif event_type:
+ sql += " AND type = ?"
+ args = (room_id, event_type)
+ else:
+ args = (room_id, )
+
+ txn.execute(sql, args)
+ results = self.cursor_to_dict(txn)
+
+ return self._parse_events_txn(txn, results)
+
+ events = yield self.runInteraction("get_current_state", f)
+ defer.returnValue(events)
+
+
+def _make_group_id(clock):
+ return str(int(clock.time_msec())) + random_string(5)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 09bc522210..280d4ad605 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering")
else:
- return "(%d < %s OR (%d == %s AND %d < %s))" % (
+ return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
@@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering")
else:
- return "(%d > %s OR (%d == %s AND %d >= %s))" % (
+ return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
@@ -149,7 +149,8 @@ class StreamStore(SQLBaseStore):
# select all the events between from/to with a sensible limit
sql = (
"SELECT e.event_id, e.room_id, e.type, s.state_key, "
- "e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
+ "e.stream_ordering FROM events AS e "
+ "LEFT JOIN state_events as s ON "
"e.event_id = s.event_id "
"WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
@@ -214,8 +215,9 @@ class StreamStore(SQLBaseStore):
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
- "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
- "WHERE m.user_id = ? AND m.membership = 'join'"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id AND c.state_key = m.user_id"
+ " WHERE m.user_id = ? AND m.membership = 'join'"
)
# We also want to get any membership events about that user, e.g.
@@ -240,7 +242,7 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
- "(e.outlier = 0 AND (room_id IN (%(current)s)) OR "
+ "(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
@@ -251,7 +253,7 @@ class StreamStore(SQLBaseStore):
}
def f(txn):
- txn.execute(sql, (user_id, user_id, from_id.stream, to_id.stream,))
+ txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn)
@@ -283,7 +285,7 @@ class StreamStore(SQLBaseStore):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
- args = [room_id]
+ args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = _StreamToken.parse(from_key).upper_bound()
@@ -307,7 +309,7 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT * FROM events"
- " WHERE outlier = 0 AND room_id = ? AND %(bounds)s"
+ " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s"
) % {
@@ -358,7 +360,7 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
- " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0"
+ " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
@@ -368,17 +370,17 @@ class StreamStore(SQLBaseStore):
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
" WHERE room_id = ? AND stream_ordering > ?"
- " AND stream_ordering <= ? AND outlier = 0"
+ " AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
def get_recent_events_for_room_txn(txn):
if from_token is None:
- txn.execute(sql, (room_id, end_token.stream, limit,))
+ txn.execute(sql, (room_id, end_token.stream, False, limit,))
else:
txn.execute(sql, (
- room_id, from_token.stream, end_token.stream, limit
+ room_id, from_token.stream, end_token.stream, False, limit
))
rows = self.cursor_to_dict(txn)
@@ -413,26 +415,23 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn
)
+ @defer.inlineCallbacks
def get_room_events_max_id(self):
- return self.runInteraction(
- "get_room_events_max_id",
- self._get_room_events_max_id_txn
- )
+ token = yield self._stream_id_gen.get_max_token(self)
+ defer.returnValue("s%d" % (token,))
- def _get_room_events_max_id_txn(self, txn):
- txn.execute(
- "SELECT MAX(stream_ordering) as m FROM events"
+ @defer.inlineCallbacks
+ def _get_min_token(self):
+ row = yield self._execute(
+ "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
- res = self.cursor_to_dict(txn)
-
- logger.debug("get_room_events_max_id: %s", res)
+ self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
+ self.min_token = min(self.min_token, -1)
- if not res or not res[0] or not res[0]["m"]:
- return "s0"
+ logger.debug("min_token is: %s", self.min_token)
- key = res[0]["m"]
- return "s%d" % (key,)
+ defer.returnValue(self.min_token)
@staticmethod
def _set_before_and_after(events, rows):
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 0b8a3b7a07..624da4a9dc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, Table, cached
+from ._base import SQLBaseStore, cached
from collections import namedtuple
+from syutil.jsonutil import encode_canonical_json
import logging
logger = logging.getLogger(__name__)
@@ -46,15 +47,19 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- where_clause = "transaction_id = ? AND origin = ?"
- query = ReceivedTransactionsTable.select_statement(where_clause)
-
- txn.execute(query, (transaction_id, origin))
-
- results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+ result = self._simple_select_one_txn(
+ txn,
+ table=ReceivedTransactionsTable.table_name,
+ keyvalues={
+ "transaction_id": transaction_id,
+ "origin": origin,
+ },
+ retcols=ReceivedTransactionsTable.fields,
+ allow_none=True,
+ )
- if results and results[0].response_code:
- return (results[0].response_code, results[0].response_json)
+ if result and result.response_code:
+ return result["response_code"], result["response_json"]
else:
return None
@@ -72,22 +77,18 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self.runInteraction(
- "set_received_txn_response",
- self._set_received_txn_response,
- transaction_id, origin, code, response_dict
+ return self._simple_insert(
+ table=ReceivedTransactionsTable.table_name,
+ values={
+ "transaction_id": transaction_id,
+ "origin": origin,
+ "response_code": code,
+ "response_json": buffer(encode_canonical_json(response_dict)),
+ },
+ or_ignore=True,
+ desc="set_received_txn_response",
)
- def _set_received_txn_response(self, txn, transaction_id, origin, code,
- response_json):
- query = (
- "UPDATE %s "
- "SET response_code = ?, response_json = ? "
- "WHERE transaction_id = ? AND origin = ?"
- ) % ReceivedTransactionsTable.table_name
-
- txn.execute(query, (code, response_json, transaction_id, origin))
-
def prep_send_transaction(self, transaction_id, destination,
origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
@@ -114,41 +115,38 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts):
+ next_id = self._transaction_id_gen.get_next_txn(txn)
+
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
# we can simply take the last one.
- query = "%s ORDER BY id DESC LIMIT 1" % (
- SentTransactions.select_statement("destination = ?"),
- )
+ query = (
+ "SELECT * FROM sent_transactions"
+ " WHERE destination = ?"
+ " ORDER BY id DESC LIMIT 1"
+ )
- results = txn.execute(query, (destination,))
- results = SentTransactions.decode_results(results)
+ txn.execute(query, (destination,))
+ results = self.cursor_to_dict(txn)
- prev_txns = [r.transaction_id for r in results]
+ prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table.
- query = SentTransactions.insert_statement()
- txn.execute(query, SentTransactions.EntryType(
- None,
- transaction_id=transaction_id,
- destination=destination,
- ts=origin_server_ts,
- response_code=0,
- response_json=None
- ))
-
- # Update the tx id -> pdu id mapping
-
- # values = [
- # (transaction_id, destination, pdu[0], pdu[1])
- # for pdu in pdu_list
- # ]
- #
- # logger.debug("Inserting: %s", repr(values))
- #
- # query = TransactionsToPduTable.insert_statement()
- # txn.executemany(query, values)
+ self._simple_insert_txn(
+ txn,
+ table=SentTransactions.table_name,
+ values={
+ "id": next_id,
+ "transaction_id": transaction_id,
+ "destination": destination,
+ "ts": origin_server_ts,
+ "response_code": 0,
+ "response_json": None,
+ }
+ )
+
+ # TODO Update the tx id -> pdu id mapping
return prev_txns
@@ -164,18 +162,24 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction(
"delivered_txn",
self._delivered_txn,
- transaction_id, destination, code, response_dict
+ transaction_id, destination, code,
+ buffer(encode_canonical_json(response_dict)),
)
- def _delivered_txn(cls, txn, transaction_id, destination,
+ def _delivered_txn(self, txn, transaction_id, destination,
code, response_json):
- query = (
- "UPDATE %s "
- "SET response_code = ?, response_json = ? "
- "WHERE transaction_id = ? AND destination = ?"
- ) % SentTransactions.table_name
-
- txn.execute(query, (code, response_json, transaction_id, destination))
+ self._simple_update_one_txn(
+ txn,
+ table=SentTransactions.table_name,
+ keyvalues={
+ "transaction_id": transaction_id,
+ "destination": destination,
+ },
+ updatevalues={
+ "response_code": code,
+ "response_json": None, # For now, don't persist response_json
+ }
+ )
def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id.
@@ -185,25 +189,26 @@ class TransactionStore(SQLBaseStore):
destination (str)
Returns:
- list: A list of `ReceivedTransactionsTable.EntryType`
+ list: A list of dicts
"""
return self.runInteraction(
"get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
- def _get_transactions_after(cls, txn, transaction_id, destination):
- where = (
- "destination = ? AND id > (select id FROM %s WHERE "
- "transaction_id = ? AND destination = ?)"
- ) % (
- SentTransactions.table_name
+ def _get_transactions_after(self, txn, transaction_id, destination):
+ query = (
+ "SELECT * FROM sent_transactions"
+ " WHERE destination = ? AND id >"
+ " ("
+ " SELECT id FROM sent_transactions"
+ " WHERE transaction_id = ? AND destination = ?"
+ " )"
)
- query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination))
- return ReceivedTransactionsTable.decode_results(txn.fetchall())
+ return self.cursor_to_dict(txn)
@cached()
def get_destination_retry_timings(self, destination):
@@ -214,22 +219,27 @@ class TransactionStore(SQLBaseStore):
Returns:
None if not retrying
- Otherwise a DestinationsTable.EntryType for the retry scheme
+ Otherwise a dict for the retry scheme
"""
return self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
- def _get_destination_retry_timings(cls, txn, destination):
- query = DestinationsTable.select_statement("destination = ?")
- txn.execute(query, (destination,))
- result = txn.fetchall()
- if result:
- result = DestinationsTable.decode_single_result(result)
- if result.retry_last_ts > 0:
- return result
- else:
- return None
+ def _get_destination_retry_timings(self, txn, destination):
+ result = self._simple_select_one_txn(
+ txn,
+ table=DestinationsTable.table_name,
+ keyvalues={
+ "destination": destination,
+ },
+ retcols=DestinationsTable.fields,
+ allow_none=True,
+ )
+
+ if result and result["retry_last_ts"] > 0:
+ return result
+ else:
+ return None
def set_destination_retry_timings(self, destination,
retry_last_ts, retry_interval):
@@ -245,11 +255,11 @@ class TransactionStore(SQLBaseStore):
# As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
- DestinationsTable.EntryType(
- destination,
- retry_last_ts,
- retry_interval
- )
+ {
+ "destination": destination,
+ "retry_last_ts": retry_last_ts,
+ "retry_interval": retry_interval
+ },
)
# XXX: we could chose to not bother persisting this if our cache thinks
@@ -262,22 +272,38 @@ class TransactionStore(SQLBaseStore):
retry_interval,
)
- def _set_destination_retry_timings(cls, txn, destination,
+ def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval):
-
query = (
- "INSERT OR REPLACE INTO %s "
- "(destination, retry_last_ts, retry_interval) "
- "VALUES (?, ?, ?) "
- ) % DestinationsTable.table_name
+ "UPDATE destinations"
+ " SET retry_last_ts = ?, retry_interval = ?"
+ " WHERE destination = ?"
+ )
+
+ txn.execute(
+ query,
+ (
+ retry_last_ts, retry_interval, destination,
+ )
+ )
- txn.execute(query, (destination, retry_last_ts, retry_interval))
+ if txn.rowcount == 0:
+ # destination wasn't already in table. Insert it.
+ self._simple_insert_txn(
+ txn,
+ table="destinations",
+ values={
+ "destination": destination,
+ "retry_last_ts": retry_last_ts,
+ "retry_interval": retry_interval,
+ }
+ )
def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction.
Returns:
- list: A list of `DestinationsTable.EntryType`
+ list: A list of dicts
"""
return self.runInteraction(
@@ -285,14 +311,17 @@ class TransactionStore(SQLBaseStore):
self._get_destinations_needing_retry
)
- def _get_destinations_needing_retry(cls, txn):
- where = "retry_last_ts > 0 and retry_next_ts < now()"
- query = DestinationsTable.select_statement(where)
- txn.execute(query)
- return DestinationsTable.decode_results(txn.fetchall())
+ def _get_destinations_needing_retry(self, txn):
+ query = (
+ "SELECT * FROM destinations"
+ " WHERE retry_last_ts > 0 and retry_next_ts < ?"
+ )
+ txn.execute(query, (self._clock.time_msec(),))
+ return self.cursor_to_dict(txn)
-class ReceivedTransactionsTable(Table):
+
+class ReceivedTransactionsTable(object):
table_name = "received_transactions"
fields = [
@@ -304,10 +333,8 @@ class ReceivedTransactionsTable(Table):
"has_been_referenced",
]
- EntryType = namedtuple("ReceivedTransactionsEntry", fields)
-
-class SentTransactions(Table):
+class SentTransactions(object):
table_name = "sent_transactions"
fields = [
@@ -322,7 +349,7 @@ class SentTransactions(Table):
EntryType = namedtuple("SentTransactionsEntry", fields)
-class TransactionsToPduTable(Table):
+class TransactionsToPduTable(object):
table_name = "transaction_id_to_pdu"
fields = [
@@ -332,10 +359,8 @@ class TransactionsToPduTable(Table):
"pdu_origin",
]
- EntryType = namedtuple("TransactionsToPduEntry", fields)
-
-class DestinationsTable(Table):
+class DestinationsTable(object):
table_name = "destinations"
fields = [
@@ -343,5 +368,3 @@ class DestinationsTable(Table):
"retry_last_ts",
"retry_interval",
]
-
- EntryType = namedtuple("DestinationsEntry", fields)
diff --git a/synapse/storage/util/__init__.py b/synapse/storage/util/__init__.py
new file mode 100644
index 0000000000..c488b10d3c
--- /dev/null
+++ b/synapse/storage/util/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
new file mode 100644
index 0000000000..e40eb8a8c4
--- /dev/null
+++ b/synapse/storage/util/id_generators.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from collections import deque
+import contextlib
+import threading
+
+
+class IdGenerator(object):
+ def __init__(self, table, column, store):
+ self.table = table
+ self.column = column
+ self.store = store
+ self._lock = threading.Lock()
+ self._next_id = None
+
+ @defer.inlineCallbacks
+ def get_next(self):
+ if self._next_id is None:
+ yield self.store.runInteraction(
+ "IdGenerator_%s" % (self.table,),
+ self.get_next_txn,
+ )
+
+ with self._lock:
+ i = self._next_id
+ self._next_id += 1
+ defer.returnValue(i)
+
+ def get_next_txn(self, txn):
+ with self._lock:
+ if self._next_id:
+ i = self._next_id
+ self._next_id += 1
+ return i
+ else:
+ txn.execute(
+ "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
+ )
+
+ val, = txn.fetchone()
+ cur = val or 0
+ cur += 1
+ self._next_id = cur + 1
+
+ return cur
+
+
+class StreamIdGenerator(object):
+ """Used to generate new stream ids when persisting events while keeping
+ track of which transactions have been completed.
+
+ This allows us to get the "current" stream id, i.e. the stream id such that
+ all ids less than or equal to it have completed. This handles the fact that
+ persistence of events can complete out of order.
+
+ Usage:
+ with stream_id_gen.get_next_txn(txn) as stream_id:
+ # ... persist event ...
+ """
+ def __init__(self):
+ self._lock = threading.Lock()
+
+ self._current_max = None
+ self._unfinished_ids = deque()
+
+ def get_next_txn(self, txn):
+ """
+ Usage:
+ with stream_id_gen.get_next_txn(txn) as stream_id:
+ # ... persist event ...
+ """
+ if not self._current_max:
+ self._get_or_compute_current_max(txn)
+
+ with self._lock:
+ self._current_max += 1
+ next_id = self._current_max
+
+ self._unfinished_ids.append(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ with self._lock:
+ self._unfinished_ids.remove(next_id)
+
+ return manager()
+
+ @defer.inlineCallbacks
+ def get_max_token(self, store):
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ if not self._current_max:
+ yield store.runInteraction(
+ "_compute_current_max",
+ self._get_or_compute_current_max,
+ )
+
+ with self._lock:
+ if self._unfinished_ids:
+ defer.returnValue(self._unfinished_ids[0] - 1)
+
+ defer.returnValue(self._current_max)
+
+ def _get_or_compute_current_max(self, txn):
+ with self._lock:
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ rows = txn.fetchall()
+ val, = rows[0]
+
+ self._current_max = int(val) if val else 1
+
+ return self._current_max
diff --git a/synapse/util/async.py b/synapse/util/async.py
index c4fe5d522f..d8febdb90c 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -32,3 +32,22 @@ def run_on_reactor():
iteration of the main loop
"""
return sleep(0)
+
+
+def create_observer(deferred):
+ """Creates a deferred that observes the result or failure of the given
+ deferred *without* affecting the given deferred.
+ """
+ d = defer.Deferred()
+
+ def callback(r):
+ d.callback(r)
+ return r
+
+ def errback(f):
+ d.errback(f)
+ return f
+
+ deferred.addCallbacks(callback, errback)
+
+ return d
diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py
index 65d5792907..96163c90f1 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/lrucache.py
@@ -14,6 +14,10 @@
# limitations under the License.
+from functools import wraps
+import threading
+
+
class LruCache(object):
"""Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety.
@@ -24,6 +28,16 @@ class LruCache(object):
PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
+ lock = threading.Lock()
+
+ def synchronized(f):
+ @wraps(f)
+ def inner(*args, **kwargs):
+ with lock:
+ return f(*args, **kwargs)
+
+ return inner
+
def add_node(key, value):
prev_node = list_root
next_node = prev_node[NEXT]
@@ -51,6 +65,7 @@ class LruCache(object):
next_node[PREV] = prev_node
cache.pop(node[KEY], None)
+ @synchronized
def cache_get(key, default=None):
node = cache.get(key, None)
if node is not None:
@@ -59,6 +74,7 @@ class LruCache(object):
else:
return default
+ @synchronized
def cache_set(key, value):
node = cache.get(key, None)
if node is not None:
@@ -69,6 +85,7 @@ class LruCache(object):
if len(cache) > max_size:
delete_node(list_root[PREV])
+ @synchronized
def cache_set_default(key, value):
node = cache.get(key, None)
if node is not None:
@@ -79,6 +96,7 @@ class LruCache(object):
delete_node(list_root[PREV])
return value
+ @synchronized
def cache_pop(key, default=None):
node = cache.get(key, None)
if node:
@@ -87,15 +105,21 @@ class LruCache(object):
else:
return default
+ @synchronized
def cache_len():
return len(cache)
+ @synchronized
+ def cache_contains(key):
+ return key in cache
+
self.sentinel = object()
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
self.len = cache_len
+ self.contains = cache_contains
def __getitem__(self, key):
result = self.get(key, self.sentinel)
@@ -114,3 +138,6 @@ class LruCache(object):
def __len__(self):
return self.len()
+
+ def __contains__(self, key):
+ return self.contains(key)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 4e82232796..a42138f556 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -60,7 +60,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
if retry_timings:
retry_last_ts, retry_interval = (
- retry_timings.retry_last_ts, retry_timings.retry_interval
+ retry_timings["retry_last_ts"], retry_timings["retry_interval"]
)
now = int(clock.time_msec())
|