From 13a6517d89c0619a938321640f331571eba0edc9 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:01:29 +0100 Subject: s/by_token/by_access_token/g We're about to have two kinds of token, access and refresh --- synapse/storage/registration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bf803f2c6e..0e404afb7c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -132,10 +132,10 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate((r,)) + self.get_user_by_access_token.invalidate((r,)) @cached() - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """Get a user from the given access token. Args: @@ -147,7 +147,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( - "get_user_by_token", + "get_user_by_access_token", self._query_for_auth, token ) -- cgit 1.5.1 From cecbd636e94f4e900ef6d246b62698ff1c8ee352 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:21:35 +0100 Subject: /tokenrefresh POST endpoint This allows refresh tokens to be exchanged for (access_token, refresh_token). It also starts issuing them on login, though no clients currently interpret them. --- synapse/handlers/auth.py | 35 ++++++++++-- synapse/rest/client/v1/login.py | 6 ++- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/tokenrefresh.py | 56 +++++++++++++++++++ synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 1 + synapse/storage/registration.py | 62 ++++++++++++++++++++++ synapse/storage/schema/delta/23/refresh_tokens.sql | 21 ++++++++ tests/storage/test_registration.py | 55 +++++++++++++++++++ 9 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/tokenrefresh.py create mode 100644 synapse/storage/schema/delta/23/refresh_tokens.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0bf917efdd..65bd8189dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,15 +279,18 @@ class AuthHandler(BaseHandler): user_id (str): User ID password (str): Password Returns: - The access token for the user's session. + A tuple of: + The access token for the user's session. + The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) logger.info("Logging in user %s", user_id) - token = yield self.issue_access_token(user_id) - defer.returnValue(token) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,11 +307,16 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id): - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_access_token(user_id) + access_token = self.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token) defer.returnValue(access_token) + @defer.inlineCallbacks + def issue_refresh_token(self, user_id): + refresh_token = self.generate_refresh_token(user_id) + yield self.store.add_refresh_token_to_user(user_id, refresh_token) + defer.returnValue(refresh_token) + def generate_access_token(self, user_id): macaroon = pymacaroons.Macaroon( location = self.hs.config.server_name, @@ -323,6 +331,23 @@ class AuthHandler(BaseHandler): return macaroon.serialize() + def generate_refresh_token(self, user_id): + m = self._generate_base_macaroon(user_id) + m.add_first_party_caveat("type = refresh") + # Important to add a nonce, because otherwise every refresh token for a + # user will be the same. + m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16)) + return m.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 694072693d..b963a38618 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -78,13 +78,15 @@ class LoginRestServlet(ClientV1RestServlet): login_submission["user"] = UserID.create( login_submission["user"], self.hs.hostname).to_string() - token = yield self.handlers.auth_handler.login_with_password( + auth_handler = self.handlers.auth_handler + access_token, refresh_token = yield auth_handler.login_with_password( user_id=login_submission["user"], password=login_submission["password"]) result = { "user_id": login_submission["user"], # may have changed - "access_token": token, + "access_token": access_token, + "refresh_token": refresh_token, "home_server": self.hs.hostname, } diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 33f961e898..5831ff0e62 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -21,6 +21,7 @@ from . import ( auth, receipts, keys, + tokenrefresh, ) from synapse.http.server import JsonResource @@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) + tokenrefresh.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py new file mode 100644 index 0000000000..901e777983 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_json_dict_from_request + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + PATTERN = client_v2_pattern("/tokenrefresh") + + def __init__(self, hs): + super(TokenRefreshRestServlet, self).__init__() + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_dict_from_request(request) + try: + old_refresh_token = body["refresh_token"] + auth_handler = self.hs.get_handlers().auth_handler + (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token) + new_access_token = yield auth_handler.issue_access_token(user_id) + defer.returnValue((200, { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + })) + except KeyError: + raise SynapseError(400, "Missing required key 'refresh_token'.") + except StoreError: + raise AuthError(403, "Did not recognize refresh token") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f154b1c8ae..53673b3bf5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 22 +SCHEMA_VERSION = 23 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1444767a52..ce71389f02 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -181,6 +181,7 @@ class SQLBaseStore(object): self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0e404afb7c..f632306688 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore): desc="add_access_token_to_user", ) + @defer.inlineCallbacks + def add_refresh_token_to_user(self, user_id, token): + """Adds a refresh token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new refresh token to add. + Raises: + StoreError if there was a problem adding this. + """ + next_id = yield self._refresh_tokens_id_gen.get_next() + + yield self._simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token + }, + desc="add_refresh_token_to_user", + ) + @defer.inlineCallbacks def register(self, user_id, token, password_hash): """Attempts to register an account. @@ -152,6 +174,46 @@ class RegistrationStore(SQLBaseStore): token ) + def exchange_refresh_token(self, refresh_token, token_generator): + """Exchange a refresh token for a new access token and refresh token. + + Doing so invalidates the old refresh token - refresh tokens are single + use. + + Args: + token (str): The refresh token of a user. + token_generator (fn: str -> str): Function which, when given a + user ID, returns a unique refresh token for that user. This + function must never return the same value twice. + Returns: + tuple of (user_id, refresh_token) + Raises: + StoreError if no user was found with that refresh token. + """ + return self.runInteraction( + "exchange_refresh_token", + self._exchange_refresh_token, + refresh_token, + token_generator + ) + + def _exchange_refresh_token(self, txn, old_token, token_generator): + sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + txn.execute(sql, (old_token,)) + rows = self.cursor_to_dict(txn) + if not rows: + raise StoreError(403, "Did not recognize refresh token") + user_id = rows[0]["user_id"] + + # TODO(danielwh): Maybe perform a validation on the macaroon that + # macaroon.user_id == user_id. + + new_token = token_generator(user_id) + sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" + txn.execute(sql, (new_token, old_token,)) + + return user_id, new_token + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql new file mode 100644 index 0000000000..46839e016c --- /dev/null +++ b/synapse/storage/schema/delta/23/refresh_tokens.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS refresh_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7a24cf898a..a4f929796a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,7 +17,9 @@ from tests import unittest from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage.registration import RegistrationStore +from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() + self.db_pool = hs.get_db_pool() self.store = RegistrationStore(hs) @@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) + @defer.inlineCallbacks + def test_exchange_refresh_token_valid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, last_token,)) + + (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( + last_token, generator.generate) + self.assertEqual(uid, found_user_id) + + rows = yield self.db_pool.runQuery( + "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) + self.assertEqual([(refresh_token,)], rows) + # We issued token 1, then exchanged it for token 2 + expected_refresh_token = u"%s-%d" % (uid, 2,) + self.assertEqual(expected_refresh_token, refresh_token) + + @defer.inlineCallbacks + def test_exchange_refresh_token_none(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + @defer.inlineCallbacks + def test_exchange_refresh_token_invalid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + wrong_token = "%s-wrong" % (last_token,) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, wrong_token,)) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + +class TokenGenerator: + def __init__(self): + self._last_issued_token = 0 + + def generate(self, user_id): + self._last_issued_token += 1 + return u"%s-%d" % (user_id, self._last_issued_token,) -- cgit 1.5.1 From 78323ccdb359404109bfcdd8b5bf6f641ba3ff9b Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 24 Aug 2015 16:17:38 +0100 Subject: Remove syutil dependency in favour of smaller single-purpose libraries --- synapse/config/key.py | 35 ++++++++++++++---------------- synapse/crypto/event_signing.py | 9 ++++---- synapse/crypto/keyring.py | 18 +++++++-------- synapse/http/client.py | 3 ++- synapse/http/matrixfederationclient.py | 4 ++-- synapse/http/server.py | 12 +++++----- synapse/python_dependencies.py | 9 +++----- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/keys.py | 3 ++- synapse/rest/key/v1/server_key_resource.py | 6 ++--- synapse/rest/key/v2/local_key_resource.py | 6 ++--- synapse/storage/event_federation.py | 2 +- synapse/storage/events.py | 23 ++++++++++---------- synapse/storage/keys.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/signatures.py | 2 +- synapse/storage/transactions.py | 2 +- 17 files changed, 69 insertions(+), 71 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/config/key.py b/synapse/config/key.py index 0494c0cb77..0f90bce04e 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from ._base import Config, ConfigError -import syutil.crypto.signing_key -from syutil.crypto.signing_key import ( - is_signing_algorithm_supported, decode_verify_key_bytes -) -from syutil.base64util import decode_base64 + from synapse.util.stringutils import random_string +from signedjson.key import ( + generate_signing_key, is_signing_algorithm_supported, + decode_signing_key_base64, decode_verify_key_bytes, + read_signing_keys, write_signing_keys, NACL_ED25519 +) +from unpadded_base64 import decode_base64 + +import os class KeyConfig(Config): @@ -83,9 +86,7 @@ class KeyConfig(Config): def read_signing_key(self, signing_key_path): signing_keys = self.read_file(signing_key_path, "signing_key") try: - return syutil.crypto.signing_key.read_signing_keys( - signing_keys.splitlines(True) - ) + return read_signing_keys(signing_keys.splitlines(True)) except Exception: raise ConfigError( "Error reading signing_key." @@ -112,22 +113,18 @@ class KeyConfig(Config): if not os.path.exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) - syutil.crypto.signing_key.write_signing_keys( - signing_key_file, - (syutil.crypto.signing_key.generate_signing_key(key_id),), + write_signing_keys( + signing_key_file, (generate_signing_key(key_id),), ) else: signing_keys = self.read_file(signing_key_path, "signing_key") if len(signing_keys.split("\n")[0].split()) == 1: # handle keys in the old format. key_id = "a_" + random_string(4) - key = syutil.crypto.signing_key.decode_signing_key_base64( - syutil.crypto.signing_key.NACL_ED25519, - key_id, - signing_keys.split("\n")[0] + key = decode_signing_key_base64( + NACL_ED25519, key_id, signing_keys.split("\n")[0] ) with open(signing_key_path, "w") as signing_key_file: - syutil.crypto.signing_key.write_signing_keys( - signing_key_file, - (key,), + write_signing_keys( + signing_key_file, (key,), ) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 6633b19565..64e40864af 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -15,11 +15,12 @@ # limitations under the License. -from synapse.events.utils import prune_event -from syutil.jsonutil import encode_canonical_json -from syutil.base64util import encode_base64, decode_base64 -from syutil.crypto.jsonsign import sign_json from synapse.api.errors import SynapseError, Codes +from synapse.events.utils import prune_event + +from canonicaljson import encode_canonical_json +from unpaddedbase64 import encode_base64, decode_base64 +from signedjson.sign import sign_json import hashlib import logging diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index aa74d4d0cb..a692cdbe55 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,20 +14,20 @@ # limitations under the License. from synapse.crypto.keyclient import fetch_server_key +from synapse.api.errors import SynapseError, Codes +from synapse.util.retryutils import get_retry_limiter +from synapse.util import unwrapFirstError +from synapse.util.async import ObservableDeferred + from twisted.internet import defer -from syutil.crypto.jsonsign import ( + +from signedjson.sign import ( verify_signed_json, signature_ids, sign_json, encode_canonical_json ) -from syutil.crypto.signing_key import ( +from signedjson.key import ( is_signing_algorithm_supported, decode_verify_key_bytes ) -from syutil.base64util import decode_base64, encode_base64 -from synapse.api.errors import SynapseError, Codes - -from synapse.util.retryutils import get_retry_limiter -from synapse.util import unwrapFirstError - -from synapse.util.async import ObservableDeferred +from unpaddedbase64 import decode_base64, encode_base64 from OpenSSL import crypto diff --git a/synapse/http/client.py b/synapse/http/client.py index 49737d55da..4b8fd3d3a3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,9 +15,10 @@ from synapse.api.errors import CodeMessageException from synapse.util.logcontext import preserve_context_over_fn -from syutil.jsonutil import encode_canonical_json import synapse.metrics +from canonicaljson import encode_canonical_json + from twisted.internet import defer, reactor from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 854e17a473..1c9e552788 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -25,13 +25,13 @@ from synapse.util.async import sleep from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics -from syutil.jsonutil import encode_canonical_json +from canonicaljson import encode_canonical_json from synapse.api.errors import ( SynapseError, Codes, HttpResponseException, ) -from syutil.crypto.jsonsign import sign_json +from signedjson.sign import sign_json import simplejson as json import logging diff --git a/synapse/http/server.py b/synapse/http/server.py index b60e905a62..50feea6f1c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -21,8 +21,8 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext import synapse.metrics import synapse.events -from syutil.jsonutil import ( - encode_canonical_json, encode_pretty_printed_json, encode_json +from canonicaljson import ( + encode_canonical_json, encode_pretty_printed_json ) from twisted.internet import defer @@ -33,6 +33,7 @@ from twisted.web.util import redirectTo import collections import logging import urllib +import ujson logger = logging.getLogger(__name__) @@ -270,12 +271,11 @@ def respond_with_json(request, code, json_object, send_cors=False, if pretty_print: json_bytes = encode_pretty_printed_json(json_object) + "\n" else: - if canonical_json: + if canonical_json or synapse.events.USE_FROZEN_DICTS: json_bytes = encode_canonical_json(json_object) else: - json_bytes = encode_json( - json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS - ) + # ujson doesn't like frozen_dicts. + json_bytes = ujson.dumps(json_object, ensure_ascii=False) return respond_with_json_bytes( request, code, json_bytes, diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index a87fdeb2a0..ced7ff0d7e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -18,7 +18,9 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) REQUIREMENTS = { - "syutil>=0.0.7": ["syutil>=0.0.7"], + "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], + "canonicaljson>=1.0.0": ["canconicaljson>=1.0.0"], + "signedjson>=1.0.0": ["signedjson>=1.0.0"], "Twisted>=15.1.0": ["twisted>=15.1.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], @@ -54,11 +56,6 @@ def github_link(project, version, egg): return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) DEPENDENCY_LINKS = [ - github_link( - project="matrix-org/syutil", - version="v0.0.7", - egg="syutil-0.0.7", - ), github_link( project="matrix-org/matrix-angular-sdk", version="v0.6.6", diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 11d08fbced..45ed5cceac 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -40,7 +40,7 @@ class VoipRestServlet(ClientV1RestServlet): username = "%d:%s" % (expiry, auth_user.to_string()) mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) - # We need to use standard base64 encoding here, *not* syutil's + # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the # same result as the TURN server. password = base64.b64encode(mac.digest()) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 718928eedd..21654fa2da 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -18,7 +18,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.types import UserID -from syutil.jsonutil import encode_canonical_json + +from canonicaljson import encode_canonical_json from ._base import client_v2_pattern diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index 71e9a51f5c..6df46969c4 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -16,9 +16,9 @@ from twisted.web.resource import Resource from synapse.http.server import respond_with_json_bytes -from syutil.crypto.jsonsign import sign_json -from syutil.base64util import encode_base64 -from syutil.jsonutil import encode_canonical_json +from signedjson.sign import sign_json +from unpaddedbase64 import encode_base64 +from canonicaljson import encode_canonical_json from OpenSSL import crypto import logging diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 33cbd7cf8e..ef7699d590 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -16,9 +16,9 @@ from twisted.web.resource import Resource from synapse.http.server import respond_with_json_bytes -from syutil.crypto.jsonsign import sign_json -from syutil.base64util import encode_base64 -from syutil.jsonutil import encode_canonical_json +from signedjson.sign import sign_json +from unpaddedbase64 import encode_base64 +from canonicaljson import encode_canonical_json from hashlib import sha256 from OpenSSL import crypto import logging diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 25cc84eb95..bc90e17c63 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -17,7 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from syutil.base64util import encode_base64 +from unpaddedbase64 import encode_base64 import logging from Queue import PriorityQueue, Empty diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e3eabab13d..e7439321b8 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,7 @@ from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes -from syutil.jsonutil import encode_json +from canonicaljson import encode_canonical_json from contextlib import contextmanager import logging @@ -33,6 +33,13 @@ import ujson as json logger = logging.getLogger(__name__) +def encode_json(json_object): + if USE_FROZEN_DICTS: + # ujson doesn't like frozen_dicts + return encode_canonical_json(json_object) + else: + return json.dumps(json_object, ensure_ascii=False) + # These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster @@ -253,8 +260,7 @@ class EventsStore(SQLBaseStore): ) metadata_json = encode_json( - event.internal_metadata.get_dict(), - using_frozen_dicts=USE_FROZEN_DICTS + event.internal_metadata.get_dict() ).decode("UTF-8") sql = ( @@ -329,12 +335,9 @@ class EventsStore(SQLBaseStore): "event_id": event.event_id, "room_id": event.room_id, "internal_metadata": encode_json( - event.internal_metadata.get_dict(), - using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8"), - "json": encode_json( - event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS + event.internal_metadata.get_dict() ).decode("UTF-8"), + "json": encode_json(event_dict(event)).decode("UTF-8"), } for event, _ in events_and_contexts ], @@ -353,9 +356,7 @@ class EventsStore(SQLBaseStore): "type": event.type, "processed": True, "outlier": event.internal_metadata.is_outlier(), - "content": encode_json( - event.content, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8"), + "content": encode_json(event.content).decode("UTF-8"), } for event, _ in events_and_contexts ], diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index ffd6daa880..344cacdc75 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -19,7 +19,7 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import OpenSSL -from syutil.crypto.signing_key import decode_verify_key_bytes +from signedjson.key import decode_verify_key_bytes import hashlib diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 08ea62681b..00b748f131 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from syutil.jsonutil import encode_canonical_json +from canonicaljson import encode_canonical_json import logging import simplejson as json diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 4f15e534b4..ab57b92174 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -17,7 +17,7 @@ from twisted.internet import defer from _base import SQLBaseStore -from syutil.base64util import encode_base64 +from unpaddedbase64 import encode_base64 from synapse.crypto.event_signing import compute_event_reference_hash diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index c8c7e6591a..15695e9831 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,7 +18,7 @@ from synapse.util.caches.descriptors import cached from collections import namedtuple -from syutil.jsonutil import encode_canonical_json +from canonicaljson import encode_canonical_json import logging logger = logging.getLogger(__name__) -- cgit 1.5.1 From 01fc3943f1aa409766f8fc54037af9102c55c658 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 24 Aug 2015 17:18:58 +0100 Subject: Fix indent --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e7439321b8..fba837f461 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) def encode_json(json_object): if USE_FROZEN_DICTS: - # ujson doesn't like frozen_dicts + # ujson doesn't like frozen_dicts return encode_canonical_json(json_object) else: return json.dumps(json_object, ensure_ascii=False) -- cgit 1.5.1 From 037481a033aeb80b86a2e7e074e29cfaf8c23ea8 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 24 Aug 2015 17:48:57 +0100 Subject: Remove autoincrement since we incrementing the ID in the storage layer --- synapse/storage/schema/delta/23/refresh_tokens.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql index 46839e016c..437b1ac1be 100644 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ b/synapse/storage/schema/delta/23/refresh_tokens.sql @@ -14,7 +14,7 @@ */ CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, token TEXT NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) -- cgit 1.5.1 From a0b181bd17cb7ec2a43ed2dbdeb1bb40f3f4373c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:23:06 +0100 Subject: Remove completely unused concepts from codebase Removes device_id and ClientInfo device_id is never actually written, and the matrix.org DB has no non-null entries for it. Right now, it's just cluttering up code. This doesn't remove the columns from the database, because that's fiddly. --- synapse/api/auth.py | 17 ++++++--------- synapse/handlers/admin.py | 1 + synapse/handlers/message.py | 9 +++----- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 ++-- synapse/rest/client/v1/events.py | 4 ++-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +++---- synapse/rest/client/v1/profile.py | 4 ++-- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/room.py | 34 ++++++++++++++--------------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/filter.py | 4 ++-- synapse/rest/client/v2_alpha/keys.py | 6 ++--- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/__init__.py | 7 +++--- synapse/storage/registration.py | 5 ++--- synapse/types.py | 4 ---- tests/api/test_auth.py | 8 +++---- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------ tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 2 -- tests/utils.py | 3 +-- 29 files changed, 63 insertions(+), 90 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d9237ccc3..1496db7dff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, ClientInfo +from synapse.types import UserID import logging @@ -322,9 +322,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - tuple : of UserID and device string: - User ID object of the user making the request - ClientInfo object of the client instance the user is using + tuple of: + UserID (str) + Access token ID (str) Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -355,7 +355,7 @@ class Auth(object): request.authenticated_entity = user_id defer.returnValue( - (UserID.from_string(user_id), ClientInfo("", "")) + (UserID.from_string(user_id), "") ) return except KeyError: @@ -363,7 +363,6 @@ class Auth(object): user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] - device_id = user_info["device_id"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -375,14 +374,13 @@ class Auth(object): self.store.insert_client_ip( user=user, access_token=access_token, - device_id=user_info["device_id"], ip=ip_addr, user_agent=user_agent ) request.authenticated_entity = user.to_string() - defer.returnValue((user, ClientInfo(device_id, token_id))) + defer.returnValue((user, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -396,7 +394,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user, device_id, and whether the + dict : dict that includes the user and whether the user is a server admin. Raises: AuthError if no user by that token exists or the token is invalid. @@ -409,7 +407,6 @@ class Auth(object): ) user_info = { "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1c9e7152c7..d852a18555 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -34,6 +34,7 @@ class AdminHandler(BaseHandler): d = {} for r in res: + # Note that device_id is always None device = d.setdefault(r["device_id"], {}) session = device.setdefault(r["access_token"], []) session.append({ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f12465fa2c..23b779ad7c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -183,7 +183,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, - client=None, txn_id=None): + token_id=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -217,11 +217,8 @@ class MessageHandler(BaseHandler): builder.content ) - if client is not None: - if client.token_id is not None: - builder.internal_metadata.token_id = client.token_id - if client.device_id is not None: - builder.internal_metadata.device_id = client.device_id + if token_id is not None: + builder.internal_metadata.token_id = token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2ce754b028..504b63eab4 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 6758a888b3..4dcda57c1b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 77b7c25a03..582148b659 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4a259bba64..4ea4da653c 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 78d4f2b128..a770efd841 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1e77eb49cf..fdde88a60d 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index c83287c028..3aabc93b8b 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -65,7 +65,7 @@ class PusherRestServlet(ClientV1RestServlet): try: yield pusher_pool.add_pusher( user_name=user.to_string(), - access_token=client.token_id, + access_token=token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b4a70cba99..c9c27dd5a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -159,7 +159,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler yield msg_handler.create_and_send_event( - event_dict, client=client, txn_id=txn_id, + event_dict, token_id=token_id, txn_id=txn_id, ) defer.returnValue((200, {})) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -186,7 +186,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -250,7 +250,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -317,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -341,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -357,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -402,7 +402,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -427,7 +427,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": state_key, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -457,7 +457,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -469,7 +469,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": user.to_string(), "redacts": event_id, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -497,7 +497,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 11d08fbced..4ae2d81b70 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 522a312c9e..b5edffdb60 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -119,7 +119,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 703250cea8..f8f91b63f5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 718928eedd..ec1145454f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -63,7 +63,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -108,7 +108,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -180,7 +180,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 40406e2ede..52e99f54d5 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -39,7 +39,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) yield self.receipts_handler.received_client_receipt( room_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index f2fd0b9f32..83a257b969 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -87,7 +87,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) limit = parse_integer(request, "limit", required=True) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index e77a20fb2e..c28dc86cd7 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index cdd1d44e07..439d5a30a8 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 53673b3bf5..77cb1dbd81 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,9 +94,9 @@ class DataStore(RoomMemberStore, RoomStore, ) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, device_id, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, device_id, ip) + key = (user.to_string(), access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, "user_agent": user_agent, }, values={ - "device_id": device_id, "last_seen": now, }, desc="insert_client_ip", @@ -132,7 +131,7 @@ class DataStore(RoomMemberStore, RoomStore, table="user_ips", keyvalues={"user_id": user.to_string()}, retcols=[ - "device_id", "access_token", "ip", "user_agent", "last_seen" + "access_token", "ip", "user_agent", "last_seen" ], desc="get_user_ip_and_agents", ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f632306688..240d14c4d0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,7 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id), device_id and whether they are + dict: Including the name (user_id) and whether they are an admin. Raises: StoreError if no user was found. @@ -228,8 +228,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin," - " access_tokens.device_id, access_tokens.id as token_id" + "SELECT users.name, users.admin, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/types.py b/synapse/types.py index e190374cbd..9cffc33d27 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -209,7 +209,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# token_id is the primary key ID of the access token, not the access token itself. -ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3343c635cc..777eb0395e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -40,7 +40,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -49,7 +48,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -86,7 +84,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +119,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0b78a82a66..4039a86d85 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -74,7 +74,6 @@ class PresenceStateTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } @@ -163,7 +162,6 @@ class PresenceListTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2e55cc08a1..dd1e67e0f9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -58,7 +58,6 @@ class RoomPermissionsTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -445,7 +444,6 @@ class RoomsMemberListTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -525,7 +523,6 @@ class RoomsCreateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -618,7 +615,6 @@ class RoomTopicTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } @@ -725,7 +721,6 @@ class RoomMemberStateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -852,7 +847,6 @@ class RoomMessagesTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -949,7 +943,6 @@ class RoomInitialSyncTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index dc8bbaaf0e..0f70ce81dc 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -65,7 +65,6 @@ class RoomTypingTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 15568b36cd..badb59f080 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -47,7 +47,6 @@ class V2AlphaRestTestCase(unittest.TestCase): return { "user": UserID.from_string(self.USER_ID), "admin": False, - "device_id": None, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a4f929796a..54fe10d58f 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -54,7 +54,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result @@ -72,7 +71,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result diff --git a/tests/utils.py b/tests/utils.py index d0fba2252d..ff560ef342 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -282,7 +282,6 @@ class MemoryDataStore(object): return { "name": self.tokens_to_users[token], "admin": 0, - "device_id": None, } except: raise StoreError(400, "User does not exist.") @@ -380,7 +379,7 @@ class MemoryDataStore(object): def get_ops_levels(self, room_id): return defer.succeed((5, 5, 5)) - def insert_client_ip(self, user, device_id, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): return defer.succeed(None) -- cgit 1.5.1 From a9d8bd95e722e24c7ddd6b14a3714c1b2f737d4d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:29:39 +0100 Subject: Stop looking up "admin", which we never read --- synapse/api/auth.py | 4 +--- synapse/storage/registration.py | 5 ++--- tests/api/test_auth.py | 2 -- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------- tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 6 ++---- tests/utils.py | 1 - 9 files changed, 5 insertions(+), 24 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b41e34e658..65ee1452ce 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -392,8 +392,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user and whether the - user is a server admin. + dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -404,7 +403,6 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "admin": bool(ret.get("admin", False)), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 240d14c4d0..a2d0f7c4b1 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,8 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and whether they are - an admin. + dict: Including the name (user_id) and the ID of their access token. Raises: StoreError if no user was found. """ @@ -228,7 +227,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.id as token_id" + "SELECT users.name, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 777eb0395e..22fc804331 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -41,7 +41,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 4039a86d85..91547bdd06 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -73,7 +73,6 @@ class PresenceStateTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } @@ -161,7 +160,6 @@ class PresenceListTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index dd1e67e0f9..34ab47d02e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -57,7 +57,6 @@ class RoomPermissionsTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -443,7 +442,6 @@ class RoomsMemberListTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -522,7 +520,6 @@ class RoomsCreateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -614,7 +611,6 @@ class RoomTopicTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } @@ -720,7 +716,6 @@ class RoomMemberStateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -846,7 +841,6 @@ class RoomMessagesTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -942,7 +936,6 @@ class RoomInitialSyncTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0f70ce81dc..1c4519406d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -64,7 +64,6 @@ class RoomTypingTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index badb59f080..ef972a53aa 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -46,7 +46,6 @@ class V2AlphaRestTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), - "admin": False, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 54fe10d58f..0cce6c37df 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,8 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) @@ -70,8 +69,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) diff --git a/tests/utils.py b/tests/utils.py index ff560ef342..3766a994f2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -281,7 +281,6 @@ class MemoryDataStore(object): try: return { "name": self.tokens_to_users[token], - "admin": 0, } except: raise StoreError(400, "User does not exist.") -- cgit 1.5.1 From d3c0e488591386b7d24d23c6f6d3b237523fca89 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 26 Aug 2015 13:42:45 +0100 Subject: Merge erikj/user_dedup to develop --- synapse/handlers/auth.py | 39 +++++++++++++++++++++++++++++++-------- synapse/handlers/register.py | 4 ++-- synapse/rest/client/v1/login.py | 5 +++-- synapse/storage/registration.py | 14 ++++++++++++++ 4 files changed, 50 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c983d444e8..1ab19cd1a6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -163,7 +163,8 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - yield self._check_password(user_id, password) + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + self._check_password(user_id, password, password_hash) defer.returnValue(user_id) @defer.inlineCallbacks @@ -280,27 +281,49 @@ class AuthHandler(BaseHandler): password (str): Password Returns: A tuple of: + The user's ID. The access token for the user's session. The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - yield self._check_password(user_id, password) + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + self._check_password(user_id, password, password_hash) + logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((access_token, refresh_token)) + defer.returnValue((user_id, access_token, refresh_token)) @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Checks that user_id has passed password, raises LoginError if not.""" - user_info = yield self.store.get_user_by_id(user_id=user_id) - if not user_info: + def _find_user_id_and_pwd_hash(self, user_id): + """Checks to see if a user with the given id exists. Will check case + insensitively, but will throw if there are multiple inexact matches. + + Returns: + tuple: A 2-tuple of `(canonical_user_id, password_hash)` + """ + user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + if not user_infos: logger.warn("Attempted to login as %s but they do not exist", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - stored_hash = user_info["password_hash"] + if len(user_infos) > 1: + if user_id not in user_infos: + logger.warn( + "Attempted to login as %s but it matches more than one user " + "inexactly: %r", + user_id, user_infos.keys() + ) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + defer.returnValue((user_id, user_infos[user_id])) + else: + defer.returnValue(user_infos.popitem()) + + def _check_password(self, user_id, password, stored_hash): + """Checks that user_id has passed password, raises LoginError if not.""" if not bcrypt.checkpw(password, stored_hash): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3d1b6531c2..56d125f753 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler): yield self.check_user_id_is_valid(user_id) - u = yield self.store.get_user_by_id(user_id) - if u: + users = yield self.store.get_users_by_id_case_insensitive(user_id) + if users: raise SynapseError( 400, "User ID already taken.", diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3a0707c2ee..e580f71964 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet): if not user_id.startswith('@'): user_id = UserID.create( - user_id, self.hs.hostname).to_string() + user_id, self.hs.hostname + ).to_string() auth_handler = self.handlers.auth_handler - access_token, refresh_token = yield auth_handler.login_with_password( + user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id=user_id, password=login_submission["password"]) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index a2d0f7c4b1..c9ceb132ae 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore): allow_none=True, ) + def get_users_by_id_case_insensitive(self, user_id): + """Gets users that match user_id case insensitively. + Returns a mapping of user_id -> password_hash. + """ + def f(txn): + sql = ( + "SELECT name, password_hash FROM users" + " WHERE lower(name) = lower(?)" + ) + txn.execute(sql, (user_id,)) + return dict(txn.fetchall()) + + return self.runInteraction("get_users_by_id_case_insensitive", f) + @defer.inlineCallbacks def user_set_password_hash(self, user_id, password_hash): """ -- cgit 1.5.1 From 417485eefaff86206b5f961102f882b3fbe44651 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 8 Sep 2015 18:14:54 +0100 Subject: Include the event_id and stream_ordering of membership events when looking up which rooms a user is in --- synapse/storage/roommember.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8eee2dfbcc..cd9eefbd9f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) RoomsForUser = namedtuple( "RoomsForUser", - ("room_id", "sender", "membership") + ("room_id", "sender", "membership", "event_id", "stream_ordering") ) @@ -141,9 +141,11 @@ class RoomMemberStore(SQLBaseStore): args.extend(membership_list) sql = ( - "SELECT m.room_id, m.sender, m.membership" + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" " FROM room_memberships as m" " INNER JOIN current_state_events as c" + " ON e.event_id = c.event_id " + " INNER JOIN events as e " " ON m.event_id = c.event_id " " AND m.room_id = c.room_id " " AND m.user_id = c.state_key" -- cgit 1.5.1 From 89ae0166ded093be2343409cfe42f475dea83139 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Sep 2015 13:25:22 +0100 Subject: Allow room initialSync for users that have left the room, returning a snapshot of how the room was when they left it --- synapse/api/auth.py | 49 ++++++++++++++++++++++++++ synapse/handlers/message.py | 85 ++++++++++++++++++++++++++++++++++++++++----- synapse/storage/stream.py | 15 ++++++++ 3 files changed, 140 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0c0d678562..9b614a12bb 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -104,6 +104,20 @@ class Auth(object): @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): + """Check if the user is currently joined in the room + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user is not in the room. + Returns: + A deferred membership event for the user if the user is in + the room. + """ if current_state: member = current_state.get( (EventTypes.Member, user_id), @@ -119,6 +133,41 @@ class Auth(object): self._check_joined_room(member, user_id, room_id) defer.returnValue(member) + @defer.inlineCallbacks + def check_user_was_in_room(self, room_id, user_id, current_state=None): + """Check if the user was in the room at some point. + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user was never in the room. + Returns: + A deferred membership event for the user if the user was in + the room. + """ + if current_state: + member = current_state.get( + (EventTypes.Member, user_id), + None + ) + else: + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + + if not member: + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + + defer.returnValue(member) + @defer.inlineCallbacks def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5447c97e83..fc9a234333 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError from synapse.util.logcontext import PreserveLoggingContext -from synapse.types import UserID, RoomStreamToken +from synapse.types import UserID, RoomStreamToken, StreamToken from ._base import BaseHandler @@ -377,7 +377,6 @@ class MessageHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield defer.gatherResults( [ self.store.get_recent_events_for_room( @@ -434,13 +433,83 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def room_initial_sync(self, user_id, room_id, pagin_config=None, feedback=False): - current_state = yield self.state.get_current_state( - room_id=room_id, + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + user_id(str): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages to return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON object with the snapshot of the room. + """ + + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, member_event + ) + elif member_event.membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, member_event + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + member_event): + room_state = yield self.store.get_state_for_events( + member_event.room_id, [member_event.event_id], None + ) + + room_state = room_state[member_event.event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event.event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield self._filter_events_for_client( + user_id, room_id, messages ) - yield self.auth.check_joined_room( - room_id, user_id, - current_state=current_state + start_token = StreamToken(token[0], 0, 0, 0) + end_token = StreamToken(token[1], 0, 0, 0) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": member_event.membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + member_event): + current_state = yield self.state.get_current_state( + room_id=room_id, ) # TODO(paul): I wish I was called with user objects not user_id @@ -454,8 +523,6 @@ class MessageHandler(BaseHandler): for x in current_state.values() ] - member_event = current_state.get((EventTypes.Member, user_id,)) - now_token = yield self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d7fe423f5a..0abfa86cd2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -379,6 +379,21 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + ).addCallback(lambda stream_ordering: "s%d" % (stream_ordering,)) + def _get_max_topological_txn(self, txn): txn.execute( "SELECT MAX(topological_ordering) FROM events" -- cgit 1.5.1 From 3c166a24c591afdc851de3c6c754c90471b1b0a9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Sep 2015 16:05:09 +0100 Subject: Remove undocumented and unimplemented 'feedback' parameter from the Client-Server API --- synapse/api/constants.py | 11 ----------- synapse/handlers/message.py | 21 +++------------------ synapse/handlers/room.py | 1 - synapse/rest/client/v1/initial_sync.py | 2 -- synapse/rest/client/v1/room.py | 2 -- synapse/storage/stream.py | 10 ++-------- 6 files changed, 5 insertions(+), 42 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 1423986c1e..3385664394 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -27,16 +27,6 @@ class Membership(object): LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) -class Feedback(object): - - """Represents the types of feedback a user can send in response to a - message.""" - - DELIVERED = u"delivered" - READ = u"read" - LIST = (DELIVERED, READ) - - class PresenceState(object): """Represents the presence state of a user.""" OFFLINE = u"offline" @@ -73,7 +63,6 @@ class EventTypes(object): PowerLevels = "m.room.power_levels" Aliases = "m.room.aliases" Redaction = "m.room.redaction" - Feedback = "m.room.message.feedback" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 171e9d72ac..72ebac047f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -71,7 +71,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - feedback=False, as_client_event=True): + as_client_event=True): """Get messages in a room. Args: @@ -79,7 +79,6 @@ class MessageHandler(BaseHandler): room_id (str): The room they want messages from. pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. - feedback (bool): True to get compressed feedback with the messages as_client_event (bool): True to get events in client-server format. Returns: dict: Pagination API results @@ -264,17 +263,6 @@ class MessageHandler(BaseHandler): ) defer.returnValue(data) - @defer.inlineCallbacks - def get_feedback(self, event_id): - # yield self.auth.check_joined_room(room_id, user_id) - - # Pull out the feedback from the db - fb = yield self.store.get_feedback(event_id) - - if fb: - defer.returnValue(fb) - defer.returnValue(None) - @defer.inlineCallbacks def get_state_events(self, user_id, room_id): """Retrieve all state events for a given room. If the user is @@ -303,8 +291,7 @@ class MessageHandler(BaseHandler): ) @defer.inlineCallbacks - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - feedback=False, as_client_event=True): + def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -314,7 +301,6 @@ class MessageHandler(BaseHandler): user_id (str): The ID of the user making the request. pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. - feedback (bool): True to get feedback along with these messages. as_client_event (bool): True to get events in client-server format. Returns: A list of dicts with "room_id" and "membership" keys for all rooms @@ -439,8 +425,7 @@ class MessageHandler(BaseHandler): defer.returnValue(ret) @defer.inlineCallbacks - def room_initial_sync(self, user_id, room_id, pagin_config=None, - feedback=False): + def room_initial_sync(self, user_id, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0ff816d53e..243623190f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -610,7 +610,6 @@ class RoomEventSource(object): to_key=config.to_key, direction=config.direction, limit=config.limit, - with_feedback=True ) defer.returnValue((events, next_key)) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4ea4da653c..bac68cc29f 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -26,14 +26,12 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): user, _ = yield self.auth.get_user_by_req(request) - with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback, as_client_event=as_client_event ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f4558b95a7..23871f161e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -329,14 +329,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet): pagination_config = PaginationConfig.from_request( request, default_limit=10, ) - with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback, as_client_event=as_client_event ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 0abfa86cd2..5763c462af 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -159,9 +159,7 @@ class StreamStore(SQLBaseStore): @log_function def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0, with_feedback=False): - # TODO (erikj): Handle compressed feedback - + limit=0): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " " INNER JOIN current_state_events as c" @@ -227,10 +225,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, - with_feedback=False): - # TODO (erikj): Handle compressed feedback - + direction='b', limit=-1): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -302,7 +297,6 @@ class StreamStore(SQLBaseStore): @cachedInlineCallbacks(num_args=4) def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): - # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) -- cgit 1.5.1 From 09cb5c7d33c32e2cbf5a5b6f6f0e2780338491d2 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Sep 2015 17:31:09 +0100 Subject: Allow users that have left a room to get the messages that happend in the room before they left --- synapse/handlers/message.py | 31 +++++++++++++++++++++++++++---- synapse/storage/stream.py | 19 ++++++++++++++++++- 2 files changed, 45 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72ebac047f..db89491b46 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -83,21 +83,44 @@ class MessageHandler(BaseHandler): Returns: dict: Pagination API results """ - yield self.auth.check_joined_room(room_id, user_id) + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) data_source = self.hs.get_event_sources().sources["room"] - if not pagin_config.from_token: + if pagin_config.from_token: + room_token = pagin_config.from_token.room_key + else: pagin_config.from_token = ( yield self.hs.get_event_sources().get_current_token( direction='b' ) ) + room_token = pagin_config.from_token.room_key - room_token = RoomStreamToken.parse(pagin_config.from_token.room_key) + room_token = RoomStreamToken.parse(room_token) if room_token.topological is None: raise SynapseError(400, "Invalid token") + pagin_config.from_token = pagin_config.from_token.copy_and_replace( + "room_key", str(room_token) + ) + + source_config = pagin_config.get_source_config("room") + + if member_event.membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room + leave_token = yield self.store.get_topological_token_for_event( + member_event.event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) + + if source_config.direction == "f": + if source_config.to_key is None: + source_config.to_key = str(leave_token) + yield self.hs.get_handlers().federation_handler.maybe_backfill( room_id, room_token.topological ) @@ -105,7 +128,7 @@ class MessageHandler(BaseHandler): user = UserID.from_string(user_id) events, next_key = yield data_source.get_pagination_rows( - user, pagin_config.get_source_config("room"), room_id + user, source_config, room_id ) next_token = pagin_config.from_token.copy_and_replace( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 5763c462af..3cab06fdef 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -386,7 +386,24 @@ class StreamStore(SQLBaseStore): table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering", - ).addCallback(lambda stream_ordering: "s%d" % (stream_ordering,)) + ).addCallback(lambda row: "s%d" % (row,)) + + def get_topological_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "t%d-%d" topological token. + """ + return self._simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "topological_ordering"), + ).addCallback(lambda row: "t%d-%d" % ( + row["topological_ordering"], row["stream_ordering"],) + ) def _get_max_topological_txn(self, txn): txn.execute( -- cgit 1.5.1 From dffc9c4ae0eea0616cc017c7f858f8a923202075 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 14 Sep 2015 14:41:37 +0100 Subject: Drop unused index --- synapse/storage/schema/delta/23/drop_state_index.sql | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 synapse/storage/schema/delta/23/drop_state_index.sql (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..07d0ea5cb2 --- /dev/null +++ b/synapse/storage/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; -- cgit 1.5.1 From 7213588083dd9a721b0cd623fe22b308f25f19a5 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 12:57:40 +0100 Subject: Implement configurable stats reporting SYN-287 This requires that HS owners either opt in or out of stats reporting. When --generate-config is passed, --report-stats must be specified If an already-generated config is used, and doesn't have the report_stats key, it is requested to be set. --- synapse/app/homeserver.py | 35 ++++++- synapse/app/synctl.py | 12 ++- synapse/config/_base.py | 45 +++++++- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/database.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 8 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/voip.py | 2 +- synapse/storage/__init__.py | 20 +++- synapse/storage/events.py | 58 ++++++++++- synapse/storage/registration.py | 12 +++ .../storage/schema/delta/24/stats_reporting.sql | 22 ++++ tests/storage/event_injector.py | 81 ++++++++++++++ tests/storage/test_events.py | 116 +++++++++++++++++++++ tests/storage/test_room.py | 2 +- tests/storage/test_stream.py | 68 +++--------- 24 files changed, 425 insertions(+), 78 deletions(-) create mode 100644 synapse/storage/schema/delta/24/stats_reporting.sql create mode 100644 tests/storage/event_injector.py create mode 100644 tests/storage/test_events.py (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 15c0a4a003..b4429bd4f3 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -42,7 +42,7 @@ from synapse.storage import ( from synapse.server import HomeServer -from twisted.internet import reactor +from twisted.internet import reactor, task, defer from twisted.application import service from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper @@ -677,6 +677,39 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) + start_time = hs.get_clock().time() + + @defer.inlineCallbacks + def phone_stats_home(): + now = int(hs.get_clock().time()) + uptime = int(now - start_time) + if uptime < 0: + uptime = 0 + + stats = {} + stats["homeserver"] = hs.config.server_name + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + stats["total_users"] = yield hs.get_datastore().count_all_users() + + all_rooms = yield hs.get_datastore().get_rooms(False) + stats["total_room_count"] = len(all_rooms) + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + daily_messages = yield hs.get_datastore().count_daily_messages() + if daily_messages is not None: + stats["daily_messages"] = daily_messages + + logger.info("Reporting stats to matrix.org: %s" % (stats,)) + hs.get_simple_http_client().put_json( + "https://matrix.org/report-usage-stats/push", + stats + ) + + if hs.config.report_stats: + phone_home_task = task.LoopingCall(phone_stats_home) + phone_home_task.start(60 * 60 * 24, now=False) + def in_thread(): with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 1f7d543c31..6bcc437591 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -25,6 +25,7 @@ SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] CONFIGFILE = "homeserver.yaml" GREEN = "\x1b[1;32m" +RED = "\x1b[1;31m" NORMAL = "\x1b[m" if not os.path.exists(CONFIGFILE): @@ -45,8 +46,15 @@ def start(): print "Starting ...", args = SYNAPSE args.extend(["--daemonize", "-c", CONFIGFILE]) - subprocess.check_call(args) - print GREEN + "started" + NORMAL + try: + subprocess.check_call(args) + print GREEN + "started" + NORMAL + except subprocess.CalledProcessError as e: + print ( + RED + + "error starting (exit code: %d); see above for logs" % e.returncode + + NORMAL + ) def stop(): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 8a75c48733..b9983f72a2 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,6 +26,16 @@ class ConfigError(Exception): class Config(object): + stats_reporting_begging_spiel = ( + "We would really appreciate it if you could help our project out by " + "reporting anonymized usage statistics from your homeserver. Only very " + "basic aggregate data (e.g. number of users) will be reported, but it " + "helps us to track the growth of the Matrix community, and helps us to " + "make Matrix a success, as well as to convince other networks that they " + "should peer with us.\n" + "Thank you." + ) + @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -111,11 +121,14 @@ class Config(object): results.append(getattr(cls, name)(self, *args, **kargs)) return results - def generate_config(self, config_dir_path, server_name): + def generate_config(self, config_dir_path, server_name, report_stats=None): default_config = "# vim:ft=yaml\n" default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", config_dir_path, server_name + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=report_stats, )) config = yaml.load(default_config) @@ -139,6 +152,12 @@ class Config(object): action="store_true", help="Generate a config file for the server name" ) + config_parser.add_argument( + "--report-stats", + action="store", + help="Stuff", + choices=["yes", "no"] + ) config_parser.add_argument( "--generate-keys", action="store_true", @@ -189,6 +208,11 @@ class Config(object): config_files.append(config_path) if config_args.generate_config: + if config_args.report_stats is None: + config_parser.error( + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + cls.stats_reporting_begging_spiel + ) if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" @@ -211,7 +235,9 @@ class Config(object): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( - config_dir_path, server_name + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=(config_args.report_stats == "yes"), ) obj.invoke_all("generate_files", config) config_file.write(config_bytes) @@ -261,9 +287,20 @@ class Config(object): specified_config.update(yaml_config) server_name = specified_config["server_name"] - _, config = obj.generate_config(config_dir_path, server_name) + _, config = obj.generate_config( + config_dir_path=config_dir_path, + server_name=server_name + ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: + sys.stderr.write( + "Please opt in or out of reporting anonymized homeserver usage " + "statistics, by setting the report_stats key in your config file " + " ( " + config_path + " ) " + + "to either True or False.\n\n" + + Config.stats_reporting_begging_spiel + "\n") + sys.exit(1) if generate_keys: obj.invoke_all("generate_files", config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 38f41933b7..b8d301995e 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -20,7 +20,7 @@ class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) - def default_config(cls, config_dir_path, server_name): + def default_config(cls, **kwargs): return """\ # A list of application service config file to use app_service_config_files: [] diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 15a132b4e3..dd92fcd0dc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,7 @@ class CaptchaConfig(Config): self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Captcha ## diff --git a/synapse/config/database.py b/synapse/config/database.py index f0611e8884..baeda8f300 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -45,7 +45,7 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def default_config(self, config, config_dir_path): + def default_config(self, **kwargs): database_path = self.abspath("homeserver.db") return """\ # Database configuration diff --git a/synapse/config/key.py b/synapse/config/key.py index 23ac8a3fca..2c187065e5 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -40,7 +40,7 @@ class KeyConfig(Config): config["perspectives"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) return """\ ## Signing Keys ## diff --git a/synapse/config/logger.py b/synapse/config/logger.py index daca698d0c..bd0c17c861 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -70,7 +70,7 @@ class LoggingConfig(Config): self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ae5a691527..825fec9a38 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -19,13 +19,15 @@ from ._base import Config class MetricsConfig(Config): def read_config(self, config): self.enable_metrics = config["enable_metrics"] + self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") - def default_config(self, config_dir_path, server_name): - return """\ + def default_config(self, report_stats=None, **kwargs): + suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" + return ("""\ ## Metrics ### # Enable collection and rendering of performance metrics enable_metrics: False - """ + """ + suffix) % locals() diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76d9970e5b..611b598ec7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -27,7 +27,7 @@ class RatelimitConfig(Config): self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_concurrent = config["federation_rc_concurrent"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Ratelimiting ## diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 62de4b399f..fa98eced34 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,7 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") - def default_config(self, config_dir, server_name): + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50) return """\ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 64644b9a7a..2fcf872449 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config): config["thumbnail_sizes"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") return """ diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 1532036876..4c6133cf22 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -41,7 +41,7 @@ class SAML2Config(Config): self.saml2_config_path = None self.saml2_idp_redirect_url = None - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable SAML2 for registration and login. Uses pysaml2 # config_path: Path to the sp_conf.py configuration file diff --git a/synapse/config/server.py b/synapse/config/server.py index a03e55c223..4d12d49857 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -117,7 +117,7 @@ class ServerConfig(Config): self.content_addr = content_addr - def default_config(self, config_dir_path, server_name): + def default_config(self, server_name, **kwargs): if ":" in server_name: bind_port = int(server_name.split(":")[1]) unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e6023a718d..0ac2698293 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -50,7 +50,7 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a1707223d3..a093354ccd 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -22,7 +22,7 @@ class VoipConfig(Config): self.turn_shared_secret = config["turn_shared_secret"] self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Turn ## diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 77cb1dbd81..b64c90d631 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 23 +SCHEMA_VERSION = 24 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -126,6 +126,24 @@ class DataStore(RoomMemberStore, RoomStore, lock=False, ) + @defer.inlineCallbacks + def count_daily_users(self): + def _count_users(txn): + txn.execute( + "SELECT COUNT(DISTINCT user_id) AS users" + " FROM user_ips" + " WHERE last_seen > ?", + # This is close enough to a day for our purposes. + (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) + ) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0a477e3122..2b51db9940 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor @@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json from contextlib import contextmanager import logging +import math import ujson as json logger = logging.getLogger(__name__) @@ -905,3 +905,59 @@ class EventsStore(SQLBaseStore): txn.execute(sql, (event.event_id,)) result = txn.fetchone() return result[0] if result else None + + @defer.inlineCallbacks + def count_daily_messages(self): + def _count_messages(txn): + now = self.hs.get_clock().time() + + txn.execute( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + last_reported = self.cursor_to_dict(txn) + + txn.execute( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + now_reporting = self.cursor_to_dict(txn) + if not now_reporting: + return None + now_reporting = now_reporting[0]["stream_ordering"] + + txn.execute("DELETE FROM stats_reporting") + txn.execute( + "INSERT INTO stats_reporting" + " (reported_stream_token, reported_time)" + " VALUES (?, ?)", + (now_reporting, now,) + ) + + if not last_reported: + return None + + # Close enough to correct for our purposes. + yesterday = (now - 24 * 60 * 60) + if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60: + return None + + txn.execute( + "SELECT COUNT(*) as messages" + " FROM events NATURAL JOIN event_json" + " WHERE json like '%m.room.message%'" + " AND stream_ordering > ?" + " AND stream_ordering <= ?", + ( + last_reported[0]["reported_stream_token"], + now_reporting, + ) + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None + return rows[0]["messages"] + + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index c9ceb132ae..6d76237658 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -289,3 +289,15 @@ class RegistrationStore(SQLBaseStore): if ret: defer.returnValue(ret['user_id']) defer.returnValue(None) + + @defer.inlineCallbacks + def count_all_users(self): + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..e9165d2917 --- /dev/null +++ b/synapse/storage/schema/delta/24/stats_reporting.sql @@ -0,0 +1,22 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Should only ever contain one row +CREATE TABLE IF NOT EXISTS stats_reporting( + -- The stream ordering token which was most recently reported as stats + reported_stream_token INTEGER, + -- The time (seconds since epoch) stats were most recently reported + reported_time BIGINT +); diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py new file mode 100644 index 0000000000..42bd8928bd --- /dev/null +++ b/tests/storage/event_injector.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID + +from tests.utils import setup_test_homeserver + +from mock import Mock + + +class EventInjector: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.message_handler = hs.get_handlers().message_handler + self.event_builder_factory = hs.get_event_builder_factory() + + @defer.inlineCallbacks + def create_room(self, room): + builder = self.event_builder_factory.new({ + "type": EventTypes.Create, + "room_id": room.to_string(), + "content": {}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + @defer.inlineCallbacks + def inject_room_member(self, room, user, membership): + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, room, user, body): + builder = self.event_builder_factory.new({ + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..313013009e --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import uuid +from mock.mock import Mock +from synapse.types import RoomID, UserID + +from tests import unittest +from twisted.internet import defer +from tests.storage.event_injector import EventInjector + +from tests.utils import setup_test_homeserver + + +class EventsStoreTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + resource_for_federation=Mock(), + http_client=None, + ) + self.store = self.hs.get_datastore() + self.db_pool = self.hs.get_db_pool() + self.message_handler = self.hs.get_handlers().message_handler + self.event_injector = EventInjector(self.hs) + + @defer.inlineCallbacks + def test_count_daily_messages(self): + self.db_pool.runQuery("DELETE FROM stats_reporting") + + self.hs.clock.now = 100 + + # Never reported before, and nothing which could be reported + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting") + self.assertEqual([(0,)], count) + + # Create something to report + room = RoomID.from_string("!abc123:test") + user = UserID.from_string("@raccoonlover:test") + yield self.event_injector.create_room(room) + + self.base_event = yield self._get_last_stream_token() + + yield self.event_injector.inject_message(room, user, "Raccoons are really cute") + + # Never reported before, something could be reported, but isn't because + # it isn't old enough. + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(1, self.hs.clock.now) + + # Already reported yesterday, two new events from today. + yield self.event_injector.inject_message(room, user, "Yeah they are!") + yield self.event_injector.inject_message(room, user, "Incredibly!") + self.hs.clock.now += 60 * 60 * 24 + count = yield self.store.count_daily_messages() + self.assertEqual(2, count) # 2 since yesterday + self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + + # Last reported too recently. + yield self.event_injector.inject_message(room, user, "Who could disagree?") + self.hs.clock.now += 60 * 60 * 22 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(4, self.hs.clock.now) + + # Last reported too long ago + yield self.event_injector.inject_message(room, user, "No one.") + self.hs.clock.now += 60 * 60 * 26 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(5, self.hs.clock.now) + + # And now let's actually report something + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + # A little over 24 hours is fine :) + self.hs.clock.now += (60 * 60 * 24) + 50 + count = yield self.store.count_daily_messages() + self.assertEqual(3, count) + self._assert_stats_reporting(8, self.hs.clock.now) + + @defer.inlineCallbacks + def _get_last_stream_token(self): + rows = yield self.db_pool.runQuery( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + if not rows: + defer.returnValue(0) + else: + defer.returnValue(rows[0][0]) + + @defer.inlineCallbacks + def _assert_stats_reporting(self, messages, time): + rows = yield self.db_pool.runQuery( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + self.assertEqual([(self.base_event + messages, time,)], rows) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ab7625a3ca..caffce64e3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() - self.event_factory = hs.get_event_factory(); + self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0c9b89d765..a658a789aa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID +from tests.storage.event_injector import EventInjector from tests.utils import setup_test_homeserver @@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() + self.event_injector = EventInjector(hs) self.handlers = hs.get_handlers() self.message_handler = self.handlers.message_handler @@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") self.room2 = RoomID.from_string("!xyx987:test") - self.depth = 1 - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - @defer.inlineCallbacks def test_event_stream_get_other(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_get_own(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_join_leave(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Then bob leaves again. - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.LEAVE ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_prev_content(self): - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.inject_room_member( + event1 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.inject_room_member( + event2 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) -- cgit 1.5.1 From eb011cd99ba03f40f4ed7a023b64f93dfa2cbdc9 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 13:29:36 +0100 Subject: Add docstring --- synapse/storage/events.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2b51db9940..46df6b4d6d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -908,6 +908,12 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ def _count_messages(txn): now = self.hs.get_clock().time() -- cgit 1.5.1 From 6d59ffe1ce9a821d50d491f97bf05950198f6f53 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 13:47:40 +0100 Subject: Add some docstrings --- synapse/storage/__init__.py | 3 +++ synapse/storage/registration.py | 1 + 2 files changed, 4 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b64c90d631..340e59afcb 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -128,6 +128,9 @@ class DataStore(RoomMemberStore, RoomStore, @defer.inlineCallbacks def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ def _count_users(txn): txn.execute( "SELECT COUNT(DISTINCT user_id) AS users" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 6d76237658..b454dd5b3a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -292,6 +292,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def count_all_users(self): + """Counts all users registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.cursor_to_dict(txn) -- cgit 1.5.1 From 527d95dea0d55bd1932639e61387ef60d834134e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 22 Sep 2015 18:14:15 +0100 Subject: synapse/storage/_base.py:Table was unused --- synapse/storage/_base.py | 128 ---------------------------------------------- synapse/storage/pusher.py | 4 +- 2 files changed, 2 insertions(+), 130 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 495ef087c9..c1b5423bd6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple - import sys import time import threading @@ -791,129 +789,3 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass - - -class Table(object): - """ A base class used to store information about a particular table. - """ - - table_name = None - """ str: The name of the table """ - - fields = None - """ list: The field names """ - - EntryType = None - """ Type: A tuple type used to decode the results """ - - _select_where_clause = "SELECT %s FROM %s WHERE %s" - _select_clause = "SELECT %s FROM %s" - _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" - - @classmethod - def select_statement(cls, where_clause=None): - """ - Args: - where_clause (str): The WHERE clause to use. - - Returns: - str: An SQL statement to select rows from the table with the given - WHERE clause. - """ - if where_clause: - return cls._select_where_clause % ( - ", ".join(cls.fields), - cls.table_name, - where_clause - ) - else: - return cls._select_clause % ( - ", ".join(cls.fields), - cls.table_name, - ) - - @classmethod - def insert_statement(cls): - return cls._insert_clause % ( - cls.table_name, - ", ".join(cls.fields), - ", ".join(["?"] * len(cls.fields)), - ) - - @classmethod - def decode_single_result(cls, results): - """ Given an iterable of tuples, return a single instance of - `EntryType` or None if the iterable is empty - Args: - results (list): The results list to convert to `EntryType` - Returns: - EntryType: An instance of `EntryType` - """ - results = list(results) - if results: - return cls.EntryType(*results[0]) - else: - return None - - @classmethod - def decode_results(cls, results): - """ Given an iterable of tuples, return a list of `EntryType` - Args: - results (list): The results list to convert to `EntryType` - - Returns: - list: A list of `EntryType` - """ - return [cls.EntryType(*row) for row in results] - - @classmethod - def get_fields_string(cls, prefix=None): - if prefix: - to_join = ("%s.%s" % (prefix, f) for f in cls.fields) - else: - to_join = cls.fields - - return ", ".join(to_join) - - -class JoinHelper(object): - """ Used to help do joins on tables by looking at the tables' fields and - creating a list of unique fields to use with SELECTs and a namedtuple - to dump the results into. - - Attributes: - tables (list): List of `Table` classes - EntryType (type) - """ - - def __init__(self, *tables): - self.tables = tables - - res = [] - for table in self.tables: - res += [f for f in table.fields if f not in res] - - self.EntryType = namedtuple("JoinHelperEntry", res) - - def get_fields(self, **prefixes): - """Get a string representing a list of fields for use in SELECT - statements with the given prefixes applied to each. - - For example:: - - JoinHelper(PdusTable, StateTable).get_fields( - PdusTable="pdus", - StateTable="state" - ) - """ - res = [] - for field in self.EntryType._fields: - for table in self.tables: - if field in table.fields: - res.append("%s.%s" % (prefixes[table.__name__], field)) - break - - return ", ".join(res) - - def decode_results(self, rows): - return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 00b748f131..345c4e1104 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore from twisted.internet import defer from synapse.api.errors import StoreError @@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore): ) -class PushersTable(Table): +class PushersTable(object): table_name = "pushers" -- cgit 1.5.1 From 7dd4f79c49e1a1ba4cf2edf8b45ed841a32a33b0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 22 Sep 2015 18:37:07 +0100 Subject: synapse/storage/_base.py:_execute_and_decode was unused --- synapse/storage/_base.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c1b5423bd6..cf4ec30f48 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -374,9 +374,6 @@ class SQLBaseStore(object): return self.runInteraction(desc, interaction) - def _execute_and_decode(self, desc, query, *args): - return self._execute(desc, self.cursor_to_dict, query, *args) - # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. -- cgit 1.5.1 From 3559a835a22be3ae474fc2a14363a59b89ec05b8 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 22 Sep 2015 18:39:46 +0100 Subject: synapse/storage/event_federation.py:_get_auth_events is unused --- synapse/storage/event_federation.py | 21 --------------------- 1 file changed, 21 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index c1cabbaa60..7ed0d9ae1b 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -225,27 +225,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_auth_events(self, txn, event_id): - auth_ids = self._simple_select_onecol_txn( - txn, - table="event_auth", - keyvalues={ - "event_id": event_id, - }, - retcol="auth_id", - ) - - results = [] - for auth_id in auth_ids: - hashes = self._get_event_reference_hashes_txn(txn, auth_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((auth_id, prev_hashes)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ -- cgit 1.5.1 From 1ee3d26432d87ff312350f21da982f646b5af49a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:30:03 +0100 Subject: synapse/storage/_base.py:_simple_selectupdate_one was unused --- synapse/storage/_base.py | 31 ------------------------------- tests/storage/test_base.py | 20 -------------------- 2 files changed, 51 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index cf4ec30f48..79021bde6b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -686,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 8573f18b55..1ddca1da4c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): [3, 4, 1, 2] ) - @defer.inlineCallbacks - def test_update_one_with_return(self): - self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Old Value",) - - ret = yield self.datastore._simple_selectupdate_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columname": "New Value"}, - retcols=["columname"] - ) - - self.assertEquals({"columname": "Old Value"}, ret) - self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', - ['TheKey']), - call("UPDATE tablename SET columname = ? WHERE keycol = ?", - ["New Value", "TheKey"]) - ]) - @defer.inlineCallbacks def test_delete_one(self): self.mock_txn.rowcount = 1 -- cgit 1.5.1 From 1d9036aff2be9ce089114b1389cc86a34ccd8490 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:30:25 +0100 Subject: synapse/storage/_base.py:_simple_delete was unused --- synapse/storage/_base.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 79021bde6b..33751f3092 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -707,16 +707,6 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete(self, table, keyvalues, desc="_simple_delete"): - """Executes a DELETE query on the named table. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - - return self.runInteraction(desc, self._simple_delete_txn) - def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, -- cgit 1.5.1 From 396834f1c010899cae87aac45143bbd9911adb4f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:30:38 +0100 Subject: synapse/storage/_base.py:_simple_max_id was unused --- synapse/storage/_base.py | 18 ------------------ 1 file changed, 18 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 33751f3092..693784ad38 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -715,24 +715,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def _simple_max_id(self, table): - """Executes a SELECT query on the named table, expecting to return the - max value for the column "id". - - Args: - table : string giving the table name - """ - sql = "SELECT MAX(id) AS id FROM %s" % table - - def func(txn): - txn.execute(sql) - max_id = self.cursor_to_dict(txn)[0]["id"] - if max_id is None: - return 0 - return max_id - - return self.runInteraction("_simple_max_id", func) - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id -- cgit 1.5.1 From c292dba70cca26751cdab90a6e50200828959642 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:31:25 +0100 Subject: Remove unused functions from synapse/storage/event_federation.py --- synapse/storage/event_federation.py | 71 ------------------------------------- 1 file changed, 71 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 7ed0d9ae1b..6d4421dd8f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -154,77 +154,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_latest_state_in_room(self, txn, room_id, type, state_key): - event_ids = self._simple_select_onecol_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "room_id": room_id, - "type": type, - "state_key": state_key, - }, - retcol="event_id", - ) - - results = [] - for event_id in event_ids: - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((event_id, prev_hashes)) - - return results - - def _get_prev_events(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=0, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_state(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=True, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_events_and_state(self, txn, event_id, is_state=None): - keyvalues = { - "event_id": event_id, - } - - if is_state is not None: - keyvalues["is_state"] = bool(is_state) - - res = self._simple_select_list_txn( - txn, - table="event_edges", - keyvalues=keyvalues, - retcols=["prev_event_id", "is_state"], - ) - - hashes = self._get_prev_event_hashes_txn(txn, event_id) - - results = [] - for d in res: - edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) - edge_hash.update(hashes.get(d["prev_event_id"], {})) - prev_hashes = { - k: encode_base64(v) - for k, v in edge_hash.items() - if k == "sha256" - } - results.append((d["prev_event_id"], prev_hashes, d["is_state"])) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ -- cgit 1.5.1 From 92d8d724c5c34f0a83cbd8c5dce7f0c0c21a1568 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:33:06 +0100 Subject: Remove unused functions from synapse/storage/events.py --- synapse/storage/events.py | 11 ----------- 1 file changed, 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 46df6b4d6d..416ef6af93 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -890,22 +890,11 @@ class EventsStore(SQLBaseStore): return ev - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - def _parse_events_txn(self, txn, rows): event_ids = [r["event_id"] for r in rows] return self._get_events_txn(txn, event_ids) - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - @defer.inlineCallbacks def count_daily_messages(self): """ -- cgit 1.5.1 From e51aa4be9691e7c6918d72810ee0b751f2e48797 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:35:10 +0100 Subject: synapse/storage/roommember.py:_get_members_query was unused --- synapse/storage/roommember.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index cd9eefbd9f..e17cbe677a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -178,12 +178,6 @@ class RoomMemberStore(SQLBaseStore): return joined_domains - def _get_members_query(self, where_clause, where_values): - return self.runInteraction( - "get_members_query", self._get_members_events_txn, - where_clause, where_values - ).addCallbacks(self._get_events) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, -- cgit 1.5.1 From 973ebb66bacfeece1f88f8ea71b86186f9c9163e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:36:33 +0100 Subject: Remove unused functions from synapse/storage/signatures.py --- synapse/storage/signatures.py | 112 ------------------------------------------ 1 file changed, 112 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ab57b92174..b070be504d 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - def _get_event_content_hashes_txn(self, txn, event_id): - """Get all the hashes for a given Event. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_content_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - return dict(txn.fetchall()) - - def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): - """Store a hash for a Event - Args: - txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. - """ - self._simple_insert_txn( - txn, - "event_content_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) - def get_event_reference_hashes(self, event_ids): def f(txn): return [ @@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore): table="event_reference_hashes", values=vals, ) - - def _get_event_signatures_txn(self, txn, event_id): - """Get all the signatures for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of sig name -> dict(key_id -> signature_bytes) - """ - query = ( - "SELECT signature_name, key_id, signature" - " FROM event_signatures" - " WHERE event_id = ? " - ) - txn.execute(query, (event_id, )) - rows = txn.fetchall() - - res = {} - - for name, key, sig in rows: - res.setdefault(name, {})[key] = sig - - return res - - def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): - """Store a signature from the origin server for a PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - origin (str): origin of the Event. - key_id (str): Id for the signing key. - signature (bytes): The signature. - """ - self._simple_insert_txn( - txn, - "event_signatures", - { - "event_id": event_id, - "signature_name": signature_name, - "key_id": key_id, - "signature": buffer(signature_bytes), - }, - ) - - def _get_prev_event_hashes_txn(self, txn, event_id): - """Get all the hashes for previous PDUs of a PDU - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. - """ - query = ( - "SELECT prev_event_id, algorithm, hash" - " FROM event_edge_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - results = {} - for prev_event_id, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault(prev_event_id, {}) - hashes[algorithm] = hash_bytes - return results - - def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): - self._simple_insert_txn( - txn, - "event_edge_hashes", - { - "event_id": event_id, - "prev_event_id": prev_event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) -- cgit 1.5.1 From 1cd65a8d1e9fe30506aa7f19c60e6271a24e6ae9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:37:58 +0100 Subject: synapse/storage/state.py: _make_group_id was unused --- synapse/storage/state.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9630efcfcc..e935b9443b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import ( from twisted.internet import defer -from synapse.util.stringutils import random_string - import logging logger = logging.getLogger(__name__) @@ -428,7 +426,3 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) - - -def _make_group_id(clock): - return str(int(clock.time_msec())) + random_string(5) -- cgit 1.5.1 From cf1100887b454535e25dfeb67d649c2a4673eab7 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 24 Sep 2015 17:34:02 +0100 Subject: Fix order of ON constraints in _get_rooms_for_user_where_membership_is_txn --- synapse/storage/roommember.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index cd9eefbd9f..41c939efb1 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,12 +142,12 @@ class RoomMemberStore(SQLBaseStore): sql = ( "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON e.event_id = c.event_id " - " INNER JOIN events as e " - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" " WHERE %s" ) % (where_clause,) -- cgit 1.5.1 From c85c9125627a62c73711786723be12be30d7a81e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Oct 2015 15:48:31 +0100 Subject: Add basic full text search impl. --- synapse/api/constants.py | 19 +++++++ synapse/handlers/__init__.py | 2 + synapse/handlers/search.py | 95 ++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/room.py | 17 ++++++ synapse/storage/__init__.py | 2 + synapse/storage/_base.py | 2 +- synapse/storage/schema/delta/24/fts.py | 57 ++++++++++++++++++++ synapse/storage/search.py | 75 +++++++++++++++++++++++++++ 8 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 synapse/handlers/search.py create mode 100644 synapse/storage/schema/delta/24/fts.py create mode 100644 synapse/storage/search.py (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 008ee64727..7c7f9ff957 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -84,3 +84,22 @@ class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + + +class SearchConstraintTypes(object): + FTS = "fts" + EXACT = "exact" + PREFIX = "prefix" + SUBSTRING = "substring" + RANGE = "range" + + +class KnownRoomEventKeys(object): + CONTENT_BODY = "content.body" + CONTENT_MSGTYPE = "content.msgtype" + CONTENT_NAME = "content.name" + CONTENT_TOPIC = "content.topic" + + SENDER = "sender" + ORIGIN_SERVER_TS = "origin_server_ts" + ROOM_ID = "room_id" diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 8725c3c420..87b4d381c7 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -32,6 +32,7 @@ from .sync import SyncHandler from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler +from .search import SearchHandler class Handlers(object): @@ -68,3 +69,4 @@ class Handlers(object): self.sync_handler = SyncHandler(hs) self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) + self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py new file mode 100644 index 0000000000..8b997fc394 --- /dev/null +++ b/synapse/handlers/search.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes +from synapse.api.errors import SynapseError +from synapse.events.utils import serialize_event + +import logging + + +logger = logging.getLogger(__name__) + + +KEYS_TO_ALLOWED_CONSTRAINT_TYPES = { + KnownRoomEventKeys.CONTENT_BODY: [SearchConstraintTypes.FTS], + KnownRoomEventKeys.CONTENT_MSGTYPE: [SearchConstraintTypes.EXACT], + KnownRoomEventKeys.CONTENT_NAME: [SearchConstraintTypes.FTS, SearchConstraintTypes.EXACT, SearchConstraintTypes.SUBSTRING], + KnownRoomEventKeys.CONTENT_TOPIC: [SearchConstraintTypes.FTS], + KnownRoomEventKeys.SENDER: [SearchConstraintTypes.EXACT], + KnownRoomEventKeys.ORIGIN_SERVER_TS: [SearchConstraintTypes.RANGE], + KnownRoomEventKeys.ROOM_ID: [SearchConstraintTypes.EXACT], +} + + +class RoomConstraint(object): + def __init__(self, search_type, keys, value): + self.search_type = search_type + self.keys = keys + self.value = value + + @classmethod + def from_dict(cls, d): + search_type = d["type"] + keys = d["keys"] + + for key in keys: + if key not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES: + raise SynapseError(400, "Unrecognized key %r", key) + + if search_type not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES[key]: + raise SynapseError(400, "Disallowed constraint type %r for key %r", search_type, key) + + return cls(search_type, keys, d["value"]) + + +class SearchHandler(BaseHandler): + + def __init__(self, hs): + super(SearchHandler, self).__init__(hs) + + @defer.inlineCallbacks + def search(self, content): + constraint_dicts = content["search_categories"]["room_events"]["constraints"] + constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts] + + fts = False + for c in constraints: + if c.search_type == SearchConstraintTypes.FTS: + if fts: + raise SynapseError(400, "Only one constraint can be FTS") + fts = True + + res = yield self.hs.get_datastore().search_msgs(constraints) + + time_now = self.hs.get_clock().time_msec() + + results = [ + { + "rank": r["rank"], + "result": serialize_event(r["result"], time_now) + } + for r in res + ] + + logger.info("returning: %r", results) + + results.sort(key=lambda r: -r["rank"]) + + defer.returnValue(results) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 23871f161e..35bd702a43 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -529,6 +529,22 @@ class RoomTypingRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) +class SearchRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/search$" + ) + + @defer.inlineCallbacks + def on_POST(self, request): + auth_user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + results = yield self.handlers.search_handler.search(content) + + defer.returnValue((200, results)) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -585,3 +601,4 @@ def register_servlets(hs, http_server): RoomInitialSyncRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 340e59afcb..5f91ef77c0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -40,6 +40,7 @@ from .filtering import FilteringStore from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore +from .search import SearchStore import fnmatch @@ -79,6 +80,7 @@ class DataStore(RoomMemberStore, RoomStore, EventsStore, ReceiptsStore, EndToEndKeyStore, + SearchStore, ): def __init__(self, hs): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 693784ad38..218e708054 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -519,7 +519,7 @@ class SQLBaseStore(object): allow_none=False, desc="_simple_select_one_onecol"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it." + return a single row, returning a single column from it. Args: table : string giving the table name diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py new file mode 100644 index 0000000000..5680332758 --- /dev/null +++ b/synapse/storage/schema/delta/24/fts.py @@ -0,0 +1,57 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage import get_statements +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +POSTGRES_SQL = """ +CREATE TABLE event_search ( + event_id TEXT, + room_id TEXT, + key TEXT, + vector tsvector +); + +INSERT INTO event_search SELECT + event_id, room_id, 'content.body', + to_tsvector('english', json::json->'content'->>'body') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.name', + to_tsvector('english', json::json->'content'->>'name') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.topic', + to_tsvector('english', json::json->'content'->>'topic') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; + + +CREATE INDEX event_search_idx ON event_search USING gin(vector); +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + # We only support FTS for postgres currently. + return + + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) diff --git a/synapse/storage/search.py b/synapse/storage/search.py new file mode 100644 index 0000000000..eea4477765 --- /dev/null +++ b/synapse/storage/search.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from _base import SQLBaseStore +from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes + + +class SearchStore(SQLBaseStore): + @defer.inlineCallbacks + def search_msgs(self, constraints): + clauses = [] + args = [] + fts = None + + for c in constraints: + local_clauses = [] + if c.search_type == SearchConstraintTypes.FTS: + fts = c.value + for key in c.keys: + local_clauses.append("key = ?") + args.append(key) + elif c.search_type == SearchConstraintTypes.EXACT: + for key in c.keys: + if key == KnownRoomEventKeys.ROOM_ID: + for value in c.value: + local_clauses.append("room_id = ?") + args.append(value) + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) + + sql = ( + "SELECT ts_rank_cd(vector, query) AS rank, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " WHERE vector @@ query" + ) + + for clause in clauses: + sql += " AND " + clause + + sql += " ORDER BY rank DESC" + + results = yield self._execute( + "search_msgs", self.cursor_to_dict, sql, *([fts] + args) + ) + + events = yield self._get_events([r["event_id"] for r in results]) + + event_map = { + ev.event_id: ev + for ev in events + } + + defer.returnValue([ + { + "rank": r["rank"], + "result": event_map[r["event_id"]] + } + for r in results + if r["event_id"] in event_map + ]) -- cgit 1.5.1 From 61561b9df791ec90e287e535cc75831c2016bf36 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Oct 2015 10:49:53 +0100 Subject: Keep FTS indexes up to date. Only search through rooms currently joined --- synapse/handlers/search.py | 31 ++++++++++++++++++++++--------- synapse/rest/client/v1/room.py | 2 +- synapse/storage/events.py | 2 ++ synapse/storage/room.py | 22 ++++++++++++++++++++++ synapse/storage/schema/delta/24/fts.py | 3 ++- synapse/storage/search.py | 7 ++++++- 6 files changed, 55 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 8b997fc394..b6bdb752e9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -65,7 +65,7 @@ class SearchHandler(BaseHandler): super(SearchHandler, self).__init__(hs) @defer.inlineCallbacks - def search(self, content): + def search(self, user, content): constraint_dicts = content["search_categories"]["room_events"]["constraints"] constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts] @@ -76,20 +76,33 @@ class SearchHandler(BaseHandler): raise SynapseError(400, "Only one constraint can be FTS") fts = True - res = yield self.hs.get_datastore().search_msgs(constraints) + rooms = yield self.store.get_rooms_for_user( + user.to_string(), + ) - time_now = self.hs.get_clock().time_msec() + # For some reason the list of events contains duplicates + # TODO(paul): work out why because I really don't think it should + room_ids = set(r.room_id for r in rooms) - results = [ - { + res = yield self.store.search_msgs(room_ids, constraints) + + time_now = self.clock.time_msec() + + results = { + r["result"].event_id: { "rank": r["rank"], "result": serialize_event(r["result"], time_now) } for r in res - ] + } logger.info("returning: %r", results) - results.sort(key=lambda r: -r["rank"]) - - defer.returnValue(results) + defer.returnValue({ + "search_categories": { + "room_events": { + "results": results, + "count": len(results) + } + } + }) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 35bd702a43..94adabca62 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -540,7 +540,7 @@ class SearchRestServlet(ClientV1RestServlet): content = _parse_json(request) - results = yield self.handlers.search_handler.search(content) + results = yield self.handlers.search_handler.search(auth_user, content) defer.returnValue((200, results)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 416ef6af93..e6c1abfc27 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -307,6 +307,8 @@ class EventsStore(SQLBaseStore): self._store_room_name_txn(txn, event) elif event.type == EventTypes.Topic: self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + self._store_room_message_txn(txn, event) elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 5e07b7e0e5..e4e830944a 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -175,6 +175,10 @@ class RoomStore(SQLBaseStore): }, ) + self._store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) + def _store_room_name_txn(self, txn, event): if hasattr(event, "content") and "name" in event.content: self._simple_insert_txn( @@ -187,6 +191,24 @@ class RoomStore(SQLBaseStore): } ) + self._store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) + + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self._store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) + + def _store_event_search_txn(self, txn, event, key, value): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, vector)" + " VALUES (?,?,?,to_tsvector('english', ?))" + ) + + txn.execute(sql, (event.event_id, event.room_id, key, value,)) + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py index 5680332758..05f1605fdd 100644 --- a/synapse/storage/schema/delta/24/fts.py +++ b/synapse/storage/schema/delta/24/fts.py @@ -44,7 +44,8 @@ INSERT INTO event_search SELECT FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; -CREATE INDEX event_search_idx ON event_search USING gin(vector); +CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); +CREATE INDEX event_search_ev_idx ON event_search(event_id); """ diff --git a/synapse/storage/search.py b/synapse/storage/search.py index eea4477765..e66b5f9edc 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -21,11 +21,16 @@ from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes class SearchStore(SQLBaseStore): @defer.inlineCallbacks - def search_msgs(self, constraints): + def search_msgs(self, room_ids, constraints): clauses = [] args = [] fts = None + clauses.append( + "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) + ) + args.extend(room_ids) + for c in constraints: local_clauses = [] if c.search_type == SearchConstraintTypes.FTS: -- cgit 1.5.1 From 927004e34905d4ad6a69576ee1799fe8019d8985 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Oct 2015 15:06:14 +0100 Subject: Remove unused room_id parameter --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 10 ++++----- synapse/handlers/search.py | 50 +++++++++++++++++++++++++++++++++++++++++- synapse/handlers/sync.py | 2 +- synapse/storage/state.py | 11 +++++----- 5 files changed, 61 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3882ba79ed..a710bdcfdb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -242,7 +242,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): event_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 30949ff7a6..d2f0892f7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -164,7 +164,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): event_id_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), @@ -290,7 +290,7 @@ class MessageHandler(BaseHandler): elif member_event.membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - room_id, [member_event.event_id], [key] + [member_event.event_id], [key] ) data = room_state[member_event.event_id].get(key) @@ -314,7 +314,7 @@ class MessageHandler(BaseHandler): room_state = yield self.state_handler.get_current_state(room_id) elif member_event.membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - room_id, [member_event.event_id], None + [member_event.event_id], None ) room_state = room_state[member_event.event_id] @@ -403,7 +403,7 @@ class MessageHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = self.store.get_state_for_events( - event.room_id, [event.event_id], None + [event.event_id], None ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -496,7 +496,7 @@ class MessageHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, member_event): room_state = yield self.store.get_state_for_events( - member_event.room_id, [member_event.event_id], None + [member_event.event_id], None ) room_state = room_state[member_event.event_id] diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9dc474aa56..71182a8fe0 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -17,7 +17,9 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes +from synapse.api.constants import ( + EventTypes, KnownRoomEventKeys, Membership, SearchConstraintTypes +) from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event @@ -71,6 +73,52 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, room_id, events): + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + ) + + def allowed(event, state): + if event.type == EventTypes.RoomHistoryVisibility: + return True + + membership_ev = state.get((EventTypes.Member, user_id), None) + if membership_ev: + membership = membership_ev.membership + else: + membership = Membership.LEAVE + + if membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history: + visibility = history.content.get("history_visibility", "shared") + else: + visibility = "shared" + + if visibility == "public": + return True + elif visibility == "shared": + return True + elif visibility == "joined": + return membership == Membership.JOIN + elif visibility == "invited": + return membership == Membership.INVITE + + return True + + defer.returnValue([ + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) + ]) + @defer.inlineCallbacks def search(self, user, content): constraint_dicts = content["search_categories"]["room_events"]["constraints"] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9914ff6f9c..a8940de166 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -312,7 +312,7 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): event_id_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e935b9443b..acfb322a53 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -54,7 +54,7 @@ class StateStore(SQLBaseStore): defer.returnValue({}) event_to_groups = yield self._get_state_group_for_events( - room_id, event_ids, + event_ids, ) groups = set(event_to_groups.values()) @@ -208,13 +208,12 @@ class StateStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids, types): + def get_state_for_events(self, event_ids, types): """Given a list of event_ids and type tuples, return a list of state dicts for each event. The state dicts will only have the type/state_keys that are in the `types` list. Args: - room_id (str) event_ids (list) types (list): List of (type, state_key) tuples which are used to filter the state fetched. `state_key` may be None, which matches @@ -225,7 +224,7 @@ class StateStore(SQLBaseStore): The dicts are mappings from (type, state_key) -> state_events """ event_to_groups = yield self._get_state_group_for_events( - room_id, event_ids, + event_ids, ) groups = set(event_to_groups.values()) @@ -251,8 +250,8 @@ class StateStore(SQLBaseStore): ) @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=2) - def _get_state_group_for_events(self, room_id, event_ids): + num_args=1) + def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ def f(txn): -- cgit 1.5.1 From ca53ad74250d94b8c9b6581e6cedef0a29520fc2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Oct 2015 15:52:55 +0100 Subject: Filter events to only thsoe that the user is allowed to see --- synapse/handlers/search.py | 16 ++++++++++------ synapse/storage/search.py | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 71182a8fe0..49b786dadb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -74,7 +74,7 @@ class SearchHandler(BaseHandler): super(SearchHandler, self).__init__(hs) @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, room_id, events): + def _filter_events_for_client(self, user_id, events): event_id_to_state = yield self.store.get_state_for_events( frozenset(e.event_id for e in events), types=( @@ -139,16 +139,20 @@ class SearchHandler(BaseHandler): # TODO(paul): work out why because I really don't think it should room_ids = set(r.room_id for r in rooms) - res = yield self.store.search_msgs(room_ids, constraints) + rank_map, event_map = yield self.store.search_msgs(room_ids, constraints) + + allowed_events = yield self._filter_events_for_client( + user.to_string(), event_map.values() + ) time_now = self.clock.time_msec() results = { - r["result"].event_id: { - "rank": r["rank"], - "result": serialize_event(r["result"], time_now) + e.event_id: { + "rank": rank_map[e.event_id], + "result": serialize_event(e, time_now) } - for r in res + for e in allowed_events } logger.info("returning: %r", results) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index e66b5f9edc..238df38440 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -70,11 +70,11 @@ class SearchStore(SQLBaseStore): for ev in events } - defer.returnValue([ + defer.returnValue(( { - "rank": r["rank"], - "result": event_map[r["event_id"]] - } - for r in results - if r["event_id"] in event_map - ]) + r["event_id"]: r["rank"] + for r in results + if r["event_id"] in event_map + }, + event_map + )) -- cgit 1.5.1 From 1a40afa75693f0c2ae3b2eaac62ff9ca6bb02488 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 10:36:25 +0100 Subject: Add sqlite schema --- synapse/storage/schema/delta/24/fts.py | 69 +++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py index 05f1605fdd..a806f4b8d3 100644 --- a/synapse/storage/schema/delta/24/fts.py +++ b/synapse/storage/schema/delta/24/fts.py @@ -15,7 +15,9 @@ import logging from synapse.storage import get_statements -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +import ujson logger = logging.getLogger(__name__) @@ -46,13 +48,70 @@ INSERT INTO event_search SELECT CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); CREATE INDEX event_search_ev_idx ON event_search(event_id); +CREATE INDEX event_search_ev_ridx ON event_search(room_id); """ +SQLITE_TABLE = ( + "CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)" +) +SQLITE_INDEX = "CREATE INDEX event_search_ev_idx ON event_search(event_id)" + + def run_upgrade(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - # We only support FTS for postgres currently. + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) return - for statement in get_statements(POSTGRES_SQL.splitlines()): - cur.execute(statement) + if isinstance(database_engine, Sqlite3Engine): + cur.execute(SQLITE_TABLE) + + rowid = -1 + while True: + cur.execute( + "SELECT rowid, json FROM event_json" + " WHERE rowid > ?" + " ORDER BY rowid ASC LIMIT 100", + (rowid,) + ) + + res = cur.fetchall() + + if not res: + break + + events = [ + ujson.loads(js) + for _, js in res + ] + + rowid = max(rid for rid, _ in res) + + rows = [] + for ev in events: + if ev["type"] == "m.room.message": + rows.append(( + ev["event_id"], ev["room_id"], "content.body", + ev["content"]["body"] + )) + if ev["type"] == "m.room.name": + rows.append(( + ev["event_id"], ev["room_id"], "content.name", + ev["content"]["name"] + )) + if ev["type"] == "m.room.topic": + rows.append(( + ev["event_id"], ev["room_id"], "content.topic", + ev["content"]["topic"] + )) + + if rows: + logger.info(rows) + cur.executemany( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)", + rows + ) + + # cur.execute(SQLITE_INDEX) -- cgit 1.5.1 From 40b6a5aad1309fed9d1e32be387798dd46b2cf4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:38:48 +0100 Subject: Split out the schema preparation and update logic into its own module --- synapse/storage/__init__.py | 378 +--------------------------------- synapse/storage/_schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 4 +- 4 files changed, 402 insertions(+), 377 deletions(-) create mode 100644 synapse/storage/_schema_prepare.py (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 340e59afcb..4be629bff8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,23 +41,16 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore +from ._schema_prepare import UpgradeDatabaseException + +__all__ = [UpgradeDatabaseException] -import fnmatch -import imp import logging -import os -import re logger = logging.getLogger(__name__) -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits # 120 seconds == 2 minutes @@ -158,371 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, ) -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) - - def are_all_users_on_domain(txn, database_engine, domain): sql = database_engine.convert_param_style( "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/_schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 4a855ffd56..949396044e 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import prepare_database +from synapse.storage._schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index d18e2808d1..a66815ef2d 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import prepare_database, prepare_sqlite3_database +from synapse.storage._schema_prepare import ( + prepare_database, prepare_sqlite3_database +) class Sqlite3Engine(object): -- cgit 1.5.1 From ec398af41c4d276abb02279efbcbb0aa08a4cbc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:41:04 +0100 Subject: Expose error more nicely --- synapse/app/homeserver.py | 5 +- synapse/storage/__init__.py | 3 - synapse/storage/_schema_prepare.py | 395 ------------------------------------ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ tests/utils.py | 2 +- 7 files changed, 400 insertions(+), 404 deletions(-) delete mode 100644 synapse/storage/_schema_prepare.py create mode 100644 synapse/storage/schema_prepare.py (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 190b03e2f7..b284d07cf0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -35,9 +35,8 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import ( - are_all_users_on_domain, UpgradeDatabaseException, -) +from synapse.storage import are_all_users_on_domain +from synapse.storage.schema_prepare import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4be629bff8..48a0633746 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,9 +41,6 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore -from ._schema_prepare import UpgradeDatabaseException - -__all__ = [UpgradeDatabaseException] import logging diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/_schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 949396044e..7e45dabf4c 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import prepare_database +from synapse.storage.schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index a66815ef2d..0eeaa45d19 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import ( +from synapse.storage.schema_prepare import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/tests/utils.py b/tests/utils.py index dd19a16fc7..6eb575bd09 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage import prepare_database +from synapse.storage.schema_prepare import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From cfd39d6b55fad5b176f1883e1bc87ed8e14acf42 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 13:47:50 +0100 Subject: Add SQLite support --- synapse/storage/search.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 238df38440..5843f80876 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -17,6 +17,7 @@ from twisted.internet import defer from _base import SQLBaseStore from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes +from synapse.storage.engines import PostgresEngine class SearchStore(SQLBaseStore): @@ -48,11 +49,17 @@ class SearchStore(SQLBaseStore): "(%s)" % (" OR ".join(local_clauses),) ) - sql = ( - "SELECT ts_rank_cd(vector, query) AS rank, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" - " WHERE vector @@ query" - ) + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, query) AS rank, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " WHERE vector @@ query" + ) + else: + sql = ( + "SELECT 0 as rank, event_id FROM event_search" + " WHERE value MATCH ?" + ) for clause in clauses: sql += " AND " + clause -- cgit 1.5.1 From 17c80c8a3d92acca5bda9b0fc7d9898547476563 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 13:56:22 +0100 Subject: rename schema_prepare to prepare_database --- synapse/app/homeserver.py | 2 +- synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/prepare_database.py | 395 ++++++++++++++++++++++++++++++++++++ synapse/storage/schema_prepare.py | 395 ------------------------------------ tests/utils.py | 2 +- 6 files changed, 399 insertions(+), 399 deletions(-) create mode 100644 synapse/storage/prepare_database.py delete mode 100644 synapse/storage/schema_prepare.py (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b284d07cf0..af53acb369 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -36,7 +36,7 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.schema_prepare import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 7e45dabf4c..98d66e0a86 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 0eeaa45d19..bad3b5c5ac 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import ( +from synapse.storage.prepare_database import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/prepare_database.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/tests/utils.py b/tests/utils.py index 6eb575bd09..4da51291a4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From cacf0688c691517dab55c3cff294b6bac7f0d6e3 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 13 Oct 2015 14:08:38 +0100 Subject: Add a get_invites_for_user method to the storage to find out the rooms a user is invited to --- synapse/handlers/sync.py | 8 ++------ synapse/storage/roommember.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d9e55d8a58..380798b7ab 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -297,12 +297,8 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) else: - invites = yield self.store.get_rooms_for_user_where_membership_is( - user_id=sync_config.user.to_string(), - membership_list=[Membership.INVITE], - ) - invite_events = yield self.store.get_events( - [invite.event_id for invite in invites] + invite_events = yield self.store.get_invites_for_user( + sync_config.user.to_string() ) for room_id in joined_room_ids: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c40d9a8a6..dd98dcfda8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -110,6 +110,20 @@ class RoomMemberStore(SQLBaseStore): membership=membership, ).addCallback(self._get_events) + def get_invites_for_user(self, user_id): + """ Get all the invite events for a user + Args: + user_id (str): The user ID. + Returns: + A deferred list of event objects. + """ + + return self.get_rooms_for_user_where_membership_is( + user_id, [Membership.INVITE] + ).addCallback(lambda invites: self._get_events([ + invites.event_id for invite in invites + ])) + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. -- cgit 1.5.1 From 3e2a1297b513dc1fadb288c74684f6651a88016d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 15:22:14 +0100 Subject: Remove constraints in preperation of using filters --- synapse/handlers/search.py | 61 ++++++++-------------------------------------- synapse/storage/search.py | 30 ++++++++--------------- 2 files changed, 20 insertions(+), 71 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index d5c395061c..8864a921fc 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import ( - EventTypes, KnownRoomEventKeys, Membership, SearchConstraintTypes + EventTypes, Membership, ) from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event @@ -29,45 +29,6 @@ import logging logger = logging.getLogger(__name__) -KEYS_TO_ALLOWED_CONSTRAINT_TYPES = { - KnownRoomEventKeys.CONTENT_BODY: [SearchConstraintTypes.FTS], - KnownRoomEventKeys.CONTENT_MSGTYPE: [SearchConstraintTypes.EXACT], - KnownRoomEventKeys.CONTENT_NAME: [ - SearchConstraintTypes.FTS, - SearchConstraintTypes.EXACT, - SearchConstraintTypes.SUBSTRING, - ], - KnownRoomEventKeys.CONTENT_TOPIC: [SearchConstraintTypes.FTS], - KnownRoomEventKeys.SENDER: [SearchConstraintTypes.EXACT], - KnownRoomEventKeys.ORIGIN_SERVER_TS: [SearchConstraintTypes.RANGE], - KnownRoomEventKeys.ROOM_ID: [SearchConstraintTypes.EXACT], -} - - -class RoomConstraint(object): - def __init__(self, search_type, keys, value): - self.search_type = search_type - self.keys = keys - self.value = value - - @classmethod - def from_dict(cls, d): - search_type = d["type"] - keys = d["keys"] - - for key in keys: - if key not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES: - raise SynapseError(400, "Unrecognized key %r", key) - - if search_type not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES[key]: - raise SynapseError( - 400, - "Disallowed constraint type %r for key %r", search_type, key - ) - - return cls(search_type, keys, d["value"]) - - class SearchHandler(BaseHandler): def __init__(self, hs): @@ -121,22 +82,20 @@ class SearchHandler(BaseHandler): @defer.inlineCallbacks def search(self, user, content): - constraint_dicts = content["search_categories"]["room_events"]["constraints"] - constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts] - - fts = False - for c in constraints: - if c.search_type == SearchConstraintTypes.FTS: - if fts: - raise SynapseError(400, "Only one constraint can be FTS") - fts = True + try: + search_term = content["search_categories"]["room_events"]["search_term"] + keys = content["search_categories"]["room_events"]["keys"] + except KeyError: + raise SynapseError(400, "Invalid search query") rooms = yield self.store.get_rooms_for_user_where_membership_is( - user.to_string(), membership_list=[Membership.JOIN, Membership.LEAVE], + user.to_string(), + membership_list=[Membership.JOIN], + # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) room_ids = set(r.room_id for r in rooms) - rank_map, event_map = yield self.store.search_msgs(room_ids, constraints) + rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys) allowed_events = yield self._filter_events_for_client( user.to_string(), event_map.values() diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 5843f80876..7a30ce25eb 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -16,38 +16,28 @@ from twisted.internet import defer from _base import SQLBaseStore -from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes from synapse.storage.engines import PostgresEngine class SearchStore(SQLBaseStore): @defer.inlineCallbacks - def search_msgs(self, room_ids, constraints): + def search_msgs(self, room_ids, search_term, keys): clauses = [] args = [] - fts = None clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) ) args.extend(room_ids) - for c in constraints: - local_clauses = [] - if c.search_type == SearchConstraintTypes.FTS: - fts = c.value - for key in c.keys: - local_clauses.append("key = ?") - args.append(key) - elif c.search_type == SearchConstraintTypes.EXACT: - for key in c.keys: - if key == KnownRoomEventKeys.ROOM_ID: - for value in c.value: - local_clauses.append("room_id = ?") - args.append(value) - clauses.append( - "(%s)" % (" OR ".join(local_clauses),) - ) + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -67,7 +57,7 @@ class SearchStore(SQLBaseStore): sql += " ORDER BY rank DESC" results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *([fts] + args) + "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) ) events = yield self._get_events([r["event_id"] for r in results]) -- cgit 1.5.1 From 7ecd11accb68cc0f20e7ab84673df38413ba7cf7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 15:50:56 +0100 Subject: Add paranoia limit --- synapse/storage/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 7a30ce25eb..1b987161e2 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -54,7 +54,7 @@ class SearchStore(SQLBaseStore): for clause in clauses: sql += " AND " + clause - sql += " ORDER BY rank DESC" + sql += " ORDER BY rank DESC LIMIT 500" results = yield self._execute( "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) -- cgit 1.5.1 From 858634e1d0ca489bb546851d6ee052d548870b06 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2015 09:29:08 +0100 Subject: Remove unused room_id arg --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 10 +++++----- synapse/handlers/sync.py | 2 +- synapse/storage/state.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3882ba79ed..a710bdcfdb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -242,7 +242,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): event_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b70258697b..dfeeae76db 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -164,7 +164,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): event_id_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), @@ -290,7 +290,7 @@ class MessageHandler(BaseHandler): elif member_event.membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - room_id, [member_event.event_id], [key] + [member_event.event_id], [key] ) data = room_state[member_event.event_id].get(key) @@ -314,7 +314,7 @@ class MessageHandler(BaseHandler): room_state = yield self.state_handler.get_current_state(room_id) elif member_event.membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - room_id, [member_event.event_id], None + [member_event.event_id], None ) room_state = room_state[member_event.event_id] @@ -406,7 +406,7 @@ class MessageHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = self.store.get_state_for_events( - event.room_id, [event.event_id], None + [event.event_id], None ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -499,7 +499,7 @@ class MessageHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, member_event): room_state = yield self.store.get_state_for_events( - member_event.room_id, [member_event.event_id], None + [member_event.event_id], None ) room_state = room_state[member_event.event_id] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9914ff6f9c..a8940de166 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -312,7 +312,7 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): event_id_to_state = yield self.store.get_state_for_events( - room_id, frozenset(e.event_id for e in events), + frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id), diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e935b9443b..6f2a50d585 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -54,7 +54,7 @@ class StateStore(SQLBaseStore): defer.returnValue({}) event_to_groups = yield self._get_state_group_for_events( - room_id, event_ids, + event_ids, ) groups = set(event_to_groups.values()) @@ -208,7 +208,7 @@ class StateStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids, types): + def get_state_for_events(self, event_ids, types): """Given a list of event_ids and type tuples, return a list of state dicts for each event. The state dicts will only have the type/state_keys that are in the `types` list. @@ -225,7 +225,7 @@ class StateStore(SQLBaseStore): The dicts are mappings from (type, state_key) -> state_events """ event_to_groups = yield self._get_state_group_for_events( - room_id, event_ids, + event_ids, ) groups = set(event_to_groups.values()) @@ -251,8 +251,8 @@ class StateStore(SQLBaseStore): ) @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=2) - def _get_state_group_for_events(self, room_id, event_ids): + num_args=1) + def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ def f(txn): -- cgit 1.5.1 From 99c7fbfef7729e6f3cceb9cea64f21d5a2c5b41f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2015 09:52:40 +0100 Subject: Fix to work with SQLite --- synapse/storage/room.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/room.py b/synapse/storage/room.py index e4e830944a..0527cee05d 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -19,6 +19,7 @@ from synapse.api.errors import StoreError from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks +from .engines import PostgresEngine import collections import logging @@ -202,10 +203,16 @@ class RoomStore(SQLBaseStore): ) def _store_event_search_txn(self, txn, event, key, value): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, vector)" - " VALUES (?,?,?,to_tsvector('english', ?))" - ) + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, vector)" + " VALUES (?,?,?,to_tsvector('english', ?))" + ) + else: + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) txn.execute(sql, (event.event_id, event.room_id, key, value,)) -- cgit 1.5.1 From 6296590bf7fb4a8f43755d29ecb8a4819310e25d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 16 Oct 2015 10:50:32 +0100 Subject: Encode the filter JSON as UTF-8 before storing in the database. Because we are using a binary column type to store the filter JSON. --- synapse/storage/filtering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 8800116570..700da8d8c3 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -34,10 +34,10 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(def_json)) + defer.returnValue(json.loads(def_json.decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): - def_json = json.dumps(user_filter) + def_json = json.dumps(user_filter).encode("utf-8") # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one -- cgit 1.5.1 From 22a8c91448f710c20a6aee66ec2a452528f1d637 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 11:19:44 +0100 Subject: Split up run_upgrade --- synapse/storage/schema/delta/24/fts.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py index b45a5fd820..0c752d8426 100644 --- a/synapse/storage/schema/delta/24/fts.py +++ b/synapse/storage/schema/delta/24/fts.py @@ -55,16 +55,24 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id); SQLITE_TABLE = ( "CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)" ) -SQLITE_INDEX = "CREATE INDEX event_search_ev_idx ON event_search(event_id)" def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): - for statement in get_statements(POSTGRES_SQL.splitlines()): - cur.execute(statement) + run_postgres_upgrade(cur) return if isinstance(database_engine, Sqlite3Engine): + run_sqlite_upgrade(cur) + return + + +def run_postgres_upgrade(cur): + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) + + +def run_sqlite_upgrade(cur): cur.execute(SQLITE_TABLE) rowid = -1 @@ -113,5 +121,3 @@ def run_upgrade(cur, database_engine, *args, **kwargs): " VALUES (?,?,?,?)", rows ) - - # cur.execute(SQLITE_INDEX) -- cgit 1.5.1 From 73260ad01f067495e541a936eef4a14ba2fea5ec Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 11:24:02 +0100 Subject: Comment on the LIMIT 500 --- synapse/storage/search.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 1b987161e2..7d642e18ff 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -54,6 +54,8 @@ class SearchStore(SQLBaseStore): for clause in clauses: sql += " AND " + clause + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" results = yield self._execute( -- cgit 1.5.1 From 3cf9948b8d5956c05026ee734ccf65d203eb6d6b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 11:28:12 +0100 Subject: Add docstring --- synapse/storage/search.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 7d642e18ff..6c10f9631e 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -22,6 +22,17 @@ from synapse.storage.engines import PostgresEngine class SearchStore(SQLBaseStore): @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): + """Performs a full text search over events with give keys. + + Args: + room_ids (list): List of room ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.body" + + Returns: + 2-tuple of (dict event_id -> rank, dict event_id -> event) + """ clauses = [] args = [] -- cgit 1.5.1 From edb998ba23cf74de624963f61ca9c897260a3e7e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 14:37:14 +0100 Subject: Explicitly check for Sqlite3Engine --- synapse/storage/search.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 6c10f9631e..dd012fa565 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -16,7 +16,7 @@ from twisted.internet import defer from _base import SQLBaseStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine class SearchStore(SQLBaseStore): @@ -56,11 +56,14 @@ class SearchStore(SQLBaseStore): " FROM plainto_tsquery('english', ?) as query, event_search" " WHERE vector @@ query" ) - else: + elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "SELECT 0 as rank, event_id FROM event_search" " WHERE value MATCH ?" ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") for clause in clauses: sql += " AND " + clause -- cgit 1.5.1 From fc012aa8dc7a3c05c1824a2c333d6c73ebc8726e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 16 Oct 2015 15:28:43 +0100 Subject: Fix FilteringStore.get_user_filter to work with postgres --- synapse/storage/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 700da8d8c3..fcd43c7fdd 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -34,7 +34,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(def_json.decode("utf-8"))) + defer.returnValue(json.loads(str(def_json).decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): def_json = json.dumps(user_filter).encode("utf-8") -- cgit 1.5.1 From f2d698cb52883d8d43faabefdc70e2ade9ebb8b8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 16:46:48 +0100 Subject: Typing --- synapse/storage/search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index dd012fa565..a3c69c5ab3 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -22,13 +22,13 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine class SearchStore(SQLBaseStore): @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): - """Performs a full text search over events with give keys. + """Performs a full text search over events with given keys. Args: room_ids (list): List of room ids to search in search_term (str): Search term to search for keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.body" + "content.body", "content.name", "content.topic" Returns: 2-tuple of (dict event_id -> rank, dict event_id -> event) -- cgit 1.5.1 From 46d39343d976a933c3f2dfd19e5e552c01c93bf4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Oct 2015 16:58:00 +0100 Subject: Explicitly check for Sqlite3Engine --- synapse/storage/room.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 0527cee05d..13441fcdce 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -19,7 +19,7 @@ from synapse.api.errors import StoreError from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks -from .engines import PostgresEngine +from .engines import PostgresEngine, Sqlite3Engine import collections import logging @@ -208,11 +208,14 @@ class RoomStore(SQLBaseStore): "INSERT INTO event_search (event_id, room_id, key, vector)" " VALUES (?,?,?,to_tsvector('english', ?))" ) - else: + elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "INSERT INTO event_search (event_id, room_id, key, value)" " VALUES (?,?,?,?)" ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") txn.execute(sql, (event.event_id, event.room_id, key, value,)) -- cgit 1.5.1 From 68b7fc3e2ba0aae7813b0bae52370860b5cd9f26 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 19 Oct 2015 17:26:18 +0100 Subject: Add rooms that the user has left under archived in v2 sync. --- synapse/handlers/sync.py | 128 ++++++++++++++++++++++++++++++++++- synapse/rest/client/v2_alpha/sync.py | 29 ++++++-- synapse/storage/roommember.py | 13 ++++ 3 files changed, 161 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ee6b881de1..1891cd088c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -61,18 +61,37 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ return bool(self.timeline or self.state or self.ephemeral) +class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ + "room_id", + "timeline", + "state", +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if room needs to be part of the sync result. + """ + return bool(self.timeline or self.state) + + class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ "room_id", "invite", ])): __slots__ = [] + def __nonzero__(self): + """Invited rooms should always be reported to the client""" + return True + class SyncResult(collections.namedtuple("SyncResult", [ "next_batch", # Token for the next sync "presence", # List of presence events for the user. "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. ])): __slots__ = [] @@ -156,11 +175,14 @@ class SyncHandler(BaseHandler): ) room_list = yield self.store.get_rooms_for_user_where_membership_is( user_id=sync_config.user.to_string(), - membership_list=[Membership.INVITE, Membership.JOIN] + membership_list=[ + Membership.INVITE, Membership.JOIN, Membership.LEAVE + ] ) joined = [] invited = [] + archived = [] for event in room_list: if event.membership == Membership.JOIN: room_sync = yield self.initial_sync_for_joined_room( @@ -173,11 +195,23 @@ class SyncHandler(BaseHandler): room_id=event.room_id, invite=invite, )) + elif event.membership == Membership.LEAVE: + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_sync = yield self.initial_sync_for_archived_room( + sync_config=sync_config, + room_id=event.room_id, + leave_event_id=event.event_id, + leave_token=leave_token, + ) + archived.append(room_sync) defer.returnValue(SyncResult( presence=presence, joined=joined, invited=invited, + archived=archived, next_batch=now_token, )) @@ -204,6 +238,28 @@ class SyncHandler(BaseHandler): ephemeral=[], )) + @defer.inlineCallbacks + def initial_sync_for_archived_room(self, room_id, sync_config, + leave_event_id, leave_token): + """Sync a room for a client which is starting without any state + Returns: + A Deferred JoinedSyncResult. + """ + + batch = yield self.load_filtered_recents( + room_id, sync_config, leave_token, + ) + + leave_state = yield self.store.get_state_for_events( + [leave_event_id], None + ) + + defer.returnValue(ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=leave_state[leave_event_id].values(), + )) + @defer.inlineCallbacks def incremental_sync_with_gap(self, sync_config, since_token): """ Get the incremental delta needed to bring the client up to @@ -257,18 +313,22 @@ class SyncHandler(BaseHandler): ) joined = [] + archived = [] if len(room_events) <= timeline_limit: # There is no gap in any of the rooms. Therefore we can just # partition the new events by room and return them. invite_events = [] + leave_events = [] events_by_room_id = {} for event in room_events: events_by_room_id.setdefault(event.room_id, []).append(event) if event.room_id not in joined_room_ids: if (event.type == EventTypes.Member - and event.membership == Membership.INVITE and event.state_key == sync_config.user.to_string()): - invite_events.append(event) + if event.membership == Membership.INVITE: + invite_events.append(event) + elif event.membership == Membership.LEAVE: + leave_events.append(event) for room_id in joined_room_ids: recents = events_by_room_id.get(room_id, []) @@ -296,11 +356,16 @@ class SyncHandler(BaseHandler): ) if room_sync: joined.append(room_sync) + else: invite_events = yield self.store.get_invites_for_user( sync_config.user.to_string() ) + leave_events = yield self.store.get_leave_events_for_user( + sync_config.user.to_string() + ) + for room_id in joined_room_ids: room_sync = yield self.incremental_sync_with_gap_for_room( room_id, sync_config, since_token, now_token, @@ -309,6 +374,12 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + for leave_event in leave_events: + room_sync = yield self.incremental_sync_for_archived_room( + sync_config, leave_event, since_token + ) + archived.append(room_sync) + invited = [ InvitedSyncResult(room_id=event.room_id, invite=event) for event in invite_events @@ -318,6 +389,7 @@ class SyncHandler(BaseHandler): presence=presence, joined=joined, invited=invited, + archived=archived, next_batch=now_token, )) @@ -416,6 +488,56 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def incremental_sync_for_archived_room(self, sync_config, leave_event, + since_token): + """ Get the incremental delta needed to bring the client up to date for + the archived room. + Returns: + A Deferred ArchivedSyncResult + """ + + stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + + leave_token = since_token.copy_and_replace("room_key", stream_token) + + batch = yield self.load_filtered_recents( + leave_event.room_id, sync_config, leave_token, since_token, + ) + + logging.debug("Recents %r", batch) + + # TODO(mjark): This seems racy since this isn't being passed a + # token to indicate what point in the stream this is + leave_state = yield self.store.get_state_for_events( + [leave_event.event_id], None + ) + + state_events_at_leave = leave_state[leave_event.event_id].values() + + state_at_previous_sync = yield self.get_state_at_previous_sync( + leave_event.room_id, since_token=since_token + ) + + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=state_events_at_leave, + ) + + room_sync = ArchivedSyncResult( + room_id=leave_event.room_id, + timeline=batch, + state=state_events_delta, + ) + + logging.debug("Room sync: %r", room_sync) + + defer.returnValue(room_sync) + + @defer.inlineCallbacks def get_state_at_previous_sync(self, room_id, since_token): """ Get the room state at the previous sync the client made. diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index fffecb24f5..73473a7e6b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -136,6 +136,10 @@ class SyncRestServlet(RestServlet): sync_result.invited, filter, time_now, token_id ) + archived = self.encode_archived( + sync_result.archived, filter, time_now, token_id + ) + response_content = { "presence": self.encode_presence( sync_result.presence, filter, time_now @@ -143,7 +147,7 @@ class SyncRestServlet(RestServlet): "rooms": { "joined": joined, "invited": invited, - "archived": {}, + "archived": archived, }, "next_batch": sync_result.next_batch.to_string(), } @@ -182,14 +186,20 @@ class SyncRestServlet(RestServlet): return invited + def encode_archived(self, rooms, filter, time_now, token_id): + joined = {} + for room in rooms: + joined[room.room_id] = self.encode_room( + room, filter, time_now, token_id, joined=False + ) + + return joined + @staticmethod - def encode_room(room, filter, time_now, token_id): + def encode_room(room, filter, time_now, token_id, joined=True): event_map = {} state_events = filter.filter_room_state(room.state) - timeline_events = filter.filter_room_timeline(room.timeline.events) - ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) state_event_ids = [] - timeline_event_ids = [] for event in state_events: # TODO(mjark): Respect formatting requirements in the filter. event_map[event.event_id] = serialize_event( @@ -198,6 +208,8 @@ class SyncRestServlet(RestServlet): ) state_event_ids.append(event.event_id) + timeline_events = filter.filter_room_timeline(room.timeline.events) + timeline_event_ids = [] for event in timeline_events: # TODO(mjark): Respect formatting requirements in the filter. event_map[event.event_id] = serialize_event( @@ -205,6 +217,7 @@ class SyncRestServlet(RestServlet): event_format=format_event_for_client_v2_without_event_id, ) timeline_event_ids.append(event.event_id) + result = { "event_map": event_map, "timeline": { @@ -213,8 +226,12 @@ class SyncRestServlet(RestServlet): "limited": room.timeline.limited, }, "state": {"events": state_event_ids}, - "ephemeral": {"events": ephemeral_events}, } + + if joined: + ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) + result["ephemeral"] = {"events": ephemeral_events} + return result diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index dd98dcfda8..623400fd36 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -124,6 +124,19 @@ class RoomMemberStore(SQLBaseStore): invites.event_id for invite in invites ])) + def get_leave_events_for_user(self, user_id): + """ Get all the leave events for a user + Args: + user_id (str): The user ID. + Returns: + A deferred list of event objects. + """ + return self.get_rooms_for_user_where_membership_is( + user_id, [Membership.LEAVE] + ).addCallback(lambda leaves: self._get_events([ + leave.event_id for leave in leaves + ])) + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. -- cgit 1.5.1 From e3d75f564ab1d17eb4ee47314f78c41553a486f1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 21 Oct 2015 11:15:48 +0100 Subject: Include banned rooms in the archived section of v2 sync --- synapse/handlers/sync.py | 15 +++++++++------ synapse/storage/roommember.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5ca2606443..6cb756e476 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -175,9 +175,12 @@ class SyncHandler(BaseHandler): ) room_list = yield self.store.get_rooms_for_user_where_membership_is( user_id=sync_config.user.to_string(), - membership_list=[ - Membership.INVITE, Membership.JOIN, Membership.LEAVE - ] + membership_list=( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN + ) ) joined = [] @@ -195,7 +198,7 @@ class SyncHandler(BaseHandler): room_id=event.room_id, invite=invite, )) - elif event.membership == Membership.LEAVE: + elif event.membership in (Membership.LEAVE, Membership.BAN): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) @@ -327,7 +330,7 @@ class SyncHandler(BaseHandler): and event.state_key == sync_config.user.to_string()): if event.membership == Membership.INVITE: invite_events.append(event) - elif event.membership == Membership.LEAVE: + elif event.membership in (Membership.LEAVE, Membership.BAN): leave_events.append(event) for room_id in joined_room_ids: @@ -362,7 +365,7 @@ class SyncHandler(BaseHandler): sync_config.user.to_string() ) - leave_events = yield self.store.get_leave_events_for_user( + leave_events = yield self.store.get_leave_and_ban_events_for_user( sync_config.user.to_string() ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 623400fd36..ae1ad56d9a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -124,7 +124,7 @@ class RoomMemberStore(SQLBaseStore): invites.event_id for invite in invites ])) - def get_leave_events_for_user(self, user_id): + def get_leave_and_ban_events_for_user(self, user_id): """ Get all the leave events for a user Args: user_id (str): The user ID. @@ -132,7 +132,7 @@ class RoomMemberStore(SQLBaseStore): A deferred list of event objects. """ return self.get_rooms_for_user_where_membership_is( - user_id, [Membership.LEAVE] + user_id, (Membership.LEAVE, Membership.BAN) ).addCallback(lambda leaves: self._get_events([ leave.event_id for leave in leaves ])) -- cgit 1.5.1 From 4d25bc6c92c3d219786892c5d510e0c4c4eb8b96 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 11:12:28 +0100 Subject: Move FTS to delta 25 --- synapse/storage/prepare_database.py | 2 +- synapse/storage/schema/delta/24/fts.py | 123 --------------------------------- synapse/storage/schema/delta/25/fts.py | 123 +++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 124 deletions(-) delete mode 100644 synapse/storage/schema/delta/24/fts.py create mode 100644 synapse/storage/schema/delta/25/fts.py (limited to 'synapse/storage') diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1ddf55be4d..1a74d6e360 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 +SCHEMA_VERSION = 25 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/24/fts.py deleted file mode 100644 index 0c752d8426..0000000000 --- a/synapse/storage/schema/delta/24/fts.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.prepare_database import get_statements -from synapse.storage.engines import PostgresEngine, Sqlite3Engine - -import ujson - -logger = logging.getLogger(__name__) - - -POSTGRES_SQL = """ -CREATE TABLE event_search ( - event_id TEXT, - room_id TEXT, - key TEXT, - vector tsvector -); - -INSERT INTO event_search SELECT - event_id, room_id, 'content.body', - to_tsvector('english', json::json->'content'->>'body') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; - -INSERT INTO event_search SELECT - event_id, room_id, 'content.name', - to_tsvector('english', json::json->'content'->>'name') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; - -INSERT INTO event_search SELECT - event_id, room_id, 'content.topic', - to_tsvector('english', json::json->'content'->>'topic') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; - - -CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); -CREATE INDEX event_search_ev_idx ON event_search(event_id); -CREATE INDEX event_search_ev_ridx ON event_search(room_id); -""" - - -SQLITE_TABLE = ( - "CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)" -) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - run_postgres_upgrade(cur) - return - - if isinstance(database_engine, Sqlite3Engine): - run_sqlite_upgrade(cur) - return - - -def run_postgres_upgrade(cur): - for statement in get_statements(POSTGRES_SQL.splitlines()): - cur.execute(statement) - - -def run_sqlite_upgrade(cur): - cur.execute(SQLITE_TABLE) - - rowid = -1 - while True: - cur.execute( - "SELECT rowid, json FROM event_json" - " WHERE rowid > ?" - " ORDER BY rowid ASC LIMIT 100", - (rowid,) - ) - - res = cur.fetchall() - - if not res: - break - - events = [ - ujson.loads(js) - for _, js in res - ] - - rowid = max(rid for rid, _ in res) - - rows = [] - for ev in events: - if ev["type"] == "m.room.message": - rows.append(( - ev["event_id"], ev["room_id"], "content.body", - ev["content"]["body"] - )) - if ev["type"] == "m.room.name": - rows.append(( - ev["event_id"], ev["room_id"], "content.name", - ev["content"]["name"] - )) - if ev["type"] == "m.room.topic": - rows.append(( - ev["event_id"], ev["room_id"], "content.topic", - ev["content"]["topic"] - )) - - if rows: - logger.info(rows) - cur.executemany( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)", - rows - ) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py new file mode 100644 index 0000000000..fd181fc630 --- /dev/null +++ b/synapse/storage/schema/delta/25/fts.py @@ -0,0 +1,123 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +import ujson + +logger = logging.getLogger(__name__) + + +POSTGRES_SQL = """ +CREATE TABLE IF NOT EXISTS event_search ( + event_id TEXT, + room_id TEXT, + key TEXT, + vector tsvector +); + +INSERT INTO event_search SELECT + event_id, room_id, 'content.body', + to_tsvector('english', json::json->'content'->>'body') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.name', + to_tsvector('english', json::json->'content'->>'name') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; + +INSERT INTO event_search SELECT + event_id, room_id, 'content.topic', + to_tsvector('english', json::json->'content'->>'topic') + FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; + + +CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); +CREATE INDEX event_search_ev_idx ON event_search(event_id); +CREATE INDEX event_search_ev_ridx ON event_search(room_id); +""" + + +SQLITE_TABLE = ( + "CREATE VIRTUAL TABLE IF NOT EXISTS event_search USING fts3 ( event_id, room_id, key, value)" +) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + run_postgres_upgrade(cur) + return + + if isinstance(database_engine, Sqlite3Engine): + run_sqlite_upgrade(cur) + return + + +def run_postgres_upgrade(cur): + for statement in get_statements(POSTGRES_SQL.splitlines()): + cur.execute(statement) + + +def run_sqlite_upgrade(cur): + cur.execute(SQLITE_TABLE) + + rowid = -1 + while True: + cur.execute( + "SELECT rowid, json FROM event_json" + " WHERE rowid > ?" + " ORDER BY rowid ASC LIMIT 100", + (rowid,) + ) + + res = cur.fetchall() + + if not res: + break + + events = [ + ujson.loads(js) + for _, js in res + ] + + rowid = max(rid for rid, _ in res) + + rows = [] + for ev in events: + if ev["type"] == "m.room.message": + rows.append(( + ev["event_id"], ev["room_id"], "content.body", + ev["content"]["body"] + )) + if ev["type"] == "m.room.name": + rows.append(( + ev["event_id"], ev["room_id"], "content.name", + ev["content"]["name"] + )) + if ev["type"] == "m.room.topic": + rows.append(( + ev["event_id"], ev["room_id"], "content.topic", + ev["content"]["topic"] + )) + + if rows: + logger.info(rows) + cur.executemany( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)", + rows + ) -- cgit 1.5.1 From f142898f529f89c5bd90710db2d0771894b0b33f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 11:18:01 +0100 Subject: PEP8 --- synapse/storage/schema/delta/25/fts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index fd181fc630..ed3cc06557 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -53,7 +53,8 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id); SQLITE_TABLE = ( - "CREATE VIRTUAL TABLE IF NOT EXISTS event_search USING fts3 ( event_id, room_id, key, value)" + "CREATE VIRTUAL TABLE IF NOT EXISTS event_search" + " USING fts3 ( event_id, room_id, key, value)" ) -- cgit 1.5.1 From ba02bba88c6895bc9196d226a2bbc8047b93e28b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 13:25:22 +0100 Subject: Limit max number of SQL vars --- synapse/storage/search.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index a3c69c5ab3..810b5406ad 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -36,10 +36,12 @@ class SearchStore(SQLBaseStore): clauses = [] args = [] - clauses.append( - "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) - ) - args.extend(room_ids) + # Make sure we don't explode because the person is in too many rooms. + if len(room_ids) > 500: + clauses.append( + "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) + ) + args.extend(room_ids) local_clauses = [] for key in keys: -- cgit 1.5.1 From 232beb3a3c28ccdc5388daa9396d5054b7768b12 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 15:02:35 +0100 Subject: Use namedtuple as return value --- synapse/handlers/search.py | 4 +++- synapse/storage/search.py | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index f53e5d35ac..bdc79ffc55 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -66,7 +66,9 @@ class SearchHandler(BaseHandler): room_ids = filtr.filter_rooms(room_ids) - rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys) + rank_map, event_map, _ = yield self.store.search_msgs( + room_ids, search_term, keys + ) filtered_events = filtr.filter(event_map.values()) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 810b5406ad..41451ade57 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -18,6 +18,17 @@ from twisted.internet import defer from _base import SQLBaseStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from collections import namedtuple + +"""The result of a search. + +Fields: + rank_map (dict): Mapping event_id -> rank + event_map (dict): Mapping event_id -> event + pagination_token (str): Pagination token +""" +SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token")) + class SearchStore(SQLBaseStore): @defer.inlineCallbacks @@ -31,7 +42,7 @@ class SearchStore(SQLBaseStore): "content.body", "content.name", "content.topic" Returns: - 2-tuple of (dict event_id -> rank, dict event_id -> event) + SearchResult """ clauses = [] args = [] @@ -85,11 +96,12 @@ class SearchStore(SQLBaseStore): for ev in events } - defer.returnValue(( + defer.returnValue(SearchResult( { r["event_id"]: r["rank"] for r in results if r["event_id"] in event_map }, - event_map + event_map, + None )) -- cgit 1.5.1 From fb0fecd0b9f430670ca6de1f4746601dd6cc24f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 16:18:35 +0100 Subject: LESS THAN --- synapse/storage/search.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 41451ade57..e658e07dc7 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -48,7 +48,8 @@ class SearchStore(SQLBaseStore): args = [] # Make sure we don't explode because the person is in too many rooms. - if len(room_ids) > 500: + # We filter the results regardless. + if len(room_ids) < 500: clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) ) -- cgit 1.5.1 From 671ac699f1d4672bed3817b3cafb7498df99c030 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2015 16:54:56 +0100 Subject: Actually filter results --- synapse/storage/search.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index e658e07dc7..9608b5d6a7 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -48,7 +48,7 @@ class SearchStore(SQLBaseStore): args = [] # Make sure we don't explode because the person is in too many rooms. - # We filter the results regardless. + # We filter the results below regardless. if len(room_ids) < 500: clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) @@ -66,13 +66,13 @@ class SearchStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): sql = ( - "SELECT ts_rank_cd(vector, query) AS rank, event_id" + "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" " FROM plainto_tsquery('english', ?) as query, event_search" " WHERE vector @@ query" ) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( - "SELECT 0 as rank, event_id FROM event_search" + "SELECT 0 as rank, room_id, event_id FROM event_search" " WHERE value MATCH ?" ) else: @@ -90,6 +90,8 @@ class SearchStore(SQLBaseStore): "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) ) + results = filter(lambda row: row["room_id"] in room_ids, results) + events = yield self._get_events([r["event_id"] for r in results]) event_map = { -- cgit 1.5.1 From 0c36098c1fe561629651e446589a644fed9188e4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Oct 2015 13:23:48 +0100 Subject: Implement rank function for SQLite FTS --- synapse/storage/engines/sqlite3.py | 27 +++++++++++++++++++++++++++ synapse/storage/schema/delta/25/fts.py | 2 +- synapse/storage/search.py | 3 ++- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index bad3b5c5ac..a5a54ec011 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -17,6 +17,8 @@ from synapse.storage.prepare_database import ( prepare_database, prepare_sqlite3_database ) +import struct + class Sqlite3Engine(object): single_threaded = True @@ -32,6 +34,7 @@ class Sqlite3Engine(object): def on_new_connection(self, db_conn): self.prepare_database(db_conn) + db_conn.create_function("rank", 1, _rank) def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) @@ -45,3 +48,27 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + + +# Following functions taken from: https://github.com/coleifer/peewee + +def _parse_match_info(buf): + bufsize = len(buf) + return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + + +def _rank(raw_match_info): + """Handle match_info called w/default args 'pcx' - based on the example rank + function http://sqlite.org/fts3.html#appendix_a + """ + match_info = _parse_match_info(raw_match_info) + score = 0.0 + p, c = match_info[:2] + for phrase_num in range(p): + phrase_info_idx = 2 + (phrase_num * c * 3) + for col_num in range(c): + col_idx = phrase_info_idx + (col_num * 3) + x1, x2 = match_info[col_idx:col_idx + 2] + if x1 > 0: + score += float(x1) / x2 + return score diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index ed3cc06557..3f0b02d119 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -54,7 +54,7 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id); SQLITE_TABLE = ( "CREATE VIRTUAL TABLE IF NOT EXISTS event_search" - " USING fts3 ( event_id, room_id, key, value)" + " USING fts4 ( event_id, room_id, key, value )" ) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 9608b5d6a7..cdf003502f 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -72,7 +72,8 @@ class SearchStore(SQLBaseStore): ) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( - "SELECT 0 as rank, room_id, event_id FROM event_search" + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " FROM event_search" " WHERE value MATCH ?" ) else: -- cgit 1.5.1 From 4cf633d5e9628a56cf9b400ee3e073fcdcb365f0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Oct 2015 15:41:36 +0100 Subject: Pull out sender when computing search results --- synapse/storage/schema/delta/25/fts.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 3f0b02d119..056487af30 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -26,22 +26,23 @@ POSTGRES_SQL = """ CREATE TABLE IF NOT EXISTS event_search ( event_id TEXT, room_id TEXT, + sender TEXT, key TEXT, vector tsvector ); INSERT INTO event_search SELECT - event_id, room_id, 'content.body', + event_id, room_id, json::json->>'sender', 'content.body', to_tsvector('english', json::json->'content'->>'body') FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; INSERT INTO event_search SELECT - event_id, room_id, 'content.name', + event_id, room_id, json::json->>'sender', 'content.name', to_tsvector('english', json::json->'content'->>'name') FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; INSERT INTO event_search SELECT - event_id, room_id, 'content.topic', + event_id, room_id, json::json->>'sender', 'content.topic', to_tsvector('english', json::json->'content'->>'topic') FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; @@ -99,26 +100,28 @@ def run_sqlite_upgrade(cur): rows = [] for ev in events: - if ev["type"] == "m.room.message": + content = ev.get("content", {}) + body = content.get("body", None) + name = content.get("name", None) + topic = content.get("topic", None) + sender = ev.get("sender", None) + if ev["type"] == "m.room.message" and body: rows.append(( - ev["event_id"], ev["room_id"], "content.body", - ev["content"]["body"] + ev["event_id"], ev["room_id"], sender, "content.body", body )) - if ev["type"] == "m.room.name": + if ev["type"] == "m.room.name" and name: rows.append(( - ev["event_id"], ev["room_id"], "content.name", - ev["content"]["name"] + ev["event_id"], ev["room_id"], sender, "content.name", name )) - if ev["type"] == "m.room.topic": + if ev["type"] == "m.room.topic" and topic: rows.append(( - ev["event_id"], ev["room_id"], "content.topic", - ev["content"]["topic"] + ev["event_id"], ev["room_id"], sender, "content.topic", topic )) if rows: logger.info(rows) cur.executemany( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)", + "INSERT INTO event_search (event_id, room_id, sender, key, value)" + " VALUES (?,?,?,?,?)", rows ) -- cgit 1.5.1 From 5cb298c934adf9c7f42d06ea1ac74ab2bc651c23 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Oct 2015 13:45:56 +0000 Subject: Add room context api --- synapse/handlers/__init__.py | 3 +- synapse/handlers/room.py | 42 ++++++++++++++++ synapse/rest/client/v1/room.py | 36 +++++++++++++ synapse/storage/stream.py | 111 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 190 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 87b4d381c7..6a2339f2eb 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler + RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, ) from .message import MessageHandler from .events import EventStreamHandler, EventHandler @@ -70,3 +70,4 @@ class Handlers(object): self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) + self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 60f9fa58b0..dd93a5d04d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from collections import OrderedDict from unpaddedbase64 import decode_base64 import logging +import math import string logger = logging.getLogger(__name__) @@ -747,6 +748,47 @@ class RoomListHandler(BaseHandler): defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) +class RoomContextHandler(BaseHandler): + @defer.inlineCallbacks + def get_event_context(self, user, room_id, event_id, limit): + before_limit = math.floor(limit/2.) + after_limit = limit - before_limit + + now_token = yield self.hs.get_event_sources().get_current_token() + + results = yield self.store.get_events_around( + room_id, event_id, before_limit, after_limit + ) + + results["events_before"] = yield self._filter_events_for_client( + user.to_string(), results["events_before"] + ) + + results["events_after"] = yield self._filter_events_for_client( + user.to_string(), results["events_after"] + ) + + if results["events_after"]: + last_event_id = results["events_after"][-1].event_id + else: + last_event_id = event_id + + state = yield self.store.get_state_for_events( + [last_event_id], None + ) + results["state"] = state[last_event_id].values() + + results["start"] = now_token.copy_and_replace( + "room_key", results["start"] + ).to_string() + + results["end"] = now_token.copy_and_replace( + "room_key", results["end"] + ).to_string() + + defer.returnValue(results) + + class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4cee1c1599..2dcaee86cd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -397,6 +397,41 @@ class RoomTriggerBackfill(ClientV1RestServlet): defer.returnValue((200, res)) +class RoomEventContext(ClientV1RestServlet): + PATTERN = client_path_pattern( + "/rooms/(?P[^/]*)/context/(?P[^/]*)$" + ) + + def __init__(self, hs): + super(RoomEventContext, self).__init__(hs) + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, event_id): + user, _ = yield self.auth.get_user_by_req(request) + + limit = int(request.args.get("limit", [10])[0]) + + results = yield self.handlers.room_context_handler.get_event_context( + user, room_id, event_id, limit, + ) + + time_now = self.clock.time_msec() + results["events_before"] = [ + serialize_event(event, time_now) for event in results["events_before"] + ] + results["events_after"] = [ + serialize_event(event, time_now) for event in results["events_after"] + ] + results["state"] = [ + serialize_event(event, time_now) for event in results["state"] + ] + + logger.info("Responding with %r", results) + + defer.returnValue((200, results)) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -628,3 +663,4 @@ def register_servlets(hs, http_server): RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + RoomEventContext(hs).register(http_server) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3cab06fdef..f2eecf52f9 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -23,7 +23,7 @@ paginate bacwards. This is implemented by keeping two ordering columns: stream_ordering and topological_ordering. Stream ordering is basically insertion/received order -(except for events from backfill requests). The topolgical_ordering is a +(except for events from backfill requests). The topological_ordering is a weak ordering of events based on the pdu graph. This means that we have to have two different types of tokens, depending on @@ -436,3 +436,112 @@ class StreamStore(SQLBaseStore): internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) + + @defer.inlineCallbacks + def get_events_around(self, room_id, event_id, before_limit, after_limit): + results = yield self.runInteraction( + "get_events_around", self._get_events_around_txn, + room_id, event_id, before_limit, after_limit + ) + + events_before = yield self._get_events( + [e for e in results["before"]["event_ids"]], + get_prev_content=True + ) + + events_after = yield self._get_events( + [e for e in results["after"]["event_ids"]], + get_prev_content=True + ) + + defer.returnValue({ + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + }) + + def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): + results = self._simple_select_one_txn( + txn, + "events", + keyvalues={ + "event_id": event_id, + "room_id": room_id, + }, + retcols=["stream_ordering", "topological_ordering"], + ) + + stream_ordering = results["stream_ordering"] + topological_ordering = results["topological_ordering"] + + query_before = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering < ?" + " OR (topological_ordering = ? AND stream_ordering < ?))" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + + query_after = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " WHERE room_id = ? AND (topological_ordering > ?" + " OR (topological_ordering = ? AND stream_ordering > ?))" + " ORDER BY topological_ordering ASC, stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute( + query_before, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, before_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_before = [r["event_id"] for r in rows] + + if rows: + start_token = str(RoomStreamToken( + rows[0]["topological_ordering"], + rows[0]["stream_ordering"] - 1, + )) + else: + start_token = str(RoomStreamToken( + topological_ordering, + stream_ordering - 1, + )) + + txn.execute( + query_after, + ( + room_id, topological_ordering, topological_ordering, + stream_ordering, after_limit, + ) + ) + + rows = self.cursor_to_dict(txn) + events_after = [r["event_id"] for r in rows] + + if rows: + end_token = str(RoomStreamToken( + rows[-1]["topological_ordering"], + rows[-1]["stream_ordering"], + )) + else: + end_token = str(RoomStreamToken( + topological_ordering, + stream_ordering, + )) + + return { + "before": { + "event_ids": events_before, + "token": start_token, + }, + "after": { + "event_ids": events_after, + "token": end_token, + }, + } -- cgit 1.5.1 From 56dbcd1524d855e0bea2a77e371b1732f82a9492 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Oct 2015 14:05:50 +0000 Subject: Docs --- synapse/handlers/room.py | 13 +++++++++++++ synapse/storage/stream.py | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index dd93a5d04d..36878a6c20 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -751,6 +751,19 @@ class RoomListHandler(BaseHandler): class RoomContextHandler(BaseHandler): @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit): + """Retrieves events, pagination tokens and state around a given event + in a room. + + Args: + user (UserID) + room_id (str) + event_id (str) + limit (int): The maximum number of events to return in total + (excluding state). + + Returns: + dict + """ before_limit = math.floor(limit/2.) after_limit = limit - before_limit diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index f2eecf52f9..15d4c2bf68 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -439,6 +439,19 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_events_around(self, room_id, event_id, before_limit, after_limit): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit @@ -462,6 +475,19 @@ class StreamStore(SQLBaseStore): }) def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + + Returns: + dict + """ + results = self._simple_select_one_txn( txn, "events", -- cgit 1.5.1 From 892e70ec8404865b4b1e6cf4eef7a3ada7114149 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 28 Oct 2015 16:06:57 +0000 Subject: Add APIs for adding and removing tags from rooms --- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/tags.py | 89 +++++++++++++++ synapse/storage/__init__.py | 2 + synapse/storage/schema/delta/25/tags.sql | 37 ++++++ synapse/storage/tags.py | 190 +++++++++++++++++++++++++++++++ 5 files changed, 320 insertions(+) create mode 100644 synapse/rest/client/v2_alpha/tags.py create mode 100644 synapse/storage/schema/delta/25/tags.sql create mode 100644 synapse/storage/tags.py (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 5831ff0e62..a108132346 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -22,6 +22,7 @@ from . import ( receipts, keys, tokenrefresh, + tags, ) from synapse.http.server import JsonResource @@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource): receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) + tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py new file mode 100644 index 0000000000..c4a670c5a7 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tags.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_pattern + +from synapse.http.servlet import RestServlet + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class TagListServlet(RestServlet): + """ + GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" + ) + + def __init__(self, hs): + super(TagListServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, room_id): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot get tags for other users.") + + tags = yield self.store.get_tags_for_room(user_id, room_id) + + defer.returnValue((200, {"tags": tags})) + + +class TagServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + """ + PATTERN = client_v2_pattern( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" + ) + def __init__(self, hs): + super(TagServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + yield self.store.add_tag_to_room(user_id, room_id, tag) + + # TODO: poke the notifier. + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request, user_id, room_id, tag): + auth_user, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + yield self.store.remove_tag_from_room(user_id, room_id, tag) + + # TODO: poke the notifier. + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + TagListServlet(hs).register(http_server) + TagServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a1bd9c4ce9..e7443f2838 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore from .search import SearchStore +from .tags import TagsStore import logging @@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore, ReceiptsStore, EndToEndKeyStore, SearchStore, + TagsStore, ): def __init__(self, hs): diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql new file mode 100644 index 0000000000..168766dcf3 --- /dev/null +++ b/synapse/storage/schema/delta/25/tags.sql @@ -0,0 +1,37 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS room_tags( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + tag TEXT NOT NULL, -- The name of the tag. + CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) +); + +CREATE TABLE IF NOT EXISTS room_tags_revisions ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- The current version of the room tags. + CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) +); + +CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py new file mode 100644 index 0000000000..507e60596f --- /dev/null +++ b/synapse/storage/tags.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached +from twisted.internet import defer +from .util.id_generators import StreamIdGenerator + +import logging + +logger = logging.getLogger(__name__) + + +class TagsStore(SQLBaseStore): + def __init__(self, hs): + super(TagsStore, self).__init__(hs) + + self._private_user_data_id_gen = StreamIdGenerator( + "private_user_data_max_stream_id", "stream_id" + ) + + @cached() + def get_tags_for_user(self, user_id): + """Get all the tags for a user. + + + Args: + user_id(str): The user to get the tags for. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings. + """ + + deferred = self._simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag"] + ) + + @deferred.addCallback + def tags_by_room(rows): + tags_by_room = {} + for row in rows: + tags_by_room.setdefault(row["room_id"], []).append(row["tag"]) + return tags_by_room + + return deferred + + @defer.inlineCallbacks + def get_updated_tags(self, user_id, stream_id): + """Get all the tags for the rooms where the tags have changed since the + given version + + Args: + user_id(str): The user to get the tags for. + stream_id(int): The earliest update to get for the user. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings for all the rooms that changed since the stream_id token. + """ + def get_updated_tags_txn(txn): + sql = ( + "SELECT room_id from room_tags_revisions" + " WHERE user_id = ? AND stream_id > ?" + ) + txn.execute(sql, (user_id, stream_id)) + room_ids = [row[0] for row in txn.fetchall()] + return room_ids + + room_ids = yield self.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) + + results = {} + if room_ids: + tags_by_room = yield self.get_tags_for_user(self, user_id) + for room_id in rooms_ids: + results[room_id] = tags_by_room[room_id] + + defer.returnValue(results) + + def get_tags_for_room(self, user_id, room_id): + """Get all the tags for the given room + Args: + user_id(str): The user to get tags for + room_id(str): The room to get tags for + Returns: + A deferred list of string tags. + """ + return self._simple_select_onecol( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcol="tag", + desc="get_tags_for_room", + ) + + @defer.inlineCallbacks + def add_tag_to_room(self, user_id, room_id, tag): + """Add a tag to a room for a user. + Returns: + A deferred that completes once the tag has been added. + """ + def add_tag_txn(txn, next_id): + sql = ( + "INSERT INTO room_tags (user_id, room_id, tag)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(sql, (user_id, room_id, tag)) + except database_engine.module.IntegrityError as e: + # Return early if the row is already in the table + # and we don't need to bump the revision number of the + # private_user_data. + return + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("add_tag", add_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + @defer.inlineCallbacks + def remove_tag_from_room(self, user_id, room_id, tag): + """Remove a tag from a room for a user. + Returns: + A deffered that completes once the tag has been removed + """ + def remove_tag_txn(txn, next_id): + sql = ( + "DELETE FROM room_tags " + " WHERE user_id = ? AND room_id = ? AND tag = ?" + ) + txn.execute(sql, (user_id, room_id, tag)) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with (yield self._private_user_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + def _update_revision_txn(self, txn, user_id, room_id, next_id): + """Update the latest revision of the tags for the given user and room. + + Args: + txn: The database cursor + user_id(str): The ID of the user. + room_id(str): The ID of the room. + next_id(int): The the revision to advance to. + """ + + update_max_id_sql = ( + "UPDATE private_user_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + update_sql = ( + "UPDATE room_tags_revisions" + " SET stream_id = ?" + " WHERE user_id = ?" + " AND room_id = ?" + ) + txn.execute(update_sql, (next_id, user_id, room_id)) + + if txn.rowcount == 0: + insert_sql = ( + "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(insert_sql, (user_id, room_id, next_id)) + except database_engine.module.IntegrityError as e: + # Ignore insertion errors. It doesn't matter if the row wasn't + # inserted because if two updates happend concurrently the one + # with the higher stream_id will not be reported to a client + # unless the previous update has completed. It doesn't matter + # which stream_id ends up in the table, as long as it is higher + # than the id that the client has. + pass -- cgit 1.5.1 From a89b86dc473f5dc41a17207a17165b27d8f599ea Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 28 Oct 2015 16:45:57 +0000 Subject: Fix pyflakes errors --- synapse/rest/client/v2_alpha/tags.py | 2 ++ synapse/storage/tags.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index c4a670c5a7..15c9347fc1 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -16,6 +16,7 @@ from ._base import client_v2_pattern from synapse.http.servlet import RestServlet +from synapse.api.errors import AuthError from twisted.internet import defer @@ -56,6 +57,7 @@ class TagServlet(RestServlet): PATTERN = client_v2_pattern( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) + def __init__(self, hs): super(TagServlet, self).__init__() self.auth = hs.get_auth() diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 507e60596f..64c65fc321 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -84,7 +84,7 @@ class TagsStore(SQLBaseStore): results = {} if room_ids: tags_by_room = yield self.get_tags_for_user(self, user_id) - for room_id in rooms_ids: + for room_id in room_ids: results[room_id] = tags_by_room[room_id] defer.returnValue(results) @@ -117,7 +117,7 @@ class TagsStore(SQLBaseStore): ) try: txn.execute(sql, (user_id, room_id, tag)) - except database_engine.module.IntegrityError as e: + except self.database_engine.module.IntegrityError: # Return early if the row is already in the table # and we don't need to bump the revision number of the # private_user_data. @@ -180,7 +180,7 @@ class TagsStore(SQLBaseStore): ) try: txn.execute(insert_sql, (user_id, room_id, next_id)) - except database_engine.module.IntegrityError as e: + except self.database_engine.module.IntegrityError: # Ignore insertion errors. It doesn't matter if the row wasn't # inserted because if two updates happend concurrently the one # with the higher stream_id will not be reported to a client -- cgit 1.5.1 From f40b0ed5e190a78ed6633505c4f437b6fddc41ee Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 29 Oct 2015 15:20:52 +0000 Subject: Inform the client of new room tags using v1 /events --- synapse/handlers/private_user_data.py | 46 +++++++++++++++++++++++++++++++++++ synapse/notifier.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 14 ++++++++--- synapse/storage/tags.py | 16 +++++++++++- synapse/streams/events.py | 5 ++++ synapse/types.py | 22 ++++++++++------- 6 files changed, 91 insertions(+), 14 deletions(-) create mode 100644 synapse/handlers/private_user_data.py (limited to 'synapse/storage') diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py new file mode 100644 index 0000000000..1778c71325 --- /dev/null +++ b/synapse/handlers/private_user_data.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + + +class PrivateUserDataEventSource(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + def get_current_key(self, direction='f'): + return self.store.get_max_private_user_data_stream_id() + + @defer.inlineCallbacks + def get_new_events_for_user(self, user, from_key, limit): + user_id = user.to_string() + last_stream_id = from_key + + current_stream_id = yield self.store.get_max_private_user_data_stream_id() + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + + results = [] + for room_id, room_tags in tags.items(): + results.append({ + "type": "m.tag", + "content": {"tags": room_tags}, + "room_id": room_id, + }) + + defer.returnValue((results, current_stream_id)) + + @defer.inlineCallbacks + def get_pagination_rows(self, user, config, key): + defer.returnValue(([], config.to_id)) diff --git a/synapse/notifier.py b/synapse/notifier.py index f998fc83bf..a78ee3c1e7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -270,7 +270,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user, rooms, timeout, callback, - from_token=StreamToken("s0", "0", "0", "0")): + from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 15c9347fc1..486add9909 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -62,6 +62,7 @@ class TagServlet(RestServlet): super(TagServlet, self).__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() + self.notifier = hs.get_notifier() @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): @@ -69,9 +70,12 @@ class TagServlet(RestServlet): if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - yield self.store.add_tag_to_room(user_id, room_id, tag) + max_id = yield self.store.add_tag_to_room(user_id, room_id, tag) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) - # TODO: poke the notifier. defer.returnValue((200, {})) @defer.inlineCallbacks @@ -80,7 +84,11 @@ class TagServlet(RestServlet): if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) # TODO: poke the notifier. defer.returnValue((200, {})) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 64c65fc321..2d5c49144a 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -31,6 +31,14 @@ class TagsStore(SQLBaseStore): "private_user_data_max_stream_id", "stream_id" ) + def get_max_private_user_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._private_user_data_id_gen.get_max_token(self) + @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -83,7 +91,7 @@ class TagsStore(SQLBaseStore): results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(self, user_id) + tags_by_room = yield self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room[room_id] @@ -129,6 +137,9 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + @defer.inlineCallbacks def remove_tag_from_room(self, user_id, room_id, tag): """Remove a tag from a room for a user. @@ -148,6 +159,9 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + def _update_revision_txn(self, txn, user_id, room_id, next_id): """Update the latest revision of the tags for the given user and room. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 699083ae12..f0d68b5bf2 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource +from synapse.handlers.private_user_data import PrivateUserDataEventSource class EventSources(object): @@ -29,6 +30,7 @@ class EventSources(object): "presence": PresenceEventSource, "typing": TypingNotificationEventSource, "receipt": ReceiptEventSource, + "private_user_data": PrivateUserDataEventSource, } def __init__(self, hs): @@ -52,5 +54,8 @@ class EventSources(object): receipt_key=( yield self.sources["receipt"].get_current_key() ), + private_user_data_key=( + yield self.sources["private_user_data"].get_current_key() + ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 9cffc33d27..8d3a8d88cc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -98,10 +98,13 @@ class EventID(DomainSpecificString): class StreamToken( - namedtuple( - "Token", - ("room_key", "presence_key", "typing_key", "receipt_key") - ) + namedtuple("Token", ( + "room_key", + "presence_key", + "typing_key", + "receipt_key", + "private_user_data_key", + )) ): _SEPARATOR = "_" @@ -128,13 +131,14 @@ class StreamToken( else: return int(self.room_key[1:].split("-")[-1]) - def is_after(self, other_token): + def is_after(self, other): """Does this token contain events that the other doesn't?""" return ( - (other_token.room_stream_id < self.room_stream_id) - or (int(other_token.presence_key) < int(self.presence_key)) - or (int(other_token.typing_key) < int(self.typing_key)) - or (int(other_token.receipt_key) < int(self.receipt_key)) + (other.room_stream_id < self.room_stream_id) + or (int(other.presence_key) < int(self.presence_key)) + or (int(other.typing_key) < int(self.typing_key)) + or (int(other.receipt_key) < int(self.receipt_key)) + or (int(other.private_user_data_key) < int(self.private_user_data_key)) ) def copy_and_advance(self, key, new_value): -- cgit 1.5.1 From 621e84d9a0f06f65bd51171079fd18eb916ab81a Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 30 Oct 2015 16:25:53 +0000 Subject: Add missing column --- synapse/storage/schema/delta/25/fts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 056487af30..b7cd0ce3b8 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -55,7 +55,7 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id); SQLITE_TABLE = ( "CREATE VIRTUAL TABLE IF NOT EXISTS event_search" - " USING fts4 ( event_id, room_id, key, value )" + " USING fts4 ( event_id, room_id, sender, key, value )" ) -- cgit 1.5.1 From ddd8566f415ab3a6092aa4947e5d2aebf67109fc Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 2 Nov 2015 15:11:31 +0000 Subject: Store room tag content and return the content in the m.tag event --- synapse/handlers/message.py | 6 ++--- synapse/rest/client/v2_alpha/tags.py | 12 +++++++-- synapse/storage/schema/delta/25/tags.sql | 1 + synapse/storage/tags.py | 44 ++++++++++++++++++++------------ 4 files changed, 41 insertions(+), 22 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f156e5c84..0f947993d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -405,7 +405,6 @@ class MessageHandler(BaseHandler): tags = tags_by_room.get(event.room_id) if tags: private_user_data.append({ - "room_id": event.room_id, "type": "m.tag", "content": {"tags": tags}, }) @@ -466,7 +465,6 @@ class MessageHandler(BaseHandler): private_user_data.append({ "type": "m.tag", "content": {"tags": tags}, - "room_id": room_id, }) result["private_user_data"] = private_user_data @@ -499,8 +497,8 @@ class MessageHandler(BaseHandler): user_id, messages ) - start_token = StreamToken(token[0], 0, 0, 0) - end_token = StreamToken(token[1], 0, 0, 0) + start_token = StreamToken(token[0], 0, 0, 0, 0) + end_token = StreamToken(token[1], 0, 0, 0, 0) time_now = self.clock.time_msec() diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 486add9909..4e3f917fc5 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -16,12 +16,14 @@ from ._base import client_v2_pattern from synapse.http.servlet import RestServlet -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from twisted.internet import defer import logging +import simplejson as json + logger = logging.getLogger(__name__) @@ -70,7 +72,13 @@ class TagServlet(RestServlet): if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag) + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid tag JSON") + + max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) yield self.notifier.on_new_event( "private_user_data_key", max_id, users=[user_id] diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql index 168766dcf3..527424c998 100644 --- a/synapse/storage/schema/delta/25/tags.sql +++ b/synapse/storage/schema/delta/25/tags.sql @@ -18,6 +18,7 @@ CREATE TABLE IF NOT EXISTS room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, -- The name of the tag. + content TEXT NOT NULL, -- The JSON content of the tag. CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 2d5c49144a..34aa38c06a 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -18,6 +18,7 @@ from synapse.util.caches.descriptors import cached from twisted.internet import defer from .util.id_generators import StreamIdGenerator +import ujson as json import logging logger = logging.getLogger(__name__) @@ -52,14 +53,15 @@ class TagsStore(SQLBaseStore): """ deferred = self._simple_select_list( - "room_tags", {"user_id": user_id}, ["room_id", "tag"] + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @deferred.addCallback def tags_by_room(rows): tags_by_room = {} for row in rows: - tags_by_room.setdefault(row["room_id"], []).append(row["tag"]) + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = json.loads(row["content"]) return tags_by_room return deferred @@ -105,31 +107,41 @@ class TagsStore(SQLBaseStore): Returns: A deferred list of string tags. """ - return self._simple_select_onecol( + return self._simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, - retcol="tag", + retcols=("tag", "content"), desc="get_tags_for_room", - ) + ).addCallback(lambda rows: { + row["tag"]: json.loads(row["content"]) for row in rows + }) @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag): + def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + tag(str): The tag name to add. + content(dict): A json object to associate with the tag. Returns: A deferred that completes once the tag has been added. """ + content_json = json.dumps(content) + def add_tag_txn(txn, next_id): - sql = ( - "INSERT INTO room_tags (user_id, room_id, tag)" - " VALUES (?, ?, ?)" + self._simple_upsert_txn( + txn, + table="room_tags", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "tag": tag, + }, + values={ + "content": content_json, + } ) - try: - txn.execute(sql, (user_id, room_id, tag)) - except self.database_engine.module.IntegrityError: - # Return early if the row is already in the table - # and we don't need to bump the revision number of the - # private_user_data. - return self._update_revision_txn(txn, user_id, room_id, next_id) with (yield self._private_user_data_id_gen.get_next(self)) as next_id: -- cgit 1.5.1 From 771ca56c886dd08f707447cfff70acd3ba73e98c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 2 Nov 2015 15:31:57 +0000 Subject: Remove more unused parameters --- synapse/handlers/room.py | 1 - synapse/handlers/sync.py | 1 - synapse/storage/stream.py | 3 +-- tests/storage/test_redaction.py | 4 ---- tests/storage/test_stream.py | 4 ---- tests/utils.py | 2 +- 6 files changed, 2 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 36878a6c20..9184dcd048 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -827,7 +827,6 @@ class RoomEventSource(object): user_id=user.to_string(), from_key=from_key, to_key=to_key, - room_id=None, limit=limit, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eaa14f38df..4054efe555 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -342,7 +342,6 @@ class SyncHandler(BaseHandler): sync_config.user.to_string(), from_key=since_token.room_key, to_key=now_token.room_key, - room_id=None, limit=timeline_limit + 1, ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 15d4c2bf68..c728013f4c 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,8 +158,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0): + def get_room_events_stream(self, user_id, from_key, to_key, limit=0): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " " INNER JOIN current_state_events as c" diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b57006fcb4..dbf9700e6a 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index a658a789aa..e5c2c5cc8e 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. @@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. diff --git a/tests/utils.py b/tests/utils.py index 4da51291a4..ca2c33cf8e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,7 +335,7 @@ class MemoryDataStore(object): ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - room_id=None, limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id): -- cgit 1.5.1 From 5897e773fd50733fa448b36275130e5441d36f31 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 3 Nov 2015 14:27:35 +0000 Subject: Spell "deferred" more correctly --- synapse/storage/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 34aa38c06a..641ea250f0 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -156,7 +156,7 @@ class TagsStore(SQLBaseStore): def remove_tag_from_room(self, user_id, room_id, tag): """Remove a tag from a room for a user. Returns: - A deffered that completes once the tag has been removed + A deferred that completes once the tag has been removed """ def remove_tag_txn(txn, next_id): sql = ( -- cgit 1.5.1 From 7ce264ce5f1b2409081446bd8e1a4adc63675e06 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Nov 2015 16:23:35 +0000 Subject: Fix broken cache for getting retry times. This meant we retried remote destinations way more frequently than we should --- synapse/federation/transaction_queue.py | 47 ++++++++++++++++---------------- synapse/storage/transactions.py | 48 +++++++++++---------------------- 2 files changed, 40 insertions(+), 55 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 32fa5e8c15..2a7dd343f3 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -202,19 +202,6 @@ class TransactionQueue(object): @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - logger.debug( - "TX [%s] Transaction already in progress", - destination - ) - return - - logger.debug("TX [%s] _attempt_new_transaction", destination) - # list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) @@ -228,20 +215,34 @@ class TransactionQueue(object): logger.debug("TX [%s] Nothing to send", destination) return - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + logger.debug( + "TX [%s] Transaction already in progress", + destination + ) + return + # NOTE: Nothing should be between the above check and the insertion below try: self.pending_transactions[destination] = 1 + logger.debug("TX [%s] _attempt_new_transaction", destination) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + txn_id = str(self._next_txn_id) limiter = yield get_retry_limiter( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 15695e9831..4e0d7c9774 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # As this is the new value, we might as well prefill the cache - self.get_destination_retry_timings.prefill( - destination, - { - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval - }, - ) - # XXX: we could chose to not bother persisting this if our cache thinks # this is a NOOP return self.runInteraction( @@ -275,31 +265,25 @@ class TransactionStore(SQLBaseStore): def _set_destination_retry_timings(self, txn, destination, retry_last_ts, retry_interval): - query = ( - "UPDATE destinations" - " SET retry_last_ts = ?, retry_interval = ?" - " WHERE destination = ?" - ) + txn.call_after(self.get_destination_retry_timings.invalidate, (destination,)) - txn.execute( - query, - ( - retry_last_ts, retry_interval, destination, - ) + self._simple_upsert_txn( + txn, + "destinations", + keyvalues={ + "destination": destination, + }, + values={ + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + insertion_values={ + "destination": destination, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + } ) - if txn.rowcount == 0: - # destination wasn't already in table. Insert it. - self._simple_insert_txn( - txn, - table="destinations", - values={ - "destination": destination, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - } - ) - def get_destinations_needing_retry(self): """Get all destinations which are due a retry for sending a transaction. -- cgit 1.5.1 From f522f50a08d48042d103c98dbc3cfd4872b7d981 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 4 Nov 2015 17:29:07 +0000 Subject: Allow guests to register and call /events?room_id= This follows the same flows-based flow as regular registration, but as the only implemented flow has no requirements, it auto-succeeds. In the future, other flows (e.g. captcha) may be required, so clients should treat this like the regular registration flow choices. --- synapse/api/auth.py | 95 ++++++++++++++++------------- synapse/api/errors.py | 1 + synapse/config/registration.py | 6 ++ synapse/handlers/_base.py | 75 ++++++++++++++--------- synapse/handlers/auth.py | 5 +- synapse/handlers/message.py | 46 +++++++------- synapse/handlers/register.py | 12 ++-- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 +- synapse/rest/client/v1/events.py | 4 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +-- synapse/rest/client/v1/profile.py | 4 +- synapse/rest/client/v1/push_rule.py | 6 +- synapse/rest/client/v1/pusher.py | 2 +- synapse/rest/client/v1/room.py | 27 ++++---- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 6 +- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 6 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 27 +++++++- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 6 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/registration.py | 15 ++--- tests/api/test_auth.py | 25 +++++++- tests/rest/client/v1/test_presence.py | 10 +-- tests/rest/client/v1/test_profile.py | 4 +- tests/rest/client/v1/test_rooms.py | 21 ++++--- tests/rest/client/v1/test_typing.py | 3 +- tests/rest/client/v2_alpha/__init__.py | 3 +- 33 files changed, 272 insertions(+), 167 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 88445fe999..dfbbc5a1cd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -49,6 +49,7 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self._KNOWN_CAVEAT_PREFIXES = set([ "gen = ", + "guest = ", "type = ", "time < ", "user_id = ", @@ -183,15 +184,11 @@ class Auth(object): defer.returnValue(member) @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id, current_state=None): + def check_user_was_in_room(self, room_id, user_id): """Check if the user was in the room at some point. Args: room_id(str): The room to check. user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. - If provided then that map is used to check whether they are a - member of the room. Otherwise the current membership is - loaded from the database. Raises: AuthError if the user was never in the room. Returns: @@ -199,17 +196,11 @@ class Auth(object): room. This will be the join event if they are currently joined to the room. This will be the leave event if they have left the room. """ - if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) - else: - member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): @@ -497,7 +488,7 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request): + def get_user_by_req(self, request, allow_guest=False): """ Get a registered user's ID. Args: @@ -535,7 +526,7 @@ class Auth(object): request.authenticated_entity = user_id - defer.returnValue((UserID.from_string(user_id), "")) + defer.returnValue((UserID.from_string(user_id), "", False)) return except KeyError: pass # normal users won't have the user_id query parameter set. @@ -543,6 +534,7 @@ class Auth(object): user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] token_id = user_info["token_id"] + is_guest = user_info["is_guest"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -557,9 +549,14 @@ class Auth(object): user_agent=user_agent ) + if is_guest and not allow_guest: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + request.authenticated_entity = user.to_string() - defer.returnValue((user, token_id,)) + defer.returnValue((user, token_id, is_guest,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -592,31 +589,45 @@ class Auth(object): self._validate_macaroon(macaroon) user_prefix = "user_id = " + user = None + guest = False for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. - ret = yield self._look_up_user_by_access_token(macaroon_str) - if ret["user"] != user: - logger.error( - "Macaroon user (%s) != DB user (%s)", - user, - ret["user"] - ) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "User mismatch in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) - defer.returnValue(ret) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + elif caveat.caveat_id == "guest = true": + guest = True + + if user is None: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + + if guest: + ret = { + "user": user, + "is_guest": True, + "token_id": None, + } + else: + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", @@ -629,6 +640,7 @@ class Auth(object): v.satisfy_exact("type = access") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(self._verify_expiry) + v.satisfy_exact("guest = true") v.verify(macaroon, self.hs.config.macaroon_secret_key) v = pymacaroons.Verifier() @@ -666,6 +678,7 @@ class Auth(object): user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), + "is_guest": False, } defer.returnValue(user_info) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3fea27d0e..d4037b3d55 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -33,6 +33,7 @@ class Codes(object): NOT_FOUND = "M_NOT_FOUND" MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" + GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f5ef36a9f4..dca391f7af 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,6 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) + self.allow_guest_access = config.get("allow_guest_access", False) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -54,6 +55,11 @@ class RegistrationConfig(Config): # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. bcrypt_rounds: 12 + + # Allows users to register as guests without a password/email/etc, and + # participate in rooms hosted on this server which have been made + # accessible to anonymous users. + allow_guest_access: False """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a26cb1879..6873a4575d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -47,37 +47,23 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events): - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) + def _filter_events_for_client(self, user_id, events, is_guest=False): + # Assumes that user has at some point joined the room if not is_guest. - def allowed(event, state): - if event.type == EventTypes.RoomHistoryVisibility: + def allowed(event, membership, visibility): + if visibility == "world_readable": return True - membership_ev = state.get((EventTypes.Member, user_id), None) - if membership_ev: - membership = membership_ev.membership - else: - membership = Membership.LEAVE + if is_guest: + return False if membership == Membership.JOIN: return True - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - else: - visibility = "shared" + if event.type == EventTypes.RoomHistoryVisibility: + return not is_guest - if visibility == "public": - return True - elif visibility == "shared": + if visibility == "shared": return True elif visibility == "joined": return membership == Membership.JOIN @@ -86,11 +72,44 @@ class BaseHandler(object): return True - defer.returnValue([ - event - for event in events - if allowed(event, event_id_to_state[event.event_id]) - ]) + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + ) + + events_to_return = [] + for event in events: + state = event_id_to_state[event.event_id] + + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + membership = membership_event.membership + else: + membership = None + + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + if visibility_event: + visibility = visibility_event.content.get("history_visibility", "shared") + else: + visibility = "shared" + + should_include = allowed(event, membership, visibility) + if should_include: + events_to_return.append(event) + + if is_guest and len(events_to_return) < len(events): + # This indicates that some events in the requested range were not + # visible to guest users. To be safe, we reject the entire request, + # so that we don't have to worry about interpreting visibility + # boundaries. + raise AuthError(403, "User %s does not have permission" % ( + user_id + )) + + defer.returnValue(events_to_return) def ratelimit(self, user_id): time_now = self.clock.time() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 055d395b20..1b11dbdffd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -372,12 +372,15 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id): + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) macaroon.add_first_party_caveat("time < %d" % (expiry,)) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) return macaroon.serialize() def generate_refresh_token(self, user_id): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f947993d1..687e1527f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -71,20 +71,20 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, is_guest=False): """Get messages in a room. Args: user_id (str): The user requesting messages. room_id (str): The room they want messages from. pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. + config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + is_guest (bool): Whether the requesting user is a guest (as opposed + to a fully registered user). Returns: dict: Pagination API results """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: @@ -107,23 +107,27 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - if member_event.membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room - leave_token = yield self.store.get_topological_token_for_event( - member_event.event_id - ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: - source_config.from_key = str(leave_token) - - if source_config.direction == "f": - if source_config.to_key is None: - source_config.to_key = str(leave_token) - else: - to_token = RoomStreamToken.parse(source_config.to_key) - if leave_token.topological < to_token.topological: + if not is_guest: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + if member_event.membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room. + # If they're a guest, we'll just 403 them if they're asking for + # events they can't see. + leave_token = yield self.store.get_topological_token_for_event( + member_event.event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) + + if source_config.direction == "f": + if source_config.to_key is None: source_config.to_key = str(leave_token) + else: + to_token = RoomStreamToken.parse(source_config.to_key) + if leave_token.topological < to_token.topological: + source_config.to_key = str(leave_token) yield self.hs.get_handlers().federation_handler.maybe_backfill( room_id, room_token.topological @@ -146,7 +150,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client(user_id, events) + events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest) time_now = self.clock.time_msec() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef4081e3fe..493a087031 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register(self, localpart=None, password=None): + def register(self, localpart=None, password=None, generate_token=True): """Registers a new client on the server. Args: @@ -89,7 +89,9 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = None + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -102,14 +104,14 @@ class RegistrationHandler(BaseHandler): attempts = 0 user_id = None token = None - while not user_id and not token: + while not user_id: try: localpart = self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - - token = self.auth_handler().generate_access_token(user_id) + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 504b63eab4..bdde43864c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4dcda57c1b..240eedac75 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 582148b659..4073b0d2d1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 52c7943400..856a70f297 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a770efd841..6fe5d19a22 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index fdde88a60d..6b379e4e5f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index bd759a2589..b0870db1ac 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") @@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_DELETE(self, request): spec = _rule_spec_from_path(request.postpath) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) @@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 3aabc93b8b..a110c0a4f0 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2dcaee86cd..0876e593c5 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler events = yield handler.get_state_events( room_id=room_id, @@ -325,7 +325,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -334,6 +334,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, pagin_config=pagination_config, as_client_event=as_client_event ) @@ -347,7 +348,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -363,7 +364,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -443,7 +444,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -524,7 +525,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -564,7 +565,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) @@ -597,7 +598,7 @@ class SearchRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 0a863e1c61..eb7c57cade 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4692ba413c..1970ad3458 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): yield run_on_reactor() - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepids = yield self.hs.get_datastore().user_get_threepids( auth_user.to_string() @@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index f8f91b63f5..97956a4b91 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index a1f4423101..820d33336f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index b107b7ce17..788acd4adb 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ba2f29711..f899376311 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern, parse_json_dict_from_request @@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + kind = "user" + if "kind" in request.args: + kind = request.args["kind"][0] + + if kind == "guest": + ret = yield self._do_guest_registration() + defer.returnValue(ret) + return + elif kind != "user": + raise UnrecognizedRequestError( + "Do not understand membership kind: %s" % (kind,) + ) + if '/register/email/requestToken' in request.path: ret = yield self.onEmailTokenRequest(request) defer.returnValue(ret) @@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet): ret = yield self.identity_handler.requestEmailToken(**body) defer.returnValue((200, ret)) + @defer.inlineCallbacks + def _do_guest_registration(self): + if not self.hs.config.allow_guest_access: + defer.returnValue((403, "Guest access is disabled")) + user_id, _ = yield self.registration_handler.register(generate_token=False) + access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + defer.returnValue((200, { + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + })) + def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 32a1087c91..d24507effa 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index dcfe6bd20e..35482ae6a6 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -42,7 +42,7 @@ class TagListServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, room_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot get tags for other users.") @@ -68,7 +68,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -88,7 +88,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index c28dc86cd7..e4fa8c4647 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 6abaf56b25..7d61596082 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index b454dd5b3a..2e5eddd259 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -102,13 +102,14 @@ class RegistrationStore(SQLBaseStore): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - # it's possible for this to get a conflict, but only for a single user - # since tokens are namespaced based on their user ID - txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" - " VALUES (?,?,?)", - (next_id, user_id, token,) - ) + if token: + # it's possible for this to get a conflict, but only for a single user + # since tokens are namespaced based on their user ID + txn.execute( + "INSERT INTO access_tokens(id, user_id, token)" + " VALUES (?,?,?)", + (next_id, user_id, token,) + ) def get_user_by_id(self, user_id): return self._simple_select_one( diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index c96273480d..70d928defe 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + @defer.inlineCallbacks + def test_get_guest_user_from_macaroon(self): + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("guest = true") + serialized = macaroon.serialize() + + user_info = yield self.auth._get_user_from_macaroon(serialized) + user = user_info["user"] + is_guest = user_info["is_guest"] + self.assertEqual(UserID.from_string(user_id), user) + self.assertTrue(is_guest) + @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): self.store.get_user_by_access_token = Mock( diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0e3b922246..3e0f294630 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -86,10 +86,11 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -173,10 +174,11 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.handlers.room_member_handler = Mock( @@ -291,8 +293,8 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 - def _get_user_by_req(req=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(req=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 929e5e5dd4..adcc1d1969 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase): replication_layer=Mock(), ) - def _get_user_by_req(request=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(request=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 93896dd076..b43563fa4b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 6395ce79db..8433585616 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f45570a1c0..fa9e17ec4f 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.USER_ID), "token_id": 1, + "is_guest": False, } hs.get_auth()._get_user_by_access_token = _get_user_by_access_token -- cgit 1.5.1 From 05c326d445fdd04ef06cb9a2a02cc0d6dde9935b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Nov 2015 17:57:44 +0000 Subject: Implement order and group by --- synapse/handlers/search.py | 126 +++++++++++++++++++++++++++++++++++++++------ synapse/storage/search.py | 96 ++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 2718e9482e..28f5300dc9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -46,15 +46,20 @@ class SearchHandler(BaseHandler): """ try: - search_term = content["search_categories"]["room_events"]["search_term"] - keys = content["search_categories"]["room_events"].get("keys", [ + room_cat = content["search_categories"]["room_events"] + search_term = room_cat["search_term"] + keys = room_cat.get("keys", [ "content.body", "content.name", "content.topic", ]) - filter_dict = content["search_categories"]["room_events"].get("filter", {}) - event_context = content["search_categories"]["room_events"].get( + filter_dict = room_cat.get("filter", {}) + order_by = room_cat.get("order_by", "rank") + event_context = room_cat.get( "event_context", None ) + group_by = room_cat.get("groupings", {}).get("group_by", {}) + group_keys = [g["key"] for g in group_by] + if event_context is not None: before_limit = int(event_context.get( "before_limit", 5 @@ -65,6 +70,15 @@ class SearchHandler(BaseHandler): except KeyError: raise SynapseError(400, "Invalid search query") + if order_by not in ("rank", "recent"): + raise SynapseError(400, "Invalid order by: %r" % (order_by,)) + + if set(group_keys) - {"room_id", "sender"}: + raise SynapseError( + 400, + "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + ) + search_filter = Filter(filter_dict) # TODO: Search through left rooms too @@ -77,18 +91,88 @@ class SearchHandler(BaseHandler): room_ids = search_filter.filter_rooms(room_ids) - rank_map, event_map, _ = yield self.store.search_msgs( - room_ids, search_term, keys - ) + rank_map = {} + allowed_events = [] + room_groups = {} + sender_group = {} - filtered_events = search_filter.filter(event_map.values()) + if order_by == "rank": + rank_map, event_map, _ = yield self.store.search_msgs( + room_ids, search_term, keys + ) - allowed_events = yield self._filter_events_for_client( - user.to_string(), filtered_events - ) + filtered_events = search_filter.filter(event_map.values()) - allowed_events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = allowed_events[:search_filter.limit()] + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[:search_filter.limit()] + + for e in allowed_events: + rm = room_groups.setdefault(e.room_id, { + "results": [], + "order": rank_map[e.event_id], + }) + rm["results"].append(e.event_id) + + s = sender_group.setdefault(e.sender, { + "results": [], + "order": rank_map[e.event_id], + }) + s["results"].append(e.event_id) + + elif order_by == "recent": + for room_id in room_ids: + room_events = [] + pagination_token = None + i = 0 + + while len(room_events) < search_filter.limit() and i < 5: + i += 5 + r_map, event_map, pagination_token = yield self.store.search_room( + room_id, search_term, keys, search_filter.limit() * 2, + pagination_token=pagination_token, + ) + rank_map.update(r_map) + + filtered_events = search_filter.filter(event_map.values()) + + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[:search_filter.limit()] + + if len(event_map) < search_filter.limit() * 2: + break + + if room_events: + group = room_groups.setdefault(room_id, {}) + if pagination_token: + group["next_batch"] = pagination_token + + group["results"] = [e.event_id for e in room_events] + group["order"] = max( + e.origin_server_ts/1000 for e in room_events + if hasattr(e, "origin_server_ts") + ) + + allowed_events.extend(room_events) + + # Normalize the group ranks + if room_groups: + mx = max(g["order"] for g in room_groups.values()) + mn = min(g["order"] for g in room_groups.values()) + + for g in room_groups.values(): + g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) + + else: + # We should never get here due to the guard earlier. + raise NotImplementedError() if event_context is not None: now_token = yield self.hs.get_event_sources().get_current_token() @@ -144,11 +228,19 @@ class SearchHandler(BaseHandler): logger.info("Found %d results", len(results)) + rooms_cat_res = { + "results": results, + "count": len(results) + } + + if room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + + if sender_group and "sender" in group_keys: + rooms_cat_res.setdefault("groups", {})["sender"] = sender_group + defer.returnValue({ "search_categories": { - "room_events": { - "results": results, - "count": len(results) - } + "room_events": rooms_cat_res } }) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index cdf003502f..e37e56c1f2 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -20,6 +20,12 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from collections import namedtuple +import logging + + +logger = logging.getLogger(__name__) + + """The result of a search. Fields: @@ -109,3 +115,93 @@ class SearchStore(SQLBaseStore): event_map, None )) + + @defer.inlineCallbacks + def search_room(self, room_id, search_term, keys, limit, pagination_token=None): + """Performs a full text search over events with given keys. + + Args: + room_id (str): The room_id to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + pagination_token (str): A pagination token previously returned + + Returns: + SearchResult + """ + clauses = [] + args = [search_term, room_id] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) + + if pagination_token: + topo, stream = pagination_token.split(",") + clauses.append( + "(topological_ordering < ?" + " OR (topological_ordering = ? AND stream_ordering < ?))" + ) + args.extend([topo, topo, stream]) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, query) as rank," + " topological_ordering, stream_ordering, room_id, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " NATURAL JOIN events" + " WHERE vector @@ query AND room_id = ?" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " topological_ordering, stream_ordering" + " FROM event_search" + " NATURAL JOIN events" + " WHERE value MATCH ? AND room_id = ?" + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for clause in clauses: + sql += " AND " + clause + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + + args.append(limit) + + results = yield self._execute( + "search_rooms", self.cursor_to_dict, sql, *args + ) + + events = yield self._get_events([r["event_id"] for r in results]) + + event_map = { + ev.event_id: ev + for ev in events + } + + pagination_token = None + if results: + topo = results[-1]["topological_ordering"] + stream = results[-1]["stream_ordering"] + pagination_token = "%s,%s" % (topo, stream) + + defer.returnValue(SearchResult( + { + r["event_id"]: r["rank"] + for r in results + if r["event_id"] in event_map + }, + event_map, + pagination_token + )) -- cgit 1.5.1 From ca2f90742d5606f8fc5b7ddd3dd7244c377c1df8 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 5 Nov 2015 14:32:26 +0000 Subject: Open up /events to anonymous users for room events only Squash-merge of PR #345 from daniel/anonymousevents --- synapse/handlers/_base.py | 7 ++- synapse/handlers/events.py | 10 ++- synapse/handlers/message.py | 47 ++++++++++---- synapse/handlers/presence.py | 4 +- synapse/handlers/private_user_data.py | 2 +- synapse/handlers/receipts.py | 6 +- synapse/handlers/room.py | 11 +++- synapse/handlers/sync.py | 20 +++++- synapse/handlers/typing.py | 11 +--- synapse/notifier.py | 42 ++++++++++--- synapse/rest/client/v1/events.py | 13 +++- synapse/rest/client/v1/room.py | 6 +- synapse/storage/events.py | 2 + synapse/storage/room.py | 13 ++++ .../storage/schema/delta/25/history_visibility.sql | 26 ++++++++ synapse/storage/stream.py | 46 +++++++++++--- tests/handlers/test_presence.py | 71 ++++++++++++++++------ tests/handlers/test_typing.py | 30 +++++++-- tests/rest/client/v1/test_presence.py | 9 ++- tests/rest/client/v1/test_typing.py | 5 +- 20 files changed, 299 insertions(+), 82 deletions(-) create mode 100644 synapse/storage/schema/delta/25/history_visibility.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6873a4575d..a9e43052b7 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -47,7 +47,8 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events, is_guest=False): + def _filter_events_for_client(self, user_id, events, is_guest=False, + require_all_visible_for_guests=True): # Assumes that user has at some point joined the room if not is_guest. def allowed(event, membership, visibility): @@ -100,7 +101,9 @@ class BaseHandler(object): if should_include: events_to_return.append(event) - if is_guest and len(events_to_return) < len(events): + if (require_all_visible_for_guests + and is_guest + and len(events_to_return) < len(events)): # This indicates that some events in the requested range were not # visible to guest users. To be safe, we reject the entire request, # so that we don't have to worry about interpreting visibility diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 53c8ca3a26..0e4c0d4d06 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler): @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, as_client_event=True, affect_presence=True, - only_room_events=False): + only_room_events=False, room_id=None, is_guest=False): """Fetches the events stream for a given user. If `only_room_events` is `True` only room events will be returned. @@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + if is_guest: + yield self.distributor.fire( + "user_joined_room", user=auth_user, room_id=room_id + ) + events, tokens = yield self.notifier.get_events_for( auth_user, pagin_config, timeout, - only_room_events=only_room_events + only_room_events=only_room_events, + is_guest=is_guest, guest_room_id=room_id ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 687e1527f7..654ecd2b37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError, Codes from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -229,7 +229,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key=""): + event_type=None, state_key="", is_guest=False): """ Get data from a room. Args: @@ -239,23 +239,42 @@ class MessageHandler(BaseHandler): Raises: SynapseError if something went wrong. """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [member_event.event_id], [key] + [membership_event_id], [key] ) - data = room_state[member_event.event_id].get(key) + data = room_state[membership_event_id].get(key) defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id): + def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): + if is_guest: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if visibility.content["history_visibility"] == "world_readable": + defer.returnValue((Membership.JOIN, None)) + return + else: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + else: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + + @defer.inlineCallbacks + def get_state_events(self, user_id, room_id, is_guest=False): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. @@ -266,15 +285,17 @@ class MessageHandler(BaseHandler): Returns: A list of dicts representing state events. [{}, {}, {}] """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: room_state = yield self.state_handler.get_current_state(room_id) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [member_event.event_id], None + [membership_event_id], None ) - room_state = room_state[member_event.event_id] + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index ce60642127..0b780cd528 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1142,8 +1142,9 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, room_ids=None, **kwargs): from_key = int(from_key) + room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler cachemap = presence._user_cachemap @@ -1161,7 +1162,6 @@ class PresenceEventSource(object): user_ids_to_check |= set( UserID.from_string(p["observed_user_id"]) for p in presence_list ) - room_ids = yield presence.get_joined_rooms_for_user(user) for room_id in set(room_ids) & set(presence._room_serials): if presence._room_serials[room_id] > from_key: joined = yield presence.get_joined_users_for_room_id(room_id) diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py index 1778c71325..1abe45ed7b 100644 --- a/synapse/handlers/private_user_data.py +++ b/synapse/handlers/private_user_data.py @@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object): return self.store.get_max_private_user_data_stream_id() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a47ae3df42..973f4d5cae 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -164,17 +164,15 @@ class ReceiptEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) to_key = yield self.get_current_key() if from_key == to_key: defer.returnValue(([], to_key)) - rooms = yield self.store.get_rooms_for_user(user.to_string()) - rooms = [room.room_id for room in rooms] events = yield self.store.get_linearized_receipts_for_rooms( - rooms, + room_ids, from_key=from_key, to_key=to_key, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9184dcd048..736ffe9066 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -807,7 +807,14 @@ class RoomEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + limit, + room_ids, + is_guest, + ): # We just ignore the key for now. to_key = yield self.get_current_key() @@ -828,6 +835,8 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit, + room_ids=room_ids, + is_guest=is_guest, ) defer.returnValue((events, end_key)) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1c1ee34b1e..5294d96466 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -295,11 +295,16 @@ class SyncHandler(BaseHandler): typing_key = since_token.typing_key if since_token else "0" + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events_for_user( + typing, typing_key = yield typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + is_guest=False, ) now_token = now_token.copy_and_replace("typing_key", typing_key) @@ -312,10 +317,13 @@ class SyncHandler(BaseHandler): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events_for_user( + receipts, receipt_key = yield receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("receipt_key", receipt_key) @@ -360,11 +368,17 @@ class SyncHandler(BaseHandler): """ now_token = yield self.event_sources.get_current_token() + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events_for_user( + presence, presence_key = yield presence_source.get_new_events( user=sync_config.user, from_key=since_token.presence_key, limit=sync_config.filter.presence_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("presence_key", presence_key) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d7096aab8c..2846f3e6e8 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -246,17 +246,12 @@ class TypingNotificationEventSource(object): }, } - @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) handler = self.handler() - joined_room_ids = ( - yield self.room_member_handler().get_joined_rooms_for_user(user) - ) - events = [] - for room_id in joined_room_ids: + for room_id in room_ids: if room_id not in handler._room_serials: continue if handler._room_serials[room_id] <= from_key: @@ -264,7 +259,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - defer.returnValue((events, handler._latest_room_serial)) + return events, handler._latest_room_serial def get_current_key(self): return self.handler()._latest_room_serial diff --git a/synapse/notifier.py b/synapse/notifier.py index b69da63d43..56c4c863b5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -269,7 +269,7 @@ class Notifier(object): logger.exception("Failed to notify listener") @defer.inlineCallbacks - def wait_for_events(self, user, timeout, callback, + def wait_for_events(self, user, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. @@ -279,11 +279,12 @@ class Notifier(object): if user_stream is None: appservice = yield self.store.get_app_service_by_user_id(user) current_token = yield self.event_sources.get_current_token() - rooms = yield self.store.get_rooms_for_user(user) - rooms = [room.room_id for room in rooms] + if room_ids is None: + rooms = yield self.store.get_rooms_for_user(user) + room_ids = [room.room_id for room in rooms] user_stream = _NotifierUserStream( user=user, - rooms=rooms, + rooms=room_ids, appservice=appservice, current_token=current_token, time_now_ms=self.clock.time_msec(), @@ -329,7 +330,8 @@ class Notifier(object): @defer.inlineCallbacks def get_events_for(self, user, pagination_config, timeout, - only_room_events=False): + only_room_events=False, + is_guest=False, guest_room_id=None): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -342,6 +344,16 @@ class Notifier(object): limit = pagination_config.limit + room_ids = [] + if is_guest: + # TODO(daniel): Deal with non-room events too + only_room_events = True + if guest_room_id: + room_ids = [guest_room_id] + else: + rooms = yield self.store.get_rooms_for_user(user.to_string()) + room_ids = [room.room_id for room in rooms] + @defer.inlineCallbacks def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): @@ -357,9 +369,23 @@ class Notifier(object): continue if only_room_events and name != "room": continue - new_events, new_key = yield source.get_new_events_for_user( - user, getattr(from_token, keyname), limit, + new_events, new_key = yield source.get_new_events( + user=user, + from_key=getattr(from_token, keyname), + limit=limit, + is_guest=is_guest, + room_ids=room_ids, ) + + if is_guest: + room_member_handler = self.hs.get_handlers().room_member_handler + new_events = yield room_member_handler._filter_events_for_client( + user.to_string(), + new_events, + is_guest=is_guest, + require_all_visible_for_guests=False + ) + events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) @@ -369,7 +395,7 @@ class Notifier(object): defer.returnValue(None) result = yield self.wait_for_events( - user, timeout, check_for_updates, from_token=from_token + user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token ) if result is None: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 4073b0d2d1..3e1750d1a1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + auth_user, _, is_guest = yield self.auth.get_user_by_req( + request, + allow_guest=True + ) + room_id = None + if is_guest: + if "room_id" not in request.args: + raise SynapseError(400, "Guest users must specify room_id param") + room_id = request.args["room_id"][0] try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet): chunk = yield handler.get_stream( auth_user.to_string(), pagin_config, timeout=timeout, - as_client_event=as_client_event + as_client_event=as_client_event, affect_presence=(not is_guest), + room_id=room_id, is_guest=is_guest ) except: logger.exception("Event stream failed") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 0876e593c5..afb802baec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): room_id=room_id, event_type=event_type, state_key=state_key, + is_guest=is_guest, ) if not data: @@ -348,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, ) defer.returnValue((200, events)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e6c1abfc27..59c9987202 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore): self._store_room_message_txn(txn, event) elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) + elif event.type == EventTypes.RoomHistoryVisibility: + self._store_history_visibility_txn(txn, event) self._store_room_members_txn( txn, diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 13441fcdce..1c79626736 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore): txn, event, "content.body", event.content["body"] ) + def _store_history_visibility_txn(self, txn, event): + if hasattr(event, "content") and "history_visibility" in event.content: + sql = ( + "INSERT INTO history_visibility" + " (event_id, room_id, history_visibility)" + " VALUES (?, ?, ?)" + ) + txn.execute(sql, ( + event.event_id, + event.room_id, + event.content["history_visibility"] + )) + def _store_event_search_txn(self, txn, event, key, value): if isinstance(self.database_engine, PostgresEngine): sql = ( diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql new file mode 100644 index 0000000000..9f387ed69f --- /dev/null +++ b/synapse/storage/schema/delta/25/history_visibility.sql @@ -0,0 +1,26 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of history_visibility content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS history_visibility( + id INTEGER PRIMARY KEY, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + history_visibility TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c728013f4c..be8ba76aae 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, limit=0): - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) + def get_room_events_stream( + self, + user_id, + from_key, + to_key, + limit=0, + is_guest=False, + room_ids=None + ): + room_ids = room_ids or [] + room_ids = [r for r in room_ids] + if is_guest: + current_room_membership_sql = ( + "SELECT c.room_id FROM history_visibility AS h" + " INNER JOIN current_state_events AS c" + " ON h.event_id = c.event_id" + " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + ) + current_room_membership_args = room_ids + else: + current_room_membership_sql = ( + "SELECT m.room_id FROM room_memberships as m " + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id AND c.state_key = m.user_id" + " WHERE m.user_id = ? AND m.membership = 'join'" + ) + current_room_membership_args = [user_id] + if room_ids: + current_room_membership_sql += " AND m.room_id in (%s)" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + current_room_membership_args = [user_id] + room_ids # We also want to get any membership events about that user, e.g. # invites or leave notifications. @@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore): "INNER JOIN current_state_events as c ON m.event_id = c.event_id " "WHERE m.user_id = ? " ) + membership_args = [user_id] if limit: limit = max(limit, MAX_STREAM_SIZE) @@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore): } def f(txn): - txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) + args = ([False] + current_room_membership_args + membership_args + + [from_id.stream, to_id.stream]) + txn.execute(sql, args) rows = self.cursor_to_dict(txn) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 29372d488a..10d4482cce 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): {"presence": ONLINE} ) + # Apple sees self-reflection even without room_id + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + ) + + self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEquals(events, + [ + {"type": "m.presence", + "content": { + "user_id": "@apple:test", + "presence": ONLINE, + "last_active_ago": 0, + }}, + ], + msg="Presence event should be visible to self-reflection" + ) + # Apple sees self-reflection - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Banana sees it because of presence subscription - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_banana, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_banana, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Elderberry sees it because of same room - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_elderberry, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_elderberry, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Durian is not in the room, should not see this event - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_durian, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_durian, + from_key=0, + room_ids=[], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): "accepted": True}, ], presence) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 1, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=1, ) self.assertEquals(self.event_source.get_current_key(), 2) @@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) ) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 1) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.room_members.append(self.u_clementine) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 41bb08b7ca..2d7ba43561 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0 + ) self.assertEquals( events[0], [ @@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase): yield put_json.await_calls() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 2) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=1, + ) self.assertEquals( events[0], [ @@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 3e0f294630..7f29d73d95 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -47,7 +47,14 @@ class NullSource(object): def __init__(self, hs): pass - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + room_ids=None, + limit=None, + is_guest=None + ): return defer.succeed(([], from_key)) def get_current_key(self, direction='f'): diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 8433585616..61b9cc743b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.user, 0, None) + events = yield self.event_source.get_new_events( + from_key=0, + room_ids=[self.room_id], + ) self.assertEquals( events[0], [ -- cgit 1.5.1 From 7301e05122e07f6513916e8a35bf05581de6521d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Nov 2015 14:34:37 +0000 Subject: Implement basic pagination for search results --- synapse/handlers/search.py | 78 +++++++++++++++++++++++++++++++++++------- synapse/rest/client/v1/room.py | 3 +- synapse/storage/search.py | 55 ++++++++++------------------- 3 files changed, 86 insertions(+), 50 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 28f5300dc9..696780f34e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -22,6 +22,8 @@ from synapse.api.filtering import Filter from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from unpaddedbase64 import decode_base64, encode_base64 + import logging @@ -34,17 +36,32 @@ class SearchHandler(BaseHandler): super(SearchHandler, self).__init__(hs) @defer.inlineCallbacks - def search(self, user, content): + def search(self, user, content, batch=None): """Performs a full text search for a user. Args: user (UserID) content (dict): Search parameters + batch (str): The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search """ + batch_group = None + batch_group_key = None + batch_token = None + if batch: + try: + b = decode_base64(batch) + batch_group, batch_group_key, batch_token = b.split("\n") + + assert batch_group is not None + assert batch_group_key is not None + assert batch_token is not None + except: + raise SynapseError(400, "Invalid batch") + try: room_cat = content["search_categories"]["room_events"] search_term = room_cat["search_term"] @@ -91,17 +108,25 @@ class SearchHandler(BaseHandler): room_ids = search_filter.filter_rooms(room_ids) + if batch_group == "room_id": + room_ids = room_ids & {batch_group_key} + rank_map = {} allowed_events = [] room_groups = {} sender_group = {} + global_next_batch = None if order_by == "rank": - rank_map, event_map, _ = yield self.store.search_msgs( + results = yield self.store.search_msgs( room_ids, search_term, keys ) - filtered_events = search_filter.filter(event_map.values()) + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = search_filter.filter([r["event"] for r in results]) events = yield self._filter_events_for_client( user.to_string(), filtered_events @@ -126,18 +151,26 @@ class SearchHandler(BaseHandler): elif order_by == "recent": for room_id in room_ids: room_events = [] - pagination_token = None + if batch_group == "room_id" and batch_group_key == room_id: + pagination_token = batch_token + else: + pagination_token = None i = 0 while len(room_events) < search_filter.limit() and i < 5: i += 5 - r_map, event_map, pagination_token = yield self.store.search_room( + results = yield self.store.search_room( room_id, search_term, keys, search_filter.limit() * 2, pagination_token=pagination_token, ) - rank_map.update(r_map) - filtered_events = search_filter.filter(event_map.values()) + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = search_filter.filter([ + r["event"] for r in results + ]) events = yield self._filter_events_for_client( user.to_string(), filtered_events @@ -146,13 +179,26 @@ class SearchHandler(BaseHandler): room_events.extend(events) room_events = room_events[:search_filter.limit()] - if len(event_map) < search_filter.limit() * 2: + if len(results) < search_filter.limit() * 2: + pagination_token = None break + else: + pagination_token = results[-1]["pagination_token"] + + if room_events: + res = results_map[room_events[-1].event_id] + pagination_token = res["pagination_token"] if room_events: group = room_groups.setdefault(room_id, {}) if pagination_token: - group["next_batch"] = pagination_token + next_batch = encode_base64("%s\n%s\n%s" % ( + "room_id", room_id, pagination_token + )) + group["next_batch"] = next_batch + + if batch_token: + global_next_batch = next_batch group["results"] = [e.event_id for e in room_events] group["order"] = max( @@ -164,11 +210,14 @@ class SearchHandler(BaseHandler): # Normalize the group ranks if room_groups: - mx = max(g["order"] for g in room_groups.values()) - mn = min(g["order"] for g in room_groups.values()) + if len(room_groups) > 1: + mx = max(g["order"] for g in room_groups.values()) + mn = min(g["order"] for g in room_groups.values()) - for g in room_groups.values(): - g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) + for g in room_groups.values(): + g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) + else: + room_groups.values()[0]["order"] = 1 else: # We should never get here due to the guard earlier. @@ -239,6 +288,9 @@ class SearchHandler(BaseHandler): if sender_group and "sender" in group_keys: rooms_cat_res.setdefault("groups", {})["sender"] = sender_group + if global_next_batch: + rooms_cat_res["next_batch"] = global_next_batch + defer.returnValue({ "search_categories": { "room_events": rooms_cat_res diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2dcaee86cd..8e28f12d29 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -601,7 +601,8 @@ class SearchRestServlet(ClientV1RestServlet): content = _parse_json(request) - results = yield self.handlers.search_handler.search(auth_user, content) + batch = request.args.get("next_batch", [None])[0] + results = yield self.handlers.search_handler.search(auth_user, content, batch) defer.returnValue((200, results)) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index e37e56c1f2..7342e7bae6 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -18,24 +18,12 @@ from twisted.internet import defer from _base import SQLBaseStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from collections import namedtuple - import logging logger = logging.getLogger(__name__) -"""The result of a search. - -Fields: - rank_map (dict): Mapping event_id -> rank - event_map (dict): Mapping event_id -> event - pagination_token (str): Pagination token -""" -SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token")) - - class SearchStore(SQLBaseStore): @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): @@ -48,7 +36,7 @@ class SearchStore(SQLBaseStore): "content.body", "content.name", "content.topic" Returns: - SearchResult + list of dicts """ clauses = [] args = [] @@ -106,15 +94,14 @@ class SearchStore(SQLBaseStore): for ev in events } - defer.returnValue(SearchResult( + defer.returnValue([ { - r["event_id"]: r["rank"] - for r in results - if r["event_id"] in event_map - }, - event_map, - None - )) + "event": event_map[r["event_id"]], + "rank": r["rank"], + } + for r in results + if r["event_id"] in event_map + ]) @defer.inlineCallbacks def search_room(self, room_id, search_term, keys, limit, pagination_token=None): @@ -128,7 +115,7 @@ class SearchStore(SQLBaseStore): pagination_token (str): A pagination token previously returned Returns: - SearchResult + list of dicts """ clauses = [] args = [search_term, room_id] @@ -190,18 +177,14 @@ class SearchStore(SQLBaseStore): for ev in events } - pagination_token = None - if results: - topo = results[-1]["topological_ordering"] - stream = results[-1]["stream_ordering"] - pagination_token = "%s,%s" % (topo, stream) - - defer.returnValue(SearchResult( + defer.returnValue([ { - r["event_id"]: r["rank"] - for r in results - if r["event_id"] in event_map - }, - event_map, - pagination_token - )) + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" % ( + r["topological_ordering"], r["stream_ordering"] + ), + } + for r in results + if r["event_id"] in event_map + ]) -- cgit 1.5.1 From 3640ddfbf6641a79b486d3847ae739c974c62c7a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Nov 2015 16:10:54 +0000 Subject: Error handling --- synapse/storage/search.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 7342e7bae6..3cea2011fa 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -16,6 +16,7 @@ from twisted.internet import defer from _base import SQLBaseStore +from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging @@ -130,7 +131,13 @@ class SearchStore(SQLBaseStore): ) if pagination_token: - topo, stream = pagination_token.split(",") + try: + topo, stream = pagination_token.split(",") + topo = int(topo) + stream = int(stream) + except: + raise SynapseError(400, "Invalid pagination token") + clauses.append( "(topological_ordering < ?" " OR (topological_ordering = ? AND stream_ordering < ?))" -- cgit 1.5.1 From f2c4ee41b91f1828d516df871adcfbaed46d5407 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 6 Nov 2015 14:27:49 +0000 Subject: Remove accidentally added ID column --- synapse/storage/schema/delta/25/history_visibility.sql | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql index 9f387ed69f..532cb05151 100644 --- a/synapse/storage/schema/delta/25/history_visibility.sql +++ b/synapse/storage/schema/delta/25/history_visibility.sql @@ -18,7 +18,6 @@ * so that we can join on them in SELECT statements. */ CREATE TABLE IF NOT EXISTS history_visibility( - id INTEGER PRIMARY KEY, event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, -- cgit 1.5.1 From 767c20a869933204b6e545c6f1495cd7cd298a87 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Fri, 6 Nov 2015 20:49:57 +0100 Subject: add a key existence check to tags_by_room to avoid /events 500'ing when testing against vector --- synapse/storage/tags.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 641ea250f0..73babd53d9 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -95,7 +95,8 @@ class TagsStore(SQLBaseStore): if room_ids: tags_by_room = yield self.get_tags_for_user(user_id) for room_id in room_ids: - results[room_id] = tags_by_room[room_id] + if room_id in tags_by_room: + results[room_id] = tags_by_room[room_id] defer.returnValue(results) -- cgit 1.5.1 From dd40fb68e455b9a736f1871a19119eb0391cd0d5 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 8 Nov 2015 16:04:37 +0000 Subject: fix comedy important missing comma breaking recent-ordered FTS on sqlite --- synapse/storage/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 3cea2011fa..91a96d2a83 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -154,7 +154,7 @@ class SearchStore(SQLBaseStore): ) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id," " topological_ordering, stream_ordering" " FROM event_search" " NATURAL JOIN events" -- cgit 1.5.1 From c4135d85e116cecbc119c1243911a0b60a6452d7 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 9 Nov 2015 14:52:18 +0000 Subject: SYN-513: Include updates for rooms that have had all their tags deleted --- synapse/handlers/sync.py | 2 +- synapse/storage/tags.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5294d96466..ff766e3af5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -272,7 +272,7 @@ class SyncHandler(BaseHandler): def private_user_data_for_room(self, room_id, tags_by_room): private_user_data = [] tags = tags_by_room.get(room_id) - if tags: + if tags is not None: private_user_data.append({ "type": "m.tag", "content": {"tags": tags}, diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 641ea250f0..bf695b7800 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore): if room_ids: tags_by_room = yield self.get_tags_for_user(user_id) for room_id in room_ids: - results[room_id] = tags_by_room[room_id] + results[room_id] = tags_by_room.get(room_id, {}) defer.returnValue(results) -- cgit 1.5.1 From c6a01f2ed0da9f0db169307e7d1649021f0c2123 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 9 Nov 2015 14:37:28 +0000 Subject: Add storage module for tracking background updates. The progress for each background update is stored as a JSON blob in the database. Each background update is broken up into separate batches. The batch size is automatically tuned to try avoid blocking single threaded databases for too long. --- synapse/storage/background_updates.py | 210 +++++++++++++++++++++ .../schema/delta/25/00background_updates.sql | 21 +++ 2 files changed, 231 insertions(+) create mode 100644 synapse/storage/background_updates.py create mode 100644 synapse/storage/schema/delta/25/00background_updates.sql (limited to 'synapse/storage') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py new file mode 100644 index 0000000000..32a233c213 --- /dev/null +++ b/synapse/storage/background_updates.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore + +from twisted.internet import defer + +import ujson as json +import logging + +logger = logging.getLogger(__name__) + + +class BackgroundUpdatePerformance(object): + """Tracks the how long a background update is taking to update its items""" + + def __init__(self, name): + self.name = name + self.total_item_count = 0 + self.total_duration_ms = 0 + self.avg_item_count = 0 + self.avg_duration_ms = 0 + + def update(self, item_count, duration_ms): + """Update the stats after doing an update""" + self.total_item_count += item_count + self.total_duration_ms += duration_ms + + # Exponential moving averages for the number of items updated and + # the duration. + self.avg_item_count += 0.1 * (item_count - self.avg_item_count) + self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) + + def average_duration_ms_per_item(self): + """An estimate of how long it takes to do a single update. + Returns: + A duration in ms as a float + """ + if self.total_item_count == 0: + return None + else: + # Use the exponential moving average so that we can adapt to + # changes in how long the update process takes. + return float(self.avg_duration_ms) / float(self.avg_item_count) + + +class BackgroundUpdateStore(SQLBaseStore): + """ Background updates are updates to the database that run in the + background. Each update processes a batch of data at once. We attempt to + limit the impact of each update by monitoring how long each batch takes to + process and autotuning the batch size. + """ + + MINIMUM_BACKGROUND_BATCH_SIZE = 100 + DEFAULT_BACKGROUND_BATCH_SIZE = 100 + + def __init__(self, hs): + super(BackgroundUpdateStore, self).__init__(hs) + self._background_update_performance = {} + self._background_update_queue = [] + self._background_update_handlers = {} + + @defer.inlineCallbacks + def do_background_update(self, desired_duration_ms): + """Does some amount of work on a background update + Args: + desired_duration_ms(float): How long we want to spend + updating. + Returns: + A deferred that completes once some amount of work is done. + The deferred will have a value of None if there is currently + no more work to do. + """ + if not self._background_update_queue: + updates = yield self._simple_select_list( + "background_updates", + keyvalues=None, + retcols=("update_name",), + ) + for update in updates: + self._background_update_queue.append(update['update_name']) + + if not self._background_update_queue: + defer.returnValue(None) + + update_name = self._background_update_queue.pop(0) + self._background_update_queue.append(update_name) + + update_handler = self._background_update_handlers[update_name] + + performance = self._background_update_performance.get(update_name) + + if performance is None: + performance = BackgroundUpdatePerformance(update_name) + self._background_update_performance[update_name] = performance + + duration_ms_per_item = performance.average_duration_ms_per_item() + + if duration_ms_per_item is not None: + batch_size = int(desired_duration_ms / duration_ms_per_item) + # Clamp the batch size so that we always make progress + batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) + else: + batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE + + progress_json = yield self._simple_select_one_onecol( + "background_updates", + keyvalues={"update_name": update_name}, + retcol="progress_json" + ) + + progress = json.loads(progress_json) + + time_start = self._clock.time_msec() + items_updated = yield update_handler(progress, batch_size) + time_stop = self._clock.time_msec() + + duration_ms = time_stop - time_start + + logger.info( + "Updating %. Updated %r items in %rms", + update_name, items_updated, duration_ms + ) + + performance.update(items_updated, duration_ms) + + defer.returnValue(len(self._background_update_performance)) + + def register_background_update_handler(self, update_name, update_handler): + """Register a handler for doing a background update. + + The handler should take two arguments: + + * A dict of the current progress + * An integer count of the number of items to update in this batch. + + The handler should return a deferred integer count of items updated. + The hander is responsible for updating the progress of the update. + + Args: + update_name(str): The name of the update that this code handles. + update_handler(function): The function that does the update. + """ + self._background_update_handlers[update_name] = update_handler + + def start_background_update(self, update_name, progress): + """Starts a background update running. + + Args: + update_name: The update to set running. + progress: The initial state of the progress of the update. + + Returns: + A deferred that completes once the task has been added to the + queue. + """ + # Clear the background update queue so that we will pick up the new + # task on the next iteration of do_background_update. + self._background_update_queue = [] + progress_json = json.dumps(progress) + + return self._simple_insert( + "background_updates", + {"update_name": update_name, "progress_json": progress_json} + ) + + def _end_background_update(self, update_name): + """Removes a completed background update task from the queue. + + Args: + update_name(str): The name of the completed task to remove + Returns: + A deferred that completes once the task is removed. + """ + self._background_update_queue = [ + name for name in self._background_update_queue if name != update_name + ] + return self._simple_delete_one( + "background_updates", keyvalues={"update_name": update_name} + ) + + def _background_update_progress_txn(self, txn, update_name, progress): + """Update the progress of a background update + + Args: + txn(cursor): The transaction. + update_name(str): The name of the background update task + progress(dict): The progress of the update. + """ + + progress_json = json.dumps(progress) + + self._simple_update_one_txn( + txn, + "background_updates", + keyvalues={"update_name": update_name}, + updatevalues={"progress_json": progress_json}, + ) diff --git a/synapse/storage/schema/delta/25/00background_updates.sql b/synapse/storage/schema/delta/25/00background_updates.sql new file mode 100644 index 0000000000..41a9b59b1b --- /dev/null +++ b/synapse/storage/schema/delta/25/00background_updates.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS background_updates( + update_name TEXT NOT NULL, -- The name of the background update. + progress_json TEXT NOT NULL, -- The current progress of the update as JSON. + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); -- cgit 1.5.1 From 2ede7aa8a1da20b735797cc9230da7bbeb55efc3 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 9 Nov 2015 19:29:32 +0000 Subject: Add background update task for reindexing event search --- synapse/storage/search.py | 98 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 3cea2011fa..f7c269865d 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from _base import SQLBaseStore +from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -25,7 +25,101 @@ import logging logger = logging.getLogger(__name__) -class SearchStore(SQLBaseStore): +class SearchStore(BackgroundUpdateStore): + + EVENT_SEARCH_UPDATE_NAME = "event_search" + + + @defer.inlineCallbacks + def _background_reindex_search(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id"] + max_stream_id = progress["max_stream_id"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + TYPES = ["m.room.name", "m.room.message", "m.room.topic"] + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_id, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " AND (%s)" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) % (" OR ".join("type = '%s'" % TYPES),) + + txn.execute(sql, target_min_stream_id, max_stream_id, batch_size) + + rows = txn.fetch_all() + if not rows: + return None + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + events = self._get_events_txn(txn, event_ids) + + event_search_rows = [] + for event in events: + try: + event_id = event.event_id + room_id = event.room_id + content = event.content + if event.type == "m.room.message": + key = "content.body" + value = content["body"] + elif event.type == "m.room.topic": + key = "content.topic" + value = content["topic"] + elif event.type == "m.room.name": + key = "content.name" + value = content["name"] + except Exception: + # If the event is missing a necessary field then + # skip over it. + continue + + event_search_rows.append((event_id, room_id, key, value)) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, vector)" + " VALUES (?,?,?,to_tsvector('english', ?))" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): + clump = event_search_rows[index : index + INSERT_CLUMP_SIZE) + txn.execute_many(sql, clump) + + progress = { + "target_max_stream_id": target_min_stream_id, + "max_stream_id": min_stream_id, + "rows_inserted": rows_inserted + len(event_search_rows) + } + + self._background_update_progress_txn( + txn, self.EVENT_SEARCH_UPDATE_NAME, progress + ) + + return len(event_search_rows) + + result = yield self.runInteration( + self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn + ) + + if result is None: + yield _end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + + defer.returnValue(result) + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. -- cgit 1.5.1 From a412b9a465cb6a64f44e3656d8645977f43c275f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 10 Nov 2015 15:50:58 +0000 Subject: Run the background updates when starting synapse. --- synapse/app/homeserver.py | 1 + synapse/storage/background_updates.py | 57 ++++++++++++++++++++++++++++++----- synapse/storage/search.py | 11 +++++-- synapse/util/__init__.py | 8 +++++ 4 files changed, 67 insertions(+), 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a77535a4ee..cd7a52ec07 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -439,6 +439,7 @@ def setup(config_options): hs.get_pusherpool().start() hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() + hs.get_datastore().start_doing_background_updates() hs.get_replication_layer().start_get_pdu_cache() return hs diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 32a233c213..b6cdc6ec68 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -43,7 +43,7 @@ class BackgroundUpdatePerformance(object): self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) - def average_duration_ms_per_item(self): + def average_items_per_ms(self): """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -53,7 +53,17 @@ class BackgroundUpdatePerformance(object): else: # Use the exponential moving average so that we can adapt to # changes in how long the update process takes. - return float(self.avg_duration_ms) / float(self.avg_item_count) + return float(self.avg_item_count) / float(self.avg_duration_ms) + + def total_items_per_ms(self): + """An estimate of how long it takes to do a single update. + Returns: + A duration in ms as a float + """ + if self.total_item_count == 0: + return None + else: + return float(self.total_item_count) / float(self.total_duration_ms) class BackgroundUpdateStore(SQLBaseStore): @@ -65,12 +75,41 @@ class BackgroundUpdateStore(SQLBaseStore): MINIMUM_BACKGROUND_BATCH_SIZE = 100 DEFAULT_BACKGROUND_BATCH_SIZE = 100 + BACKGROUND_UPDATE_INTERVAL_MS = 1000 + BACKGROUND_UPDATE_DURATION_MS = 100 def __init__(self, hs): super(BackgroundUpdateStore, self).__init__(hs) self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} + self._background_update_timer = None + + @defer.inlineCallbacks + def start_doing_background_updates(self): + while True: + if self._background_update_timer is not None: + return + + sleep = defer.Deferred() + self._background_update_timer = self._clock.call_later( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback + ) + try: + yield sleep + finally: + self._background_update_timer = None + + result = yield self.do_background_update( + self.BACKGROUND_UPDATE_DURATION_MS + ) + + if result is None: + logger.info( + "No more background updates to do." + " Unscheduling background update task." + ) + return @defer.inlineCallbacks def do_background_update(self, desired_duration_ms): @@ -106,10 +145,10 @@ class BackgroundUpdateStore(SQLBaseStore): performance = BackgroundUpdatePerformance(update_name) self._background_update_performance[update_name] = performance - duration_ms_per_item = performance.average_duration_ms_per_item() + items_per_ms = performance.average_items_per_ms() - if duration_ms_per_item is not None: - batch_size = int(desired_duration_ms / duration_ms_per_item) + if items_per_ms is not None: + batch_size = int(desired_duration_ms * items_per_ms) # Clamp the batch size so that we always make progress batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) else: @@ -130,8 +169,12 @@ class BackgroundUpdateStore(SQLBaseStore): duration_ms = time_stop - time_start logger.info( - "Updating %. Updated %r items in %rms", - update_name, items_updated, duration_ms + "Updating %. Updated %r items in %rms." + " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)", + update_name, items_updated, duration_ms, + performance.total_items_per_ms(), + performance.average_items_per_ms(), + performance.total_item_count, ) performance.update(items_updated, duration_ms) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index f7c269865d..d170c546b5 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -29,6 +29,11 @@ class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" + def __init__(self, hs): + super(SearchStore, self).__init__(hs) + self.register_background_update_handler( + self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search + ) @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): @@ -74,7 +79,7 @@ class SearchStore(BackgroundUpdateStore): elif event.type == "m.room.name": key = "content.name" value = content["name"] - except Exception: + except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. continue @@ -96,7 +101,7 @@ class SearchStore(BackgroundUpdateStore): raise Exception("Unrecognized database engine") for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): - clump = event_search_rows[index : index + INSERT_CLUMP_SIZE) + clump = event_search_rows[index:index + INSERT_CLUMP_SIZE] txn.execute_many(sql, clump) progress = { @@ -116,7 +121,7 @@ class SearchStore(BackgroundUpdateStore): ) if result is None: - yield _end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) defer.returnValue(result) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 1d123ccefc..d69c7cb991 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -53,6 +53,14 @@ class Clock(object): loop.stop() def call_later(self, delay, callback, *args, **kwargs): + """Call something later + + Args: + delay(float): How long to wait in seconds. + callback(function): Function to call + *args: Postional arguments to pass to function. + **kwargs: Key arguments to pass to function. + """ current_context = LoggingContext.current_context() def wrapped_callback(*args, **kwargs): -- cgit 1.5.1 From 90b503216c40974dc4925a9bbb673779a039143c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 10 Nov 2015 16:20:13 +0000 Subject: Use a background task to update databases to use the full text search --- synapse/storage/schema/delta/25/fts.py | 100 ++++++++------------------------- synapse/storage/search.py | 8 +-- 2 files changed, 28 insertions(+), 80 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index b7cd0ce3b8..e408d36da0 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -22,7 +22,7 @@ import ujson logger = logging.getLogger(__name__) -POSTGRES_SQL = """ +POSTGRES_TABLE = """ CREATE TABLE IF NOT EXISTS event_search ( event_id TEXT, room_id TEXT, @@ -31,22 +31,6 @@ CREATE TABLE IF NOT EXISTS event_search ( vector tsvector ); -INSERT INTO event_search SELECT - event_id, room_id, json::json->>'sender', 'content.body', - to_tsvector('english', json::json->'content'->>'body') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.message'; - -INSERT INTO event_search SELECT - event_id, room_id, json::json->>'sender', 'content.name', - to_tsvector('english', json::json->'content'->>'name') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.name'; - -INSERT INTO event_search SELECT - event_id, room_id, json::json->>'sender', 'content.topic', - to_tsvector('english', json::json->'content'->>'topic') - FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic'; - - CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); CREATE INDEX event_search_ev_idx ON event_search(event_id); CREATE INDEX event_search_ev_ridx ON event_search(room_id); @@ -61,67 +45,31 @@ SQLITE_TABLE = ( def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): - run_postgres_upgrade(cur) + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) return if isinstance(database_engine, Sqlite3Engine): - run_sqlite_upgrade(cur) - return - - -def run_postgres_upgrade(cur): - for statement in get_statements(POSTGRES_SQL.splitlines()): - cur.execute(statement) - - -def run_sqlite_upgrade(cur): cur.execute(SQLITE_TABLE) + return - rowid = -1 - while True: - cur.execute( - "SELECT rowid, json FROM event_json" - " WHERE rowid > ?" - " ORDER BY rowid ASC LIMIT 100", - (rowid,) - ) - - res = cur.fetchall() - - if not res: - break - - events = [ - ujson.loads(js) - for _, js in res - ] - - rowid = max(rid for rid, _ in res) - - rows = [] - for ev in events: - content = ev.get("content", {}) - body = content.get("body", None) - name = content.get("name", None) - topic = content.get("topic", None) - sender = ev.get("sender", None) - if ev["type"] == "m.room.message" and body: - rows.append(( - ev["event_id"], ev["room_id"], sender, "content.body", body - )) - if ev["type"] == "m.room.name" and name: - rows.append(( - ev["event_id"], ev["room_id"], sender, "content.name", name - )) - if ev["type"] == "m.room.topic" and topic: - rows.append(( - ev["event_id"], ev["room_id"], sender, "content.topic", topic - )) - - if rows: - logger.info(rows) - cur.executemany( - "INSERT INTO event_search (event_id, room_id, sender, key, value)" - " VALUES (?,?,?,?,?)", - rows - ) + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = ujson.dumps(progress) + + cur.execute( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)", ("event_search", progress_json) + ) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index d170c546b5..3b19b118eb 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -37,8 +37,8 @@ class SearchStore(BackgroundUpdateStore): @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id"] - max_stream_id = progress["max_stream_id"] + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) INSERT_CLUMP_SIZE = 1000 @@ -105,8 +105,8 @@ class SearchStore(BackgroundUpdateStore): txn.execute_many(sql, clump) progress = { - "target_max_stream_id": target_min_stream_id, - "max_stream_id": min_stream_id, + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, "rows_inserted": rows_inserted + len(event_search_rows) } -- cgit 1.5.1 From cf437900e0c689aad40f3da89866cf84c0f7ef65 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 10 Nov 2015 17:10:27 +0000 Subject: Return world_readable and guest_can_join in /publicRooms --- synapse/storage/events.py | 2 + synapse/storage/room.py | 71 ++++++++++++++---------- synapse/storage/schema/delta/25/guest_access.sql | 25 +++++++++ tests/storage/test_room.py | 2 + 4 files changed, 71 insertions(+), 29 deletions(-) create mode 100644 synapse/storage/schema/delta/25/guest_access.sql (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 59c9987202..4a365ff639 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -313,6 +313,8 @@ class EventsStore(SQLBaseStore): self._store_redaction(txn, event) elif event.type == EventTypes.RoomHistoryVisibility: self._store_history_visibility_txn(txn, event) + elif event.type == EventTypes.GuestAccess: + self._store_guest_access_txn(txn, event) self._store_room_members_txn( txn, diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 1c79626736..4f08df478c 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -99,34 +99,39 @@ class RoomStore(SQLBaseStore): """ def f(txn): - topic_subquery = ( - "SELECT topics.event_id as event_id, " - "topics.room_id as room_id, topic " - "FROM topics " - "INNER JOIN current_state_events as c " - "ON c.event_id = topics.event_id " - ) - - name_subquery = ( - "SELECT room_names.event_id as event_id, " - "room_names.room_id as room_id, name " - "FROM room_names " - "INNER JOIN current_state_events as c " - "ON c.event_id = room_names.event_id " - ) + def subquery(table_name, column_name=None): + column_name = column_name or table_name + return ( + "SELECT %(table_name)s.event_id as event_id, " + "%(table_name)s.room_id as room_id, %(column_name)s " + "FROM %(table_name)s " + "INNER JOIN current_state_events as c " + "ON c.event_id = %(table_name)s.event_id " % { + "column_name": column_name, + "table_name": table_name, + } + ) - # We use non printing ascii character US (\x1F) as a separator sql = ( - "SELECT r.room_id, max(n.name), max(t.topic)" + "SELECT" + " r.room_id," + " max(n.name)," + " max(t.topic)," + " max(v.history_visibility)," + " max(g.guest_access)" " FROM rooms AS r" " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" + " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id" + " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id" " WHERE r.is_public = ?" - " GROUP BY r.room_id" - ) % { - "topic": topic_subquery, - "name": name_subquery, - } + " GROUP BY r.room_id" % { + "topic": subquery("topics", "topic"), + "name": subquery("room_names", "name"), + "history_visibility": subquery("history_visibility"), + "guest_access": subquery("guest_access"), + } + ) txn.execute(sql, (is_public,)) @@ -156,10 +161,12 @@ class RoomStore(SQLBaseStore): "room_id": r[0], "name": r[1], "topic": r[2], - "aliases": r[3], + "world_readable": r[3] == "world_readable", + "guest_can_join": r[4] == "can_join", + "aliases": r[5], } for r in rows - if r[3] # We only return rooms that have at least one alias. + if r[5] # We only return rooms that have at least one alias. ] defer.returnValue(ret) @@ -203,16 +210,22 @@ class RoomStore(SQLBaseStore): ) def _store_history_visibility_txn(self, txn, event): - if hasattr(event, "content") and "history_visibility" in event.content: + self._store_content_index_txn(txn, event, "history_visibility") + + def _store_guest_access_txn(self, txn, event): + self._store_content_index_txn(txn, event, "guest_access") + + def _store_content_index_txn(self, txn, event, key): + if hasattr(event, "content") and key in event.content: sql = ( - "INSERT INTO history_visibility" - " (event_id, room_id, history_visibility)" - " VALUES (?, ?, ?)" + "INSERT INTO %(key)s" + " (event_id, room_id, %(key)s)" + " VALUES (?, ?, ?)" % {"key": key} ) txn.execute(sql, ( event.event_id, event.room_id, - event.content["history_visibility"] + event.content[key] )) def _store_event_search_txn(self, txn, event, key, value): diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/schema/delta/25/guest_access.sql new file mode 100644 index 0000000000..bdb90e7118 --- /dev/null +++ b/synapse/storage/schema/delta/25/guest_access.sql @@ -0,0 +1,25 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of guest_access content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS guest_access( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + guest_access TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index caffce64e3..91c967548d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -73,6 +73,8 @@ class RoomStoreTestCase(unittest.TestCase): "room_id": self.room.to_string(), "topic": None, "aliases": [self.alias.to_string()], + "world_readable": False, + "guest_can_join": False, }, rooms[0]) -- cgit 1.5.1 From 940a16119205115c02a3b23e9f3c67a08486bae0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 11 Nov 2015 13:59:40 +0000 Subject: Fix the background update --- synapse/storage/background_updates.py | 13 ++++++++----- synapse/storage/schema/delta/25/fts.py | 7 +++---- synapse/storage/search.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b6cdc6ec68..45fccc2e5e 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -93,16 +93,19 @@ class BackgroundUpdateStore(SQLBaseStore): sleep = defer.Deferred() self._background_update_timer = self._clock.call_later( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None ) try: yield sleep finally: self._background_update_timer = None - result = yield self.do_background_update( - self.BACKGROUND_UPDATE_DURATION_MS - ) + try: + result = yield self.do_background_update( + self.BACKGROUND_UPDATE_DURATION_MS + ) + except: + logger.exception("Error doing update") if result is None: logger.info( @@ -169,7 +172,7 @@ class BackgroundUpdateStore(SQLBaseStore): duration_ms = time_stop - time_start logger.info( - "Updating %. Updated %r items in %rms." + "Updating %r. Updated %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)", update_name, items_updated, duration_ms, performance.total_items_per_ms(), diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index e408d36da0..ce92f59025 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -47,11 +47,10 @@ def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) - return - - if isinstance(database_engine, Sqlite3Engine): + elif isinstance(database_engine, Sqlite3Engine): cur.execute(SQLITE_TABLE) - return + else: + raise Exception("Unrecognized database engine") cur.execute("SELECT MIN(stream_ordering) FROM events") rows = cur.fetchall() diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 3b19b118eb..3c0d671129 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -46,18 +46,18 @@ class SearchStore(BackgroundUpdateStore): def reindex_search_txn(txn): sql = ( - "SELECT stream_id, event_id FROM events" + "SELECT stream_ordering, event_id FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" " LIMIT ?" - ) % (" OR ".join("type = '%s'" % TYPES),) + ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) - txn.execute(sql, target_min_stream_id, max_stream_id, batch_size) + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = txn.fetch_all() + rows = txn.fetchall() if not rows: - return None + return 0 min_stream_id = rows[-1][0] event_ids = [row[1] for row in rows] @@ -102,7 +102,7 @@ class SearchStore(BackgroundUpdateStore): for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): clump = event_search_rows[index:index + INSERT_CLUMP_SIZE] - txn.execute_many(sql, clump) + txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -116,11 +116,11 @@ class SearchStore(BackgroundUpdateStore): return len(event_search_rows) - result = yield self.runInteration( + result = yield self.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) - if result is None: + if not result: yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) defer.returnValue(result) -- cgit 1.5.1 From e1627388d15fac534e22b80c2f7cb1a4e1afacc3 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 11 Nov 2015 17:14:56 +0000 Subject: Fix param style to work on both sqlite and postgres --- synapse/storage/schema/delta/25/fts.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index ce92f59025..5239d69073 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -68,7 +68,11 @@ def run_upgrade(cur, database_engine, *args, **kwargs): } progress_json = ujson.dumps(progress) - cur.execute( + sql = ( "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)", ("event_search", progress_json) + " VALUES (?, ?)" ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search", progress_json)) -- cgit 1.5.1 From 39de87869c9ad966f382e597986e860d03f3fef2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 14:06:31 +0000 Subject: Fix bug where assumed dict was namedtuple --- synapse/storage/transactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4e0d7c9774..ad099775eb 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -59,7 +59,7 @@ class TransactionStore(SQLBaseStore): allow_none=True, ) - if result and result.response_code: + if result and result["response_code"]: return result["response_code"], result["response_json"] else: return None -- cgit 1.5.1 From 14a9d805b959b5d2b26fea0d57794cb3ceb28958 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 14:07:25 +0000 Subject: Use a (hopefully) more efficient SQL query for doing recency based room search --- synapse/storage/search.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 2e88c51ad0..eac265b543 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -253,11 +253,13 @@ class SearchStore(BackgroundUpdateStore): ) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id," + "SELECT rank(matchinfo) as rank, room_id, event_id," " topological_ordering, stream_ordering" - " FROM event_search" - " NATURAL JOIN events" - " WHERE value MATCH ? AND room_id = ?" + " FROM (SELECT event_id, matchinfo(event_search) FROM event_search" + " WHERE value MATCH" + " )" + " CROSS JOIN events USING (event_id)" + " WHERE room_id = ?" ) else: # This should be unreachable. -- cgit 1.5.1 From 320408ef47eb373c2b8edad7ab3d08c3fa0cb459 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 15:09:45 +0000 Subject: Fix SQL syntax --- synapse/storage/search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index eac265b543..e1911e2480 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -255,8 +255,9 @@ class SearchStore(BackgroundUpdateStore): sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," " topological_ordering, stream_ordering" - " FROM (SELECT event_id, matchinfo(event_search) FROM event_search" - " WHERE value MATCH" + " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" + " FROM event_search" + " WHERE value MATCH ?" " )" " CROSS JOIN events USING (event_id)" " WHERE room_id = ?" -- cgit 1.5.1 From 764e79d0514a0cce9dc456df0355108d40dbeda8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 15:19:56 +0000 Subject: Comment --- synapse/storage/search.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index e1911e2480..0b00ddb9db 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -252,6 +252,8 @@ class SearchStore(BackgroundUpdateStore): " WHERE vector @@ query AND room_id = ?" ) elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. + # https://sqlite.org/optoverview.html#crossjoin sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," " topological_ordering, stream_ordering" -- cgit 1.5.1 From 8fd8e72cec8df94403441664d8eecc25fa8d363f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 15:33:47 +0000 Subject: Expand comment --- synapse/storage/search.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 0b00ddb9db..dcc5ac65a3 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -254,6 +254,12 @@ class SearchStore(BackgroundUpdateStore): elif isinstance(self.database_engine, Sqlite3Engine): # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin + # + # We want to use the full text search index on event_search to + # extract all possible matches first, then lookup those matches + # in the events table to get the topological ordering. We need + # to use the indexes in this order because sqlite refuses to + # MATCH unless it uses the full text search index sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," " topological_ordering, stream_ordering" -- cgit 1.5.1 From 3de46c7755ac5e061fc13fb916c13b841b65f2b5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Nov 2015 15:36:43 +0000 Subject: Trailing whitespace --- synapse/storage/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/search.py b/synapse/storage/search.py index dcc5ac65a3..380270b009 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -258,7 +258,7 @@ class SearchStore(BackgroundUpdateStore): # We want to use the full text search index on event_search to # extract all possible matches first, then lookup those matches # in the events table to get the topological ordering. We need - # to use the indexes in this order because sqlite refuses to + # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," -- cgit 1.5.1 From fddedd51d9ca44f385d412e8c2e3e8d99325ba67 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Nov 2015 18:27:23 +0000 Subject: Fix a few race conditions in the state calculation Be a bit more careful about how we calculate the state to be returned by /sync. In a few places, it was possible for /sync to return slightly later state than that represented by the next_batch token and the timeline. In particular, the following cases were susceptible: * On a full state sync, for an active room * During a per-room incremental sync with a timeline gap * When the user has just joined a room. (Refactor check_joined_room to make it less magical) Also, use store.get_state_for_events() (and thus the existing stategroups) to calculate the state corresponding to a particular sync position, rather than state_handler.compute_event_context(), which recalculates from first principles (and tends to miss some state). Merged from PR https://github.com/matrix-org/synapse/pull/372 --- synapse/handlers/sync.py | 123 ++++++++++++++++++++++++----------------------- synapse/storage/state.py | 14 ++++++ 2 files changed, 77 insertions(+), 60 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8b154fa7e7..6dc9d0fb92 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -254,9 +254,7 @@ class SyncHandler(BaseHandler): room_id, sync_config, now_token, since_token=timeline_since_token ) - current_state = yield self.state_handler.get_current_state( - room_id - ) + current_state = yield self.get_state_at(room_id, now_token) defer.returnValue(JoinedSyncResult( room_id=room_id, @@ -353,14 +351,12 @@ class SyncHandler(BaseHandler): room_id, sync_config, leave_token, since_token=timeline_since_token ) - leave_state = yield self.store.get_state_for_events( - [leave_event_id], None - ) + leave_state = yield self.store.get_state_for_event(leave_event_id) defer.returnValue(ArchivedSyncResult( room_id=room_id, timeline=batch, - state=leave_state[leave_event_id], + state=leave_state, private_user_data=self.private_user_data_for_room( room_id, tags_by_room ), @@ -424,6 +420,9 @@ class SyncHandler(BaseHandler): if len(room_events) <= timeline_limit: # There is no gap in any of the rooms. Therefore we can just # partition the new events by room and return them. + logger.debug("Got %i events for incremental sync - not limited", + len(room_events)) + invite_events = [] leave_events = [] events_by_room_id = {} @@ -439,9 +438,11 @@ class SyncHandler(BaseHandler): for room_id in joined_room_ids: recents = events_by_room_id.get(room_id, []) + logger.debug("Events for room %s: %r", room_id, recents) state = { (event.type, event.state_key): event for event in recents if event.is_state()} + limited = False if recents: prev_batch = now_token.copy_and_replace( @@ -450,9 +451,13 @@ class SyncHandler(BaseHandler): else: prev_batch = now_token - state, limited = yield self.check_joined_room( - sync_config, room_id, state - ) + just_joined = yield self.check_joined_room(sync_config, state) + if just_joined: + logger.debug("User has just joined %s: needs full state", + room_id) + state = yield self.get_state_at(room_id, now_token) + # the timeline is inherently limited if we've just joined + limited = True room_sync = JoinedSyncResult( room_id=room_id, @@ -467,10 +472,15 @@ class SyncHandler(BaseHandler): room_id, tags_by_room ), ) + logger.debug("Result for room %s: %r", room_id, room_sync) + if room_sync: joined.append(room_sync) else: + logger.debug("Got %i events for incremental sync - hit limit", + len(room_events)) + invite_events = yield self.store.get_invites_for_user( sync_config.user.to_string() ) @@ -563,6 +573,8 @@ class SyncHandler(BaseHandler): Returns: A Deferred JoinedSyncResult """ + logger.debug("Doing incremental sync for room %s between %s and %s", + room_id, since_token, now_token) # TODO(mjark): Check for redactions we might have missed. @@ -572,30 +584,26 @@ class SyncHandler(BaseHandler): logging.debug("Recents %r", batch) - # TODO(mjark): This seems racy since this isn't being passed a - # token to indicate what point in the stream this is - current_state = yield self.state_handler.get_current_state( - room_id - ) + current_state = yield self.get_state_at(room_id, now_token) - state_at_previous_sync = yield self.get_state_at_previous_sync( - room_id, since_token=since_token + state_at_previous_sync = yield self.get_state_at( + room_id, stream_position=since_token ) - state_events_delta = yield self.compute_state_delta( + state = yield self.compute_state_delta( since_token=since_token, previous_state=state_at_previous_sync, current_state=current_state, ) - state_events_delta, _ = yield self.check_joined_room( - sync_config, room_id, state_events_delta - ) + just_joined = yield self.check_joined_room(sync_config, state) + if just_joined: + state = yield self.get_state_at(room_id, now_token) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, - state=state_events_delta, + state=state, ephemeral=ephemeral_by_room.get(room_id, []), private_user_data=self.private_user_data_for_room( room_id, tags_by_room @@ -627,16 +635,12 @@ class SyncHandler(BaseHandler): logging.debug("Recents %r", batch) - # TODO(mjark): This seems racy since this isn't being passed a - # token to indicate what point in the stream this is - leave_state = yield self.store.get_state_for_events( - [leave_event.event_id], None + state_events_at_leave = yield self.store.get_state_for_event( + leave_event.event_id ) - state_events_at_leave = leave_state[leave_event.event_id] - - state_at_previous_sync = yield self.get_state_at_previous_sync( - leave_event.room_id, since_token=since_token + state_at_previous_sync = yield self.get_state_at( + leave_event.room_id, stream_position=since_token ) state_events_delta = yield self.compute_state_delta( @@ -659,26 +663,36 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) @defer.inlineCallbacks - def get_state_at_previous_sync(self, room_id, since_token): - """ Get the room state at the previous sync the client made. - Returns: - A Deferred map from ((type, state_key)->Event) + def get_state_after_event(self, event): + """ + Get the room state after the given event + + :param synapse.events.EventBase event: event of interest + :return: A Deferred map from ((type, state_key)->Event) + """ + state = yield self.store.get_state_for_event(event.event_id) + if event.is_state(): + state = state.copy() + state[(event.type, event.state_key)] = event + defer.returnValue(state) + + @defer.inlineCallbacks + def get_state_at(self, room_id, stream_position): + """ Get the room state at a particular stream position + :param str room_id: room for which to get state + :param StreamToken stream_position: point at which to get state + :returns: A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( - room_id, end_token=since_token.room_key, limit=1, + room_id, end_token=stream_position.room_key, limit=1, ) if last_events: - last_event = last_events[0] - last_context = yield self.state_handler.compute_event_context( - last_event - ) - if last_event.is_state(): - state = last_context.current_state.copy() - state[(last_event.type, last_event.state_key)] = last_event - else: - state = last_context.current_state + last_event = last_events[-1] + state = yield self.get_state_after_event(last_event) + else: + # no events in this room - so presumably no state state = {} defer.returnValue(state) @@ -706,31 +720,20 @@ class SyncHandler(BaseHandler): state_delta[key] = event return state_delta - @defer.inlineCallbacks - def check_joined_room(self, sync_config, room_id, state_delta): + def check_joined_room(self, sync_config, state_delta): """ - Check if the user has just joined the given room. If so, return the - full state for the room, instead of the delta since the last sync. + Check if the user has just joined the given room (so should + be given the full state) :param sync_config: - :param room_id: :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the difference in state since the last sync :returns A deferred Tuple (state_delta, limited) """ - joined = False - limited = False - join_event = state_delta.get(( EventTypes.Member, sync_config.user.to_string()), None) if join_event is not None: if join_event.content["membership"] == Membership.JOIN: - joined = True - - if joined: - state_delta = yield self.state_handler.get_current_state(room_id) - # the timeline is inherently limited if we've just joined - limited = True - - defer.returnValue((state_delta, limited)) + return True + return False diff --git a/synapse/storage/state.py b/synapse/storage/state.py index acfb322a53..80e9b63f50 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -237,6 +237,20 @@ class StateStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) + @defer.inlineCallbacks + def get_state_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + :param str event_id: event whose state should be returned + :param list[(str, str)]|None types: List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + :return: a deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + @cached(num_args=2, lru=True, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( -- cgit 1.5.1 From e4d622aaaf0df503f942d016a5bf798dd52899d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Nov 2015 18:29:25 +0000 Subject: Implementation of state rollback in /sync Implementation of SPEC-254: roll back the state dictionary to how it looked at the start of the timeline. Merged PR https://github.com/matrix-org/synapse/pull/373 --- synapse/rest/client/v2_alpha/sync.py | 67 ++++++++++++++++++++++++++++++++++-- synapse/storage/events.py | 6 ++-- 2 files changed, 69 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 272a00bc85..efd8281558 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -20,6 +20,7 @@ from synapse.http.servlet import ( ) from synapse.handlers.sync import SyncConfig from synapse.types import StreamToken +from synapse.events import FrozenEvent from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_event_id, ) @@ -256,7 +257,13 @@ class SyncRestServlet(RestServlet): :rtype: dict[str, object] """ event_map = {} - state_events = filter.filter_room_state(room.state.values()) + state_dict = room.state + timeline_events = filter.filter_room_timeline(room.timeline.events) + + state_dict = SyncRestServlet._rollback_state_for_timeline( + state_dict, timeline_events) + + state_events = filter.filter_room_state(state_dict.values()) state_event_ids = [] for event in state_events: # TODO(mjark): Respect formatting requirements in the filter. @@ -266,7 +273,6 @@ class SyncRestServlet(RestServlet): ) state_event_ids.append(event.event_id) - timeline_events = filter.filter_room_timeline(room.timeline.events) timeline_event_ids = [] for event in timeline_events: # TODO(mjark): Respect formatting requirements in the filter. @@ -297,6 +303,63 @@ class SyncRestServlet(RestServlet): return result + @staticmethod + def _rollback_state_for_timeline(state, timeline): + """ + Wind the state dictionary backwards, so that it represents the + state at the start of the timeline, rather than at the end. + + :param dict[(str, str), synapse.events.EventBase] state: the + state dictionary. Will be updated to the state before the timeline. + :param list[synapse.events.EventBase] timeline: the event timeline + :return: updated state dictionary + """ + logger.debug("Processing state dict %r; timeline %r", state, + [e.get_dict() for e in timeline]) + + result = state.copy() + + for timeline_event in reversed(timeline): + if not timeline_event.is_state(): + continue + + event_key = (timeline_event.type, timeline_event.state_key) + + logger.debug("Considering %s for removal", event_key) + + state_event = result.get(event_key) + if (state_event is None or + state_event.event_id != timeline_event.event_id): + # the event in the timeline isn't present in the state + # dictionary. + # + # the most likely cause for this is that there was a fork in + # the event graph, and the state is no longer valid. Really, + # the event shouldn't be in the timeline. We're going to ignore + # it for now, however. + logger.warn("Found state event %r in timeline which doesn't " + "match state dictionary", timeline_event) + continue + + prev_event_id = timeline_event.unsigned.get("replaces_state", None) + logger.debug("Replacing %s with %s in state dict", + timeline_event.event_id, prev_event_id) + + if prev_event_id is None: + del result[event_key] + else: + result[event_key] = FrozenEvent({ + "type": timeline_event.type, + "state_key": timeline_event.state_key, + "content": timeline_event.unsigned['prev_content'], + "sender": timeline_event.unsigned['prev_sender'], + "event_id": prev_event_id, + "room_id": timeline_event.room_id, + }) + logger.debug("New value: %r", result.get(event_key)) + + return result + def register_servlets(hs, http_server): SyncRestServlet(hs).register(http_server) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4a365ff639..5d35ca90b9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -831,7 +831,8 @@ class EventsStore(SQLBaseStore): allow_none=True, ) if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] + ev.unsigned["prev_content"] = prev.content + ev.unsigned["prev_sender"] = prev.sender self._get_event_cache.prefill( (ev.event_id, check_redacted, get_prev_content), ev @@ -888,7 +889,8 @@ class EventsStore(SQLBaseStore): get_prev_content=False, ) if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] + ev.unsigned["prev_content"] = prev.content + ev.unsigned["prev_sender"] = prev.sender self._get_event_cache.prefill( (ev.event_id, check_redacted, get_prev_content), ev -- cgit 1.5.1