From 53ace904b2b8e2444bc518b2498947c869c8c13b Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 5 Dec 2017 01:29:25 +0000 Subject: total WIP skeleton for /room_keys API --- synapse/storage/e2e_room_keys.py | 133 ++++++++++++++++++++++ synapse/storage/schema/delta/46/e2e_room_keys.sql | 40 +++++++ 2 files changed, 173 insertions(+) create mode 100644 synapse/storage/e2e_room_keys.py create mode 100644 synapse/storage/schema/delta/46/e2e_room_keys.sql (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py new file mode 100644 index 0000000000..9f6d47e1b6 --- /dev/null +++ b/synapse/storage/e2e_room_keys.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.util.caches.descriptors import cached + +from canonicaljson import encode_canonical_json +import ujson as json + +from ._base import SQLBaseStore + + +class EndToEndRoomKeyStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_e2e_room_key(self, user_id, version, room_id, session_id): + + row = yield self._simple_select_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_key", + ) + + defer.returnValue(row); + + def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + + def _set_e2e_room_key_txn(txn): + + self._simple_upsert( + txn, + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + } + values=[ + { + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": room_key['session_data'], + } + ], + lock=False, + ) + + return True + + return self.runInteraction( + "set_e2e_room_key", _set_e2e_room_key_txn + ) + + + def set_e2e_room_keys(self, user_id, version, room_keys): + + def _set_e2e_room_keys_txn(txn): + + self._simple_insert_many_txn( + txn, + table="e2e_room_keys", + values=[ + { + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + "version": version, + "first_message_index": room_keys['rooms'][room_id]['sessions'][session_id]['first_message_index'], + "forwarded_count": room_keys['rooms'][room_id]['sessions'][session_id]['forwarded_count'], + "is_verified": room_keys['rooms'][room_id]['sessions'][session_id]['is_verified'], + "session_data": room_keys['rooms'][room_id]['sessions'][session_id]['session_data'], + } + for session_id in room_keys['rooms'][room_id]['sessions'] + for room_id in room_keys['rooms'] + ] + ) + + return True + + return self.runInteraction( + "set_e2e_room_keys", _set_e2e_room_keys_txn + ) + + @defer.inlineCallbacks + def get_e2e_room_keys(self, user_id, version, room_id, session_id): + + keyvalues={ + "user_id": user_id, + "version": version, + } + if room_id: keyvalues['room_id'] = room_id + if session_id: keyvalues['session_id'] = session_id + + rows = yield self._simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {} + sessions['rooms'][roomId]['sessions'][session_id] = row for row in rows; + defer.returnValue(sessions); diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql new file mode 100644 index 0000000000..51b826e8b3 --- /dev/null +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -0,0 +1,40 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version INT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_user_idx ON e2e_room_keys(user_id); +CREATE UNIQUE INDEX e2e_room_keys_room_idx ON e2e_room_keys(room_id); +CREATE UNIQUE INDEX e2e_room_keys_session_idx ON e2e_room_keys(session_id); + +-- the versioning metadata about versions of users' encrypted e2e session backups +CREATE TABLE e2e_room_key_versions ( + user_id TEXT NOT NULL, + version INT NOT NULL, + algorithm TEXT NOT NULL, + dummy_session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_key_user_idx ON e2e_room_keys(user_id); -- cgit 1.4.1 From 0bc4627a731d0edd437905a5b07e85421c7553d8 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 5 Dec 2017 17:54:48 +0000 Subject: interim WIP checkin; doesn't build yet --- synapse/handlers/e2e_room_keys.py | 46 +++++++++++++++++++++---------- synapse/rest/client/v2_alpha/room_keys.py | 37 ++++++++++++++++++++++--- synapse/storage/e2e_room_keys.py | 20 ++++++++++++++ 3 files changed, 84 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 78c838a829..15e3beb5ed 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -20,8 +20,7 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable +from synapse.util.async import Linearizer from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -30,31 +29,48 @@ logger = logging.getLogger(__name__) class E2eRoomKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() + self._upload_linearizer = async.Linearizer("upload_room_keys_lock") @defer.inlineCallbacks def get_room_keys(self, user_id, version, room_id, session_id): results = yield self.store.get_e2e_room_keys(user_id, version, room_id, session_id) defer.returnValue(results) + @defer.inlineCallbacks + def delete_room_keys(self, user_id, version, room_id, session_id): + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): # TODO: Validate the JSON to make sure it has the right keys. - # go through the room_keys - for room_id in room_keys['rooms']: - for session_id in room_keys['rooms'][room_id]['sessions']: - session = room_keys['rooms'][room_id]['sessions'][session_id] - - # get a lock + # XXX: perhaps we should use a finer grained lock here? + with (yield self._upload_linearizer.queue(user_id): - # get the room_key for this particular row - yield self.store.get_e2e_room_key() + # go through the room_keys + for room_id in room_keys['rooms']: + for session_id in room_keys['rooms'][room_id]['sessions']: + room_key = room_keys['rooms'][room_id]['sessions'][session_id] - # check whether we merge or not - if() + # get the room_key for this particular row + current_room_key = yield self.store.get_e2e_room_key( + user_id, version, room_id, session_id + ) - # if so, we set it - yield self.store.set_e2e_room_key() + # check whether we merge or not. spelling it out with if/elifs rather than + # lots of booleans for legibility. + replace = False + if current_room_key: + if room_key['is_verified'] and not current_room_key['is_verified']: + replace = True + elif room_key['first_message_index'] < current_room_key['first_message_index']: + replace = True + elif room_key['forwarded_count'] < room_key['forwarded_count']: + replace = True - # release the lock + # if so, we set the new room_key + if replace: + yield self.store.set_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 9b93001919..7291018a48 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -28,7 +28,7 @@ from ._base import client_v2_patterns logger = logging.getLogger(__name__) -class RoomKeysUploadServlet(RestServlet): +class RoomKeysServlet(RestServlet): PATTERNS = client_v2_patterns("/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$") def __init__(self, hs): @@ -41,16 +41,45 @@ class RoomKeysUploadServlet(RestServlet): self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @defer.inlineCallbacks - def on_POST(self, request, room_id, session_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + def on_PUT(self, request, room_id, session_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) + version = request.args.get("version", None) + + if session_id: + body = { "sessions": { session_id : body } } + + if room_id: + body = { "rooms": { room_id : body } } result = yield self.e2e_room_keys_handler.upload_room_keys( user_id, version, body ) defer.returnValue((200, result)) + @defer.inlineCallbacks + def on_GET(self, request, room_id, session_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = request.args.get("version", None) + + room_keys = yield self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ) + defer.returnValue((200, room_keys)) + + @defer.inlineCallbacks + def on_DELETE(self, request, room_id, session_id): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = request.args.get("version", None) + + yield self.e2e_room_keys_handler.delete_room_keys( + user_id, version, room_id, session_id + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): - RoomKeysUploadServlet(hs).register(http_server) + RoomKeysServlet(hs).register(http_server) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 9f6d47e1b6..903dc083f8 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from twisted.internet import defer from synapse.util.caches.descriptors import cached @@ -77,6 +78,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) + # XXX: this isn't currently used and isn't tested anywhere + # it could be used in future for bulk-uploading new versions of room_keys + # for a user or something though. def set_e2e_room_keys(self, user_id, version, room_keys): def _set_e2e_room_keys_txn(txn): @@ -131,3 +135,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): sessions = {} sessions['rooms'][roomId]['sessions'][session_id] = row for row in rows; defer.returnValue(sessions); + + @defer.inlineCallbacks + def delete_e2e_room_keys(self, user_id, version, room_id, session_id): + + keyvalues={ + "user_id": user_id, + "version": version, + } + if room_id: keyvalues['room_id'] = room_id + if session_id: keyvalues['session_id'] = session_id + + yield self._simple_delete( + table="e2e_room_keys", + keyvalues=keyvalues, + desc="delete_e2e_room_keys", + ) -- cgit 1.4.1 From 6b8c07abc293bd222051e769550753bc0fd6f667 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Tue, 5 Dec 2017 21:44:25 +0000 Subject: make it work and fix pep8 --- synapse/handlers/e2e_room_keys.py | 69 +++++++++------ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/room_keys.py | 33 ++++--- synapse/server.py | 5 ++ synapse/storage/__init__.py | 2 + synapse/storage/e2e_room_keys.py | 103 +++++++++++++--------- synapse/storage/schema/delta/46/e2e_room_keys.sql | 2 +- 7 files changed, 134 insertions(+), 82 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 15e3beb5ed..93f4ad5194 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -13,15 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ujson as json import logging -from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.api.errors import SynapseError, CodeMessageException +from synapse.api.errors import StoreError from synapse.util.async import Linearizer -from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -29,11 +26,13 @@ logger = logging.getLogger(__name__) class E2eRoomKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self._upload_linearizer = async.Linearizer("upload_room_keys_lock") + self._upload_linearizer = Linearizer("upload_room_keys_lock") @defer.inlineCallbacks def get_room_keys(self, user_id, version, room_id, session_id): - results = yield self.store.get_e2e_room_keys(user_id, version, room_id, session_id) + results = yield self.store.get_e2e_room_keys( + user_id, version, room_id, session_id + ) defer.returnValue(results) @defer.inlineCallbacks @@ -46,31 +45,49 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id): + with (yield self._upload_linearizer.queue(user_id)): # go through the room_keys for room_id in room_keys['rooms']: for session_id in room_keys['rooms'][room_id]['sessions']: room_key = room_keys['rooms'][room_id]['sessions'][session_id] - # get the room_key for this particular row - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id + yield self._upload_room_key( + user_id, version, room_id, session_id, room_key ) - # check whether we merge or not. spelling it out with if/elifs rather than - # lots of booleans for legibility. - replace = False - if current_room_key: - if room_key['is_verified'] and not current_room_key['is_verified']: - replace = True - elif room_key['first_message_index'] < current_room_key['first_message_index']: - replace = True - elif room_key['forwarded_count'] < room_key['forwarded_count']: - replace = True - - # if so, we set the new room_key - if replace: - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) + @defer.inlineCallbacks + def _upload_room_key(self, user_id, version, room_id, session_id, room_key): + # get the room_key for this particular row + current_room_key = None + try: + current_room_key = yield self.store.get_e2e_room_key( + user_id, version, room_id, session_id + ) + except StoreError as e: + if e.code == 404: + pass + else: + raise + + # check whether we merge or not. spelling it out with if/elifs rather + # than lots of booleans for legibility. + upsert = True + if current_room_key: + if room_key['is_verified'] and not current_room_key['is_verified']: + pass + elif ( + room_key['first_message_index'] < + current_room_key['first_message_index'] + ): + pass + elif room_key['forwarded_count'] < room_key['forwarded_count']: + pass + else: + upsert = False + + # if so, we set the new room_key + if upsert: + yield self.store.set_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3418f06fd6..4856822a5d 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( receipts, register, report_event, + room_keys, sendtodevice, sync, tags, @@ -102,6 +103,7 @@ class ClientRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) read_marker.register_servlets(hs, client_resource) + room_keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 7291018a48..010aed98f9 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -17,26 +17,25 @@ import logging from twisted.internet import defer -from synapse.api.errors import SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, parse_integer + RestServlet, parse_json_object_from_request ) -from synapse.http.servlet import parse_string -from synapse.types import StreamToken from ._base import client_v2_patterns logger = logging.getLogger(__name__) class RoomKeysServlet(RestServlet): - PATTERNS = client_v2_patterns("/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$") + PATTERNS = client_v2_patterns( + "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" + ) def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysUploadServlet, self).__init__() + super(RoomKeysServlet, self).__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -45,24 +44,32 @@ class RoomKeysServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - version = request.args.get("version", None) + version = request.args.get("version")[0] if session_id: - body = { "sessions": { session_id : body } } + body = { + "sessions": { + session_id: body + } + } if room_id: - body = { "rooms": { room_id : body } } + body = { + "rooms": { + room_id: body + } + } - result = yield self.e2e_room_keys_handler.upload_room_keys( + yield self.e2e_room_keys_handler.upload_room_keys( user_id, version, body ) - defer.returnValue((200, result)) + defer.returnValue((200, {})) @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = request.args.get("version", None) + version = request.args.get("version")[0] room_keys = yield self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id @@ -73,7 +80,7 @@ class RoomKeysServlet(RestServlet): def on_DELETE(self, request, room_id, session_id): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = request.args.get("version", None) + version = request.args.get("version")[0] yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id diff --git a/synapse/server.py b/synapse/server.py index 140be9ebe8..706cb1361f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -49,6 +49,7 @@ from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler +from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.initial_sync import InitialSyncHandler @@ -127,6 +128,7 @@ class HomeServer(object): 'auth_handler', 'device_handler', 'e2e_keys_handler', + 'e2e_room_keys_handler', 'event_handler', 'event_stream_handler', 'initial_sync_handler', @@ -288,6 +290,9 @@ class HomeServer(object): def build_e2e_keys_handler(self): return E2eKeysHandler(self) + def build_e2e_room_keys_handler(self): + return E2eRoomKeysHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 134e4a80f1..69cb28268a 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .engines import PostgresEngine from .event_federation import EventFederationStore @@ -76,6 +77,7 @@ class DataStore(RoomMemberStore, RoomStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, + EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 903dc083f8..5982710bd5 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -15,11 +15,6 @@ from twisted.internet import defer -from synapse.util.caches.descriptors import cached - -from canonicaljson import encode_canonical_json -import ujson as json - from ._base import SQLBaseStore @@ -45,29 +40,27 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_key", ) - defer.returnValue(row); + defer.returnValue(row) def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): def _set_e2e_room_key_txn(txn): - self._simple_upsert( + self._simple_upsert_txn( txn, table="e2e_room_keys", keyvalues={ "user_id": user_id, "room_id": room_id, - "session_id": session_id, - } - values=[ - { - "version": version, - "first_message_index": room_key['first_message_index'], - "forwarded_count": room_key['forwarded_count'], - "is_verified": room_key['is_verified'], - "session_data": room_key['session_data'], - } - ], + "session_id": session_id, + }, + values={ + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": room_key['session_data'], + }, lock=False, ) @@ -77,7 +70,6 @@ class EndToEndRoomKeyStore(SQLBaseStore): "set_e2e_room_key", _set_e2e_room_key_txn ) - # XXX: this isn't currently used and isn't tested anywhere # it could be used in future for bulk-uploading new versions of room_keys # for a user or something though. @@ -85,23 +77,27 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _set_e2e_room_keys_txn(txn): + values = [] + for room_id in room_keys['rooms']: + for session_id in room_keys['rooms'][room_id]['sessions']: + session = room_keys['rooms'][room_id]['sessions'][session_id] + values.append( + { + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + "version": version, + "first_message_index": session['first_message_index'], + "forwarded_count": session['forwarded_count'], + "is_verified": session['is_verified'], + "session_data": session['session_data'], + } + ) + self._simple_insert_many_txn( txn, table="e2e_room_keys", - values=[ - { - "user_id": user_id, - "room_id": room_id, - "session_id": session_id, - "version": version, - "first_message_index": room_keys['rooms'][room_id]['sessions'][session_id]['first_message_index'], - "forwarded_count": room_keys['rooms'][room_id]['sessions'][session_id]['forwarded_count'], - "is_verified": room_keys['rooms'][room_id]['sessions'][session_id]['is_verified'], - "session_data": room_keys['rooms'][room_id]['sessions'][session_id]['session_data'], - } - for session_id in room_keys['rooms'][room_id]['sessions'] - for room_id in room_keys['rooms'] - ] + values=values ) return True @@ -113,17 +109,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_room_keys(self, user_id, version, room_id, session_id): - keyvalues={ + keyvalues = { "user_id": user_id, "version": version, } - if room_id: keyvalues['room_id'] = room_id - if session_id: keyvalues['session_id'] = session_id + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id rows = yield self._simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( + "user_id", + "room_id", + "session_id", "first_message_index", "forwarded_count", "is_verified", @@ -132,19 +133,37 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {} - sessions['rooms'][roomId]['sessions'][session_id] = row for row in rows; - defer.returnValue(sessions); + # perlesque autovivification from https://stackoverflow.com/a/19829714/6764493 + class AutoVivification(dict): + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError: + value = self[item] = type(self)() + return value + + sessions = AutoVivification() + for row in rows: + sessions['rooms'][row['room_id']]['sessions'][row['session_id']] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + "is_verified": row["is_verified"], + "session_data": row["session_data"], + } + + defer.returnValue(sessions) @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id, session_id): - keyvalues={ + keyvalues = { "user_id": user_id, "version": version, } - if room_id: keyvalues['room_id'] = room_id - if session_id: keyvalues['session_id'] = session_id + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id yield self._simple_delete( table="e2e_room_keys", diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql index 51b826e8b3..6b344c5ad7 100644 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -29,7 +29,7 @@ CREATE UNIQUE INDEX e2e_room_keys_user_idx ON e2e_room_keys(user_id); CREATE UNIQUE INDEX e2e_room_keys_room_idx ON e2e_room_keys(room_id); CREATE UNIQUE INDEX e2e_room_keys_session_idx ON e2e_room_keys(session_id); --- the versioning metadata about versions of users' encrypted e2e session backups +-- the metadata for each generation of encrypted e2e session backups CREATE TABLE e2e_room_key_versions ( user_id TEXT NOT NULL, version INT NOT NULL, -- cgit 1.4.1 From 8ae64b270f7742abfe4cf0b8d140c81993464ea4 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 6 Dec 2017 01:02:57 +0000 Subject: implement /room_keys/version too (untested) --- synapse/api/errors.py | 25 ++++++++++ synapse/handlers/e2e_room_keys.py | 47 +++++++++++++++++-- synapse/rest/client/v2_alpha/room_keys.py | 47 +++++++++++++++++++ synapse/storage/e2e_room_keys.py | 56 +++++++++++++++++++++++ synapse/storage/schema/delta/46/e2e_room_keys.sql | 2 +- 5 files changed, 171 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b41d595059..8c97e91ba1 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -56,6 +56,7 @@ class Codes(object): CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" + WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" class CodeMessageException(RuntimeError): @@ -285,6 +286,30 @@ class LimitExceededError(SynapseError): ) +class RoomKeysVersionError(SynapseError): + """A client has tried to upload to a non-current version of the room_keys store + """ + def __init__(self, code=403, msg="Wrong room_keys version", current_version=None, + errcode=Codes.WRONG_ROOM_KEYS_VERSION): + super(RoomKeysVersionError, self).__init__(code, msg, errcode) + self.current_version = current_version + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + current_version=self.current_version, + ) + + +def cs_exception(exception): + if isinstance(exception, CodeMessageException): + return exception.error_dict() + else: + logger.error("Unknown exception type: %s", type(exception)) + return {} + + def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 93f4ad5194..4333ca610c 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import StoreError +from synapse.api.errors import StoreError, SynapseError, RoomKeysVersionError from synapse.util.async import Linearizer logger = logging.getLogger(__name__) @@ -30,10 +30,13 @@ class E2eRoomKeysHandler(object): @defer.inlineCallbacks def get_room_keys(self, user_id, version, room_id, session_id): - results = yield self.store.get_e2e_room_keys( - user_id, version, room_id, session_id - ) - defer.returnValue(results) + # we deliberately take the lock to get keys so that changing the version + # works atomically + with (yield self._upload_linearizer.queue(user_id)): + results = yield self.store.get_e2e_room_keys( + user_id, version, room_id, session_id + ) + defer.returnValue(results) @defer.inlineCallbacks def delete_room_keys(self, user_id, version, room_id, session_id): @@ -44,6 +47,16 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. + # Check that the version we're trying to upload is the current version + try: + version_info = yield self.get_version_info(user_id, version) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%d' not found" % (version,)) + + if version_info.version != version: + raise RoomKeysVersionError(current_version=version_info.version) + # XXX: perhaps we should use a finer grained lock here? with (yield self._upload_linearizer.queue(user_id)): @@ -91,3 +104,27 @@ class E2eRoomKeysHandler(object): yield self.store.set_e2e_room_key( user_id, version, room_id, session_id, room_key ) + + @defer.inlineCallbacks + def create_version(self, user_id, version, version_info): + + # TODO: Validate the JSON to make sure it has the right keys. + + # lock everyone out until we've switched version + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.create_version( + user_id, version, version_info + ) + + @defer.inlineCallbacks + def get_version_info(self, user_id, version): + with (yield self._upload_linearizer.queue(user_id)): + results = yield self.store.get_e2e_room_key_version( + user_id, version + ) + defer.returnValue(results) + + @defer.inlineCallbacks + def delete_version(self, user_id, version): + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_key_version(user_id, version) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index be82eccb2b..4d76e1d824 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -221,5 +221,52 @@ class RoomKeysServlet(RestServlet): defer.returnValue((200, {})) +class RoomKeysVersionServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/room_keys/version(/(?P[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RoomKeysVersionServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + @defer.inlineCallbacks + def on_POST(self, request, version): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + new_version = yield self.e2e_room_keys_handler.create_version( + user_id, version, info + ) + defer.returnValue((200, {"version": new_version})) + + @defer.inlineCallbacks + def on_GET(self, request, version): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + info = yield self.e2e_room_keys_handler.get_version_info( + user_id, version + ) + defer.returnValue((200, info)) + + @defer.inlineCallbacks + def on_DELETE(self, request, version): + requester = yield self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + yield self.e2e_room_keys_handler.delete_version( + user_id, version + ) + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): RoomKeysServlet(hs).register(http_server) + RoomKeysVersionServlet(hs).register(http_server) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 5982710bd5..994878acf6 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -170,3 +170,59 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues=keyvalues, desc="delete_e2e_room_keys", ) + + @defer.inlineCallbacks + def get_e2e_room_key_version(self, user_id, version): + + row = yield self._simple_select_one( + table="e2e_room_key_versions", + keyvalues={ + "user_id": user_id, + "version": version, + }, + retcols=( + "user_id", + "version", + "algorithm", + "auth_data", + ), + desc="get_e2e_room_key_version_info", + ) + + defer.returnValue(row) + + def create_e2e_room_key_version(self, user_id, version, info): + + def _create_e2e_room_key_version_txn(txn): + + self._simple_insert_txn( + txn, + table="e2e_room_key_versions", + values={ + "user_id": user_id, + "version": version, + "algorithm": info["algorithm"], + "auth_data": info["auth_data"], + }, + lock=False, + ) + + return True + + return self.runInteraction( + "create_e2e_room_key_version_txn", _create_e2e_room_key_version_txn + ) + + @defer.inlineCallbacks + def delete_e2e_room_key_version(self, user_id, version): + + keyvalues = { + "user_id": user_id, + "version": version, + } + + yield self._simple_delete( + table="e2e_room_key_versions", + keyvalues=keyvalues, + desc="delete_e2e_room_key_version", + ) diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql index 6b344c5ad7..463f828c66 100644 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -34,7 +34,7 @@ CREATE TABLE e2e_room_key_versions ( user_id TEXT NOT NULL, version INT NOT NULL, algorithm TEXT NOT NULL, - dummy_session_data TEXT NOT NULL + auth_data TEXT NOT NULL ); CREATE UNIQUE INDEX e2e_room_key_user_idx ON e2e_room_keys(user_id); -- cgit 1.4.1 From 69e51c7ba48a84b48ab64c8c290a232d14193a18 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 6 Dec 2017 10:02:49 +0100 Subject: make /room_keys/version work --- synapse/handlers/e2e_room_keys.py | 18 +++++++++++------- synapse/rest/client/v2_alpha/room_keys.py | 9 ++++++++- synapse/storage/e2e_room_keys.py | 22 +++++++++++++++++----- synapse/storage/schema/delta/46/e2e_room_keys.sql | 4 ++-- 4 files changed, 38 insertions(+), 15 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 4333ca610c..bd58be6558 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -48,13 +48,16 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # Check that the version we're trying to upload is the current version + try: version_info = yield self.get_version_info(user_id, version) except StoreError as e: if e.code == 404: - raise SynapseError(404, "Version '%d' not found" % (version,)) + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise e - if version_info.version != version: + if version_info['version'] != version: raise RoomKeysVersionError(current_version=version_info.version) # XXX: perhaps we should use a finer grained lock here? @@ -81,7 +84,7 @@ class E2eRoomKeysHandler(object): if e.code == 404: pass else: - raise + raise e # check whether we merge or not. spelling it out with if/elifs rather # than lots of booleans for legibility. @@ -106,20 +109,21 @@ class E2eRoomKeysHandler(object): ) @defer.inlineCallbacks - def create_version(self, user_id, version, version_info): + def create_version(self, user_id, version_info): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version with (yield self._upload_linearizer.queue(user_id)): - yield self.store.create_version( - user_id, version, version_info + new_version = yield self.store.create_e2e_room_key_version( + user_id, version_info ) + defer.returnValue(new_version) @defer.inlineCallbacks def get_version_info(self, user_id, version): with (yield self._upload_linearizer.queue(user_id)): - results = yield self.store.get_e2e_room_key_version( + results = yield self.store.get_e2e_room_key_version_info( user_id, version ) defer.returnValue(results) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 4d76e1d824..128b732fb1 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer +from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request ) @@ -237,15 +238,21 @@ class RoomKeysVersionServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, version): + if version: + raise SynapseError(405, "Cannot POST to a specific version") + requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) new_version = yield self.e2e_room_keys_handler.create_version( - user_id, version, info + user_id, info ) defer.returnValue((200, {"version": new_version})) + # we deliberately don't have a PUT /version, as these things really should + # be immutable to avoid people footgunning + @defer.inlineCallbacks def on_GET(self, request, version): requester = yield self.auth.get_user_by_req(request, allow_guest=False) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 994878acf6..8efca11a8c 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -172,7 +172,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_room_key_version(self, user_id, version): + def get_e2e_room_key_version_info(self, user_id, version): row = yield self._simple_select_one( table="e2e_room_key_versions", @@ -191,23 +191,35 @@ class EndToEndRoomKeyStore(SQLBaseStore): defer.returnValue(row) - def create_e2e_room_key_version(self, user_id, version, info): + def create_e2e_room_key_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + """ def _create_e2e_room_key_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_key_versions WHERE user_id=?", + (user_id,) + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = 0 + + new_version = current_version + 1 + self._simple_insert_txn( txn, table="e2e_room_key_versions", values={ "user_id": user_id, - "version": version, + "version": new_version, "algorithm": info["algorithm"], "auth_data": info["auth_data"], }, - lock=False, ) - return True + return new_version return self.runInteraction( "create_e2e_room_key_version_txn", _create_e2e_room_key_version_txn diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql index 463f828c66..0d2a85fbe6 100644 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -18,7 +18,7 @@ CREATE TABLE e2e_room_keys ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, - version INT NOT NULL, + version TEXT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, @@ -32,7 +32,7 @@ CREATE UNIQUE INDEX e2e_room_keys_session_idx ON e2e_room_keys(session_id); -- the metadata for each generation of encrypted e2e session backups CREATE TABLE e2e_room_key_versions ( user_id TEXT NOT NULL, - version INT NOT NULL, + version TEXT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL ); -- cgit 1.4.1 From 0abb205b47158a4160ddceb317c0245d640b6e3f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Mon, 18 Dec 2017 01:52:46 +0000 Subject: blindly incorporate PR review - needs testing & fixing --- synapse/api/errors.py | 11 ++- synapse/handlers/e2e_room_keys.py | 88 +++++++++++++++-------- synapse/rest/client/v2_alpha/room_keys.py | 2 + synapse/storage/e2e_room_keys.py | 69 ++++++++---------- synapse/storage/schema/delta/46/e2e_room_keys.sql | 8 +-- 5 files changed, 99 insertions(+), 79 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 8c97e91ba1..d37bcb4082 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -289,9 +289,14 @@ class LimitExceededError(SynapseError): class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ - def __init__(self, code=403, msg="Wrong room_keys version", current_version=None, - errcode=Codes.WRONG_ROOM_KEYS_VERSION): - super(RoomKeysVersionError, self).__init__(code, msg, errcode) + def __init__(self, current_version): + """ + Args: + current_version (str): the current version of the store they should have used + """ + super(RoomKeysVersionError, self).__init__( + 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION + ) self.current_version = current_version def error_dict(self): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index bd58be6558..dda31fdd24 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -24,8 +24,21 @@ logger = logging.getLogger(__name__) class E2eRoomKeysHandler(object): + """ + Implements an optional realtime backup mechanism for encrypted E2E megolm room keys. + This gives a way for users to store and recover their megolm keys if they lose all + their clients. It should also extend easily to future room key mechanisms. + The actual payload of the encrypted keys is completely opaque to the handler. + """ + def __init__(self, hs): self.store = hs.get_datastore() + + # Used to lock whenever a client is uploading key data. This prevents collisions + # between clients trying to upload the details of a new session, given all + # clients belonging to a user will receive and try to upload a new session at + # roughly the same time. Also used to lock out uploads when the key is being + # changed. self._upload_linearizer = Linearizer("upload_room_keys_lock") @defer.inlineCallbacks @@ -40,33 +53,34 @@ class E2eRoomKeysHandler(object): @defer.inlineCallbacks def delete_room_keys(self, user_id, version, room_id, session_id): - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + # lock for consistency with uploading + with (yield self._upload_linearizer.queue(user_id)): + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): # TODO: Validate the JSON to make sure it has the right keys. - # Check that the version we're trying to upload is the current version - - try: - version_info = yield self.get_version_info(user_id, version) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Version '%s' not found" % (version,)) - else: - raise e - - if version_info['version'] != version: - raise RoomKeysVersionError(current_version=version_info.version) - # XXX: perhaps we should use a finer grained lock here? with (yield self._upload_linearizer.queue(user_id)): - - # go through the room_keys - for room_id in room_keys['rooms']: - for session_id in room_keys['rooms'][room_id]['sessions']: - room_key = room_keys['rooms'][room_id]['sessions'][session_id] + # Check that the version we're trying to upload is the current version + try: + version_info = yield self.get_version_info(user_id, version) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise e + + if version_info['version'] != version: + raise RoomKeysVersionError(current_version=version_info.version) + + # go through the room_keys. + # XXX: this should/could be done concurrently, given we're in a lock. + for room_id, room in room_keys['rooms'].iteritems(): + for session_id, session in room['sessions'].iteritems(): + room_key = session[session_id] yield self._upload_room_key( user_id, version, room_id, session_id, room_key @@ -86,10 +100,29 @@ class E2eRoomKeysHandler(object): else: raise e - # check whether we merge or not. spelling it out with if/elifs rather - # than lots of booleans for legibility. - upsert = True + if _should_replace_room_key(current_room_key, room_key): + yield self.store.set_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + + def _should_replace_room_key(current_room_key, room_key): + """ + Determine whether to replace the current_room_key in our backup for this + session (if any) with a new room_key that has been uploaded. + + Args: + current_room_key (dict): Optional, the current room_key dict if any + room_key (dict): The new room_key dict which may or may not be fit to + replace the current_room_key + + Returns: + True if current_room_key should be replaced by room_key in the backup + """ + if current_room_key: + # spelt out with if/elifs rather than nested boolean expressions + # purely for legibility. + if room_key['is_verified'] and not current_room_key['is_verified']: pass elif ( @@ -97,16 +130,11 @@ class E2eRoomKeysHandler(object): current_room_key['first_message_index'] ): pass - elif room_key['forwarded_count'] < room_key['forwarded_count']: + elif room_key['forwarded_count'] < current_room_key['forwarded_count']: pass else: - upsert = False - - # if so, we set the new room_key - if upsert: - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) + return False + return True @defer.inlineCallbacks def create_version(self, user_id, version_info): diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 128b732fb1..70b7b4573f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -68,6 +68,8 @@ class RoomKeysServlet(RestServlet): * lower forwarded_count always wins over higher forwarded_count We trust the clients not to lie and corrupt their own backups. + It also means that if your access_token is stolen, the attacker could + delete your backup. POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 Content-Type: application/json diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 8efca11a8c..c11417c415 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -44,30 +44,21 @@ class EndToEndRoomKeyStore(SQLBaseStore): def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - def _set_e2e_room_key_txn(txn): - - self._simple_upsert_txn( - txn, - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "session_id": session_id, - }, - values={ - "version": version, - "first_message_index": room_key['first_message_index'], - "forwarded_count": room_key['forwarded_count'], - "is_verified": room_key['is_verified'], - "session_data": room_key['session_data'], - }, - lock=False, - ) - - return True - - return self.runInteraction( - "set_e2e_room_key", _set_e2e_room_key_txn + yield self._simple_upsert( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + }, + values={ + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": room_key['session_data'], + }, + lock=False, ) # XXX: this isn't currently used and isn't tested anywhere @@ -107,7 +98,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id, session_id): + def get_e2e_room_keys( + self, user_id, version, room_id=room_id, session_id=session_id + ): keyvalues = { "user_id": user_id, @@ -115,8 +108,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): } if room_id: keyvalues['room_id'] = room_id - if session_id: - keyvalues['session_id'] = session_id + if session_id: + keyvalues['session_id'] = session_id rows = yield self._simple_select_list( table="e2e_room_keys", @@ -133,18 +126,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - # perlesque autovivification from https://stackoverflow.com/a/19829714/6764493 - class AutoVivification(dict): - def __getitem__(self, item): - try: - return dict.__getitem__(self, item) - except KeyError: - value = self[item] = type(self)() - return value - - sessions = AutoVivification() + sessions = {} for row in rows: - sessions['rooms'][row['room_id']]['sessions'][row['session_id']] = { + room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) + room_entry['sessions'][row['session_id']] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], "is_verified": row["is_verified"], @@ -154,7 +139,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): defer.returnValue(sessions) @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id, session_id): + def delete_e2e_room_keys( + self, user_id, version, room_id=room_id, session_id=session_id + ): keyvalues = { "user_id": user_id, @@ -162,8 +149,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): } if room_id: keyvalues['room_id'] = room_id - if session_id: - keyvalues['session_id'] = session_id + if session_id: + keyvalues['session_id'] = session_id yield self._simple_delete( table="e2e_room_keys", diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql index 0d2a85fbe6..16499ac34c 100644 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -25,16 +25,14 @@ CREATE TABLE e2e_room_keys ( session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_user_idx ON e2e_room_keys(user_id); -CREATE UNIQUE INDEX e2e_room_keys_room_idx ON e2e_room_keys(room_id); -CREATE UNIQUE INDEX e2e_room_keys_session_idx ON e2e_room_keys(session_id); +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); -- the metadata for each generation of encrypted e2e session backups -CREATE TABLE e2e_room_key_versions ( +CREATE TABLE e2e_room_keys_versions ( user_id TEXT NOT NULL, version TEXT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_room_key_user_idx ON e2e_room_keys(user_id); +CREATE UNIQUE INDEX e2e_room_keys_versions_user_idx ON e2e_room_keys_versions(user_id); -- cgit 1.4.1 From cac02537998718c05a561918269745161378dd6f Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Mon, 18 Dec 2017 01:58:53 +0000 Subject: rename room_key_version table correctly, and fix opt args --- synapse/handlers/e2e_room_keys.py | 6 +++--- synapse/storage/e2e_room_keys.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index dda31fdd24..87be081b1c 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -143,7 +143,7 @@ class E2eRoomKeysHandler(object): # lock everyone out until we've switched version with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_key_version( + new_version = yield self.store.create_e2e_room_keys_version( user_id, version_info ) defer.returnValue(new_version) @@ -151,7 +151,7 @@ class E2eRoomKeysHandler(object): @defer.inlineCallbacks def get_version_info(self, user_id, version): with (yield self._upload_linearizer.queue(user_id)): - results = yield self.store.get_e2e_room_key_version_info( + results = yield self.store.get_e2e_room_keys_version_info( user_id, version ) defer.returnValue(results) @@ -159,4 +159,4 @@ class E2eRoomKeysHandler(object): @defer.inlineCallbacks def delete_version(self, user_id, version): with (yield self._upload_linearizer.queue(user_id)): - yield self.store.delete_e2e_room_key_version(user_id, version) + yield self.store.delete_e2e_room_keys_version(user_id, version) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index c11417c415..7e1cb13e74 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -99,7 +99,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_room_keys( - self, user_id, version, room_id=room_id, session_id=session_id + self, user_id, version, room_id=None, session_id=None ): keyvalues = { @@ -140,7 +140,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def delete_e2e_room_keys( - self, user_id, version, room_id=room_id, session_id=session_id + self, user_id, version, room_id=None, session_id=None ): keyvalues = { @@ -159,10 +159,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_room_key_version_info(self, user_id, version): + def get_e2e_room_keys_version_info(self, user_id, version): row = yield self._simple_select_one( - table="e2e_room_key_versions", + table="e2e_room_keys_versions", keyvalues={ "user_id": user_id, "version": version, @@ -173,20 +173,20 @@ class EndToEndRoomKeyStore(SQLBaseStore): "algorithm", "auth_data", ), - desc="get_e2e_room_key_version_info", + desc="get_e2e_room_keys_version_info", ) defer.returnValue(row) - def create_e2e_room_key_version(self, user_id, info): + def create_e2e_room_keys_version(self, user_id, info): """Atomically creates a new version of this user's e2e_room_keys store with the given version info. """ - def _create_e2e_room_key_version_txn(txn): + def _create_e2e_room_keys_version_txn(txn): txn.execute( - "SELECT MAX(version) FROM e2e_room_key_versions WHERE user_id=?", + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,) ) current_version = txn.fetchone()[0] @@ -197,7 +197,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): self._simple_insert_txn( txn, - table="e2e_room_key_versions", + table="e2e_room_keys_versions", values={ "user_id": user_id, "version": new_version, @@ -209,11 +209,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version return self.runInteraction( - "create_e2e_room_key_version_txn", _create_e2e_room_key_version_txn + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @defer.inlineCallbacks - def delete_e2e_room_key_version(self, user_id, version): + def delete_e2e_room_keys_version(self, user_id, version): keyvalues = { "user_id": user_id, @@ -221,7 +221,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): } yield self._simple_delete( - table="e2e_room_key_versions", + table="e2e_room_keys_versions", keyvalues=keyvalues, - desc="delete_e2e_room_key_version", + desc="delete_e2e_room_keys_version", ) -- cgit 1.4.1 From 8d14598e90396c52c9394cf807861921e053d8da Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 24 Dec 2017 16:44:18 +0000 Subject: add storage docstring; remove unused set_e2e_room_keys --- synapse/storage/e2e_room_keys.py | 119 +++++++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 36 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 7e1cb13e74..45d2c9b433 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -22,6 +22,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_room_key(self, user_id, version, room_id, session_id): + """Get the encrypted E2E room key for a given session from a given + backup version of room_keys. We only store the 'best' room key for a given + session at a given time, as determined by the handler. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): the ID of the room whose keys we're querying. + This is a bit redundant as it's implied by the session_id, but + we include for consistency with the rest of the API. + session_id(str): the session whose room_key we're querying. + + Returns: + A deferred dict giving the session_data and message metadata for + this room key. + """ row = yield self._simple_select_one( table="e2e_room_keys", @@ -43,6 +59,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): defer.returnValue(row) def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces or inserts the encrypted E2E room key for a given session in + a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + Raises: + StoreError if stuff goes wrong, probably + """ yield self._simple_upsert( table="e2e_room_keys", @@ -61,46 +88,27 @@ class EndToEndRoomKeyStore(SQLBaseStore): lock=False, ) - # XXX: this isn't currently used and isn't tested anywhere - # it could be used in future for bulk-uploading new versions of room_keys - # for a user or something though. - def set_e2e_room_keys(self, user_id, version, room_keys): - - def _set_e2e_room_keys_txn(txn): - - values = [] - for room_id in room_keys['rooms']: - for session_id in room_keys['rooms'][room_id]['sessions']: - session = room_keys['rooms'][room_id]['sessions'][session_id] - values.append( - { - "user_id": user_id, - "room_id": room_id, - "session_id": session_id, - "version": version, - "first_message_index": session['first_message_index'], - "forwarded_count": session['forwarded_count'], - "is_verified": session['is_verified'], - "session_data": session['session_data'], - } - ) - - self._simple_insert_many_txn( - txn, - table="e2e_room_keys", - values=values - ) - - return True - - return self.runInteraction( - "set_e2e_room_keys", _set_e2e_room_keys_txn - ) - @defer.inlineCallbacks def get_e2e_room_keys( self, user_id, version, room_id=None, session_id=None ): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ keyvalues = { "user_id": user_id, @@ -142,6 +150,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): def delete_e2e_room_keys( self, user_id, version, room_id=None, session_id=None ): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ keyvalues = { "user_id": user_id, @@ -160,6 +184,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_room_keys_version_info(self, user_id, version): + """Get info etadata about a given version of our room_keys backup + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup we're querying about + + Returns: + A deferred dict giving the info metadata for this backup version + """ row = yield self._simple_select_one( table="e2e_room_keys_versions", @@ -181,6 +214,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): def create_e2e_room_keys_version(self, user_id, info): """Atomically creates a new version of this user's e2e_room_keys store with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -214,6 +254,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks def delete_e2e_room_keys_version(self, user_id, version): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): the ID of the backup version we're deleting + """ keyvalues = { "user_id": user_id, -- cgit 1.4.1 From 9f0791b7bd030a70182cd9b33bbe2f78dba705dd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 27 Dec 2017 23:35:10 +0000 Subject: add a tonne of docstring; make upload_room_keys properly assert version --- synapse/storage/e2e_room_keys.py | 51 ++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 20 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 45d2c9b433..e04e6a3690 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -67,6 +67,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str): the version ID of the backup we're updating room_id(str): the ID of the room whose keys we're setting session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set Raises: StoreError if stuff goes wrong, probably """ @@ -182,35 +183,46 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="delete_e2e_room_keys", ) - @defer.inlineCallbacks - def get_e2e_room_keys_version_info(self, user_id, version): - """Get info etadata about a given version of our room_keys backup + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. Args: user_id(str): the user whose backup we're querying - version(str): the version ID of the backup we're querying about - + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: A deferred dict giving the info metadata for this backup version """ - row = yield self._simple_select_one( - table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": version, - }, - retcols=( - "user_id", - "version", - "algorithm", - "auth_data", - ), + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,) + ) + version = txn.fetchone()[0] + + return self._simple_select_one_txn( + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": version, + }, + retcols=( + "user_id", + "version", + "algorithm", + "auth_data", + ), + ) + + return self.runInteraction( desc="get_e2e_room_keys_version_info", + _get_e2e_room_keys_version_info_txn ) - defer.returnValue(row) - def create_e2e_room_keys_version(self, user_id, info): """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -224,7 +236,6 @@ class EndToEndRoomKeyStore(SQLBaseStore): """ def _create_e2e_room_keys_version_txn(txn): - txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,) -- cgit 1.4.1 From 234611f3472b229d9e3a9ec7a27c51446f6a61ba Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 27 Dec 2017 23:42:08 +0000 Subject: fix typos --- synapse/handlers/e2e_room_keys.py | 3 ++- synapse/storage/e2e_room_keys.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 2fa025bfc7..6446c3c6c3 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -110,7 +110,8 @@ class E2eRoomKeysHandler(object): # XXX: perhaps we should use a finer grained lock here? with (yield self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version - version_info = yield self.get_current_version_info(user_id) + try: + version_info = yield self.get_current_version_info(user_id) except StoreError as e: if e.code == 404: raise SynapseError(404, "Version '%s' not found" % (version,)) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index e04e6a3690..b51faa1204 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -219,7 +219,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) return self.runInteraction( - desc="get_e2e_room_keys_version_info", + "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) -- cgit 1.4.1 From 982edca38026e2bb9085e77b4c25af60c4793f34 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 27 Dec 2017 23:58:51 +0000 Subject: fix flakes --- synapse/handlers/e2e_room_keys.py | 4 ++-- synapse/storage/e2e_room_keys.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 6446c3c6c3..fdae69c600 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -136,7 +136,7 @@ class E2eRoomKeysHandler(object): """Upload a given room_key for a given room and session into a given version of the backup. Merges the key with any which might already exist. - Args: + Args: user_id(str): the user whose backup we're setting version(str): the version ID of the backup we're updating room_id(str): the ID of the room whose keys we're setting @@ -199,7 +199,7 @@ class E2eRoomKeysHandler(object): backup version for the user's keys; previous backups will no longer be writeable to. - Args: + Args: user_id(str): the user whose backup version we're creating version_info(dict): metadata about the new version being created diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index b51faa1204..7ab75070a2 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -202,13 +202,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,) ) - version = txn.fetchone()[0] + this_version = txn.fetchone()[0] + else: + this_version = version return self._simple_select_one_txn( table="e2e_room_keys_versions", keyvalues={ "user_id": user_id, - "version": version, + "version": this_version, }, retcols=( "user_id", -- cgit 1.4.1 From b5eee511c73fdd9b5d1a7453433513791192a250 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 31 Dec 2017 14:11:02 +0000 Subject: don't needlessly return user_id --- synapse/storage/e2e_room_keys.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 7ab75070a2..3c720f3b3e 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -213,7 +213,6 @@ class EndToEndRoomKeyStore(SQLBaseStore): "version": this_version, }, retcols=( - "user_id", "version", "algorithm", "auth_data", -- cgit 1.4.1 From 15d513f16fd272fb3763a42f05a8f296d4e1abf0 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 31 Dec 2017 14:35:25 +0000 Subject: fix idiocies and so make tests pass --- synapse/storage/e2e_room_keys.py | 5 +++-- synapse/storage/schema/delta/46/e2e_room_keys.sql | 2 +- tests/handlers/test_e2e_room_keys.py | 19 +++++++++++-------- 3 files changed, 15 insertions(+), 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 3c720f3b3e..e4d56b7c37 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -207,6 +207,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): this_version = version return self._simple_select_one_txn( + txn, table="e2e_room_keys_versions", keyvalues={ "user_id": user_id, @@ -243,9 +244,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) current_version = txn.fetchone()[0] if current_version is None: - current_version = 0 + current_version = '0' - new_version = current_version + 1 + new_version = str(int(current_version) + 1) self._simple_insert_txn( txn, diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql index 16499ac34c..4531fd56ee 100644 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/46/e2e_room_keys.sql @@ -35,4 +35,4 @@ CREATE TABLE e2e_room_keys_versions ( auth_data TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_versions_user_idx ON e2e_room_keys_versions(user_id); +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 9d3bef6db2..6f4b3a147a 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -37,7 +37,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): handlers=None, replication_layer=mock.Mock(), ) - self.handler = synapse.handlers.e2e_keys.E2eRoomKeysHandler(self.hs) + self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) @defer.inlineCallbacks @@ -46,6 +46,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): if there is no version. """ local_user = "@boris:" + self.hs.hostname + res = None try: res = yield self.handler.get_version_info(local_user); except errors.SynapseError as e: @@ -58,6 +59,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): if it doesn't exist. """ local_user = "@boris:" + self.hs.hostname + res = None try: res = yield self.handler.get_version_info(local_user, "mrflibble"); except errors.SynapseError as e: @@ -69,7 +71,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we can create and then retrieve versions. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.create_version(user_id, { + res = yield self.handler.create_version(local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", }); @@ -92,7 +94,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }); # upload a new one... - res = yield self.handler.create_version(user_id, { + res = yield self.handler.create_version(local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", }); @@ -111,7 +113,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we can create and then delete versions. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.create_version(user_id, { + res = yield self.handler.create_version(local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", }); @@ -121,21 +123,22 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.delete_version(local_user, "1"); # check that it's gone + res = None try: res = yield self.handler.get_version_info(local_user, "1"); except errors.SynapseError as e: self.assertEqual(e.code, 404); - self.assertEqual(res, None); + self.assertEqual(res, None); @defer.inlineCallbacks def test_get_room_keys(self): - pass + yield None @defer.inlineCallbacks def test_upload_room_keys(self): - pass + yield None @defer.inlineCallbacks def test_delete_room_keys(self): - pass + yield None -- cgit 1.4.1 From fe87890b18f57f0268bd65aeca881e7817bbe9e4 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 31 Dec 2017 17:47:11 +0000 Subject: implement remaining tests and make them work --- synapse/handlers/e2e_room_keys.py | 35 +++- synapse/rest/client/v2_alpha/room_keys.py | 6 + synapse/storage/e2e_room_keys.py | 3 +- tests/handlers/test_e2e_room_keys.py | 276 +++++++++++++++++++++++++++--- 4 files changed, 287 insertions(+), 33 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f08d80da3e..09c2888db6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -58,6 +58,10 @@ class E2eRoomKeysHandler(object): results = yield self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) + + if results['rooms'] == {}: + raise SynapseError(404, "No room_keys found") + defer.returnValue(results) @defer.inlineCallbacks @@ -109,9 +113,10 @@ class E2eRoomKeysHandler(object): # XXX: perhaps we should use a finer grained lock here? with (yield self._upload_linearizer.queue(user_id)): + # Check that the version we're trying to upload is the current version try: - version_info = yield self.get_version_info(user_id) + version_info = yield self._get_version_info_unlocked(user_id) except StoreError as e: if e.code == 404: raise SynapseError(404, "Version '%s' not found" % (version,)) @@ -119,16 +124,23 @@ class E2eRoomKeysHandler(object): raise e if version_info['version'] != version: - raise RoomKeysVersionError(current_version=version_info.version) + # Check that the version we're trying to upload actually exists + try: + version_info = yield self._get_version_info_unlocked(user_id, version) + # if we get this far, the version must exist + raise RoomKeysVersionError(current_version=version_info['version']) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Version '%s' not found" % (version,)) + else: + raise e # go through the room_keys. # XXX: this should/could be done concurrently, given we're in a lock. for room_id, room in room_keys['rooms'].iteritems(): for session_id, session in room['sessions'].iteritems(): - room_key = session[session_id] - yield self._upload_room_key( - user_id, version, room_id, session_id, room_key + user_id, version, room_id, session_id, session ) @defer.inlineCallbacks @@ -242,8 +254,17 @@ class E2eRoomKeysHandler(object): """ with (yield self._upload_linearizer.queue(user_id)): - results = yield self.store.get_e2e_room_keys_version_info(user_id) - defer.returnValue(results) + res = yield self._get_version_info_unlocked(user_id, version) + defer.returnValue(res) + + @defer.inlineCallbacks + def _get_version_info_unlocked(self, user_id, version=None): + """Get the info about a given version of the user's backup + without obtaining the upload_linearizer lock. For params see get_version_info + """ + + results = yield self.store.get_e2e_room_keys_version_info(user_id, version) + defer.returnValue(results) @defer.inlineCallbacks def delete_version(self, user_id, version): diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index ca69ced1e3..8f10e4e1cd 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -204,6 +204,12 @@ class RoomKeysServlet(RestServlet): room_keys = yield self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) + + if session_id: + room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + elif room_id: + room_keys = room_keys['rooms'][room_id] + defer.returnValue((200, room_keys)) @defer.inlineCallbacks diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index e4d56b7c37..8e8e4e457c 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -58,6 +58,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): defer.returnValue(row) + @defer.inlineCallbacks def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): """Replaces or inserts the encrypted E2E room key for a given session in a given backup @@ -135,7 +136,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {} + sessions = { 'rooms': {} } for row in rows: room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) room_entry['sessions'][row['session_id']] = { diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index afe6ecf27b..3cbfd6f9d0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -17,6 +17,7 @@ import mock from synapse.api import errors from twisted.internet import defer +import copy import synapse.api.errors import synapse.handlers.e2e_room_keys @@ -25,6 +26,22 @@ import synapse.storage from tests import unittest, utils +# sample room_key data for use in the tests +room_keys = { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } +} + class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) @@ -38,46 +55,44 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): replication_layer=mock.Mock(), ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) + self.local_user = "@boris:" + self.hs.hostname; @defer.inlineCallbacks def test_get_missing_current_version_info(self): """Check that we get a 404 if we ask for info about the current version if there is no version. """ - local_user = "@boris:" + self.hs.hostname res = None try: - res = yield self.handler.get_version_info(local_user) + yield self.handler.get_version_info(self.local_user) except errors.SynapseError as e: - self.assertEqual(e.code, 404) - self.assertEqual(res, None) + res = e.code + self.assertEqual(res, 404) @defer.inlineCallbacks def test_get_missing_version_info(self): """Check that we get a 404 if we ask for info about a specific version if it doesn't exist. """ - local_user = "@boris:" + self.hs.hostname res = None try: - res = yield self.handler.get_version_info(local_user, "mrflibble") + yield self.handler.get_version_info(self.local_user, "bogus_version") except errors.SynapseError as e: - self.assertEqual(e.code, 404) - self.assertEqual(res, None) + res = e.code + self.assertEqual(res, 404) @defer.inlineCallbacks def test_create_version(self): """Check that we can create and then retrieve versions. """ - local_user = "@boris:" + self.hs.hostname - res = yield self.handler.create_version(local_user, { + res = yield self.handler.create_version(self.local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", }) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(local_user) + res = yield self.handler.get_version_info(self.local_user) self.assertDictEqual(res, { "version": "1", "algorithm": "m.megolm_backup.v1", @@ -85,7 +100,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(local_user, "1") + res = yield self.handler.get_version_info(self.local_user, "1") self.assertDictEqual(res, { "version": "1", "algorithm": "m.megolm_backup.v1", @@ -93,14 +108,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }) # upload a new one... - res = yield self.handler.create_version(local_user, { + res = yield self.handler.create_version(self.local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", }) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(local_user) + res = yield self.handler.get_version_info(self.local_user) self.assertDictEqual(res, { "version": "2", "algorithm": "m.megolm_backup.v1", @@ -111,32 +126,243 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - local_user = "@boris:" + self.hs.hostname - res = yield self.handler.create_version(local_user, { + res = yield self.handler.create_version(self.local_user, { "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", }) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(local_user, "1") + yield self.handler.delete_version(self.local_user, "1") # check that it's gone res = None try: - res = yield self.handler.get_version_info(local_user, "1") + yield self.handler.get_version_info(self.local_user, "1") except errors.SynapseError as e: - self.assertEqual(e.code, 404) - self.assertEqual(res, None) + res = e.code + self.assertEqual(res, 404) @defer.inlineCallbacks - def test_get_room_keys(self): - yield None + def test_get_missing_room_keys(self): + """Check that we get a 404 on querying missing room_keys + """ + res = None + try: + yield self.handler.get_room_keys(self.local_user, "bogus_version") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check we also get a 404 even if the version is valid + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.get_room_keys(self.local_user, version) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # TODO: test the locking semantics when uploading room_keys, + # although this is probably best done in sytest @defer.inlineCallbacks - def test_upload_room_keys(self): - yield None + def test_upload_room_keys_no_versions(self): + """Check that we get a 404 on uploading keys when no versions are defined + """ + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_bogus_version(self): + """Check that we get a 404 on uploading keys when an nonexistent version is specified + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_wrong_version(self): + """Check that we get a 403 on uploading keys for an old version + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + self.assertEqual(version, "2") + + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 403) + + @defer.inlineCallbacks + def test_upload_room_keys_insert(self): + """Check that we can insert and retrieve keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given room + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org" + ) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given session_id + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, room_keys) + + @defer.inlineCallbacks + def test_upload_room_keys_merge(self): + """Check that we can upload a new room_key for an existing session and + have it correctly merged""" + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + new_room_keys = copy.deepcopy(room_keys) + + # test that increasing the message_index doesn't replace the existing session + new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['first_message_index'] = 2 + new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'] = 'new' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "SSBBTSBBIEZJU0gK" + ) + + # test that marking the session as verified however /does/ replace it + new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['is_verified'] = True + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # test that a session with a higher forwarded_count doesn't replace one + # with a lower forwarding count + new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['forwarded_count'] = 2 + new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'] = 'other' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks def test_delete_room_keys(self): - yield None + """Check that we can insert and delete keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + # check for bulk-delete + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys(self.local_user, version) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per room + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per session + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) -- cgit 1.4.1 From edc427a35187e462458d188039e761ffdb7e486d Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 31 Dec 2017 17:50:55 +0000 Subject: flake8 --- synapse/storage/e2e_room_keys.py | 2 +- tests/handlers/test_e2e_room_keys.py | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 8e8e4e457c..c2f226396d 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -136,7 +136,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = { 'rooms': {} } + sessions = {'rooms': {}} for row in rows: room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) room_entry['sessions'][row['session_id']] = { diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 3cbfd6f9d0..6e43543ed9 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -42,6 +42,7 @@ room_keys = { } } + class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) @@ -55,7 +56,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): replication_layer=mock.Mock(), ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) - self.local_user = "@boris:" + self.hs.hostname; + self.local_user = "@boris:" + self.hs.hostname @defer.inlineCallbacks def test_get_missing_current_version_info(self): @@ -184,7 +185,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def test_upload_room_keys_bogus_version(self): - """Check that we get a 404 on uploading keys when an nonexistent version is specified + """Check that we get a 404 on uploading keys when an nonexistent version + is specified """ version = yield self.handler.create_version(self.local_user, { "algorithm": "m.megolm_backup.v1", @@ -194,7 +196,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = None try: - yield self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys) + yield self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -267,10 +271,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) new_room_keys = copy.deepcopy(room_keys) + new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] # test that increasing the message_index doesn't replace the existing session - new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['first_message_index'] = 2 - new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'] = 'new' + new_room_key['first_message_index'] = 2 + new_room_key['session_data'] = 'new' yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) @@ -280,7 +285,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # test that marking the session as verified however /does/ replace it - new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['is_verified'] = True + new_room_key['is_verified'] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) @@ -291,8 +296,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count - new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['forwarded_count'] = 2 - new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'] = 'other' + new_room_key['forwarded_count'] = 2 + new_room_key['session_data'] = 'other' yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) -- cgit 1.4.1 From 66a4ca1d28c2c2e22a0343e6db0f5a2bce9ec987 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 7 Jan 2018 23:45:55 +0000 Subject: 404 nicely if you try to interact with a missing current version --- synapse/storage/e2e_room_keys.py | 51 +++++++++++++++++++++++++----------- tests/handlers/test_e2e_room_keys.py | 22 ++++++++++++++++ 2 files changed, 57 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index c2f226396d..d82f223e86 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -184,6 +184,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="delete_e2e_room_keys", ) + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,) + ) + row = txn.fetchone() + if not row: + raise StoreError(404, 'No current backup version') + return row[0] + def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. @@ -199,11 +210,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _get_e2e_room_keys_version_info_txn(txn): if version is None: - txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", - (user_id,) - ) - this_version = txn.fetchone()[0] + this_version = self._get_current_version(txn, user_id) else: this_version = version @@ -266,23 +273,35 @@ class EndToEndRoomKeyStore(SQLBaseStore): "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) - @defer.inlineCallbacks - def delete_e2e_room_keys_version(self, user_id, version): + def delete_e2e_room_keys_version(self, user_id, version=None): """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: user_id(str): the user whose backup version we're deleting - version(str): the ID of the backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. """ - keyvalues = { - "user_id": user_id, - "version": version, - } + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version - yield self._simple_delete( - table="e2e_room_keys_versions", - keyvalues=keyvalues, - desc="delete_e2e_room_keys_version", + return self._simple_delete_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + }, + ) + + return self.runInteraction( + "delete_e2e_room_keys_version", + _delete_e2e_room_keys_version_txn ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 6e43543ed9..8bfffb5c0e 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -123,6 +123,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "auth_data": "second_version_auth_data", }) + @defer.inlineCallbacks + def test_delete_missing_version(self): + """Check that we get a 404 on deleting nonexistent versions + """ + res = None + try: + yield self.handler.delete_version(self.local_user, "1") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_delete_missing_current_version(self): + """Check that we get a 404 on deleting nonexistent current version + """ + res = None + try: + yield self.handler.delete_version(self.local_user) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + @defer.inlineCallbacks def test_delete_version(self): """Check that we can create and then delete versions. -- cgit 1.4.1 From f0cede5556a54ece6cd5289856bb230f8009c4fa Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sun, 7 Jan 2018 23:58:32 +0000 Subject: missing import --- synapse/storage/e2e_room_keys.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index d82f223e86..089989fcfa 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.errors import StoreError from ._base import SQLBaseStore -- cgit 1.4.1 From 8550a7e9c2fcf9e46c717127d14c79010350a998 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 21 Aug 2018 10:38:00 -0400 Subject: allow auth_data to be any JSON instead of a string --- synapse/storage/e2e_room_keys.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 089989fcfa..c7b1fad21e 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -15,6 +15,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError +import simplejson as json from ._base import SQLBaseStore @@ -215,7 +216,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - return self._simple_select_one_txn( + result = self._simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={ @@ -228,6 +229,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): "auth_data", ), ) + result["auth_data"] = json.loads(result["auth_data"]) + return result return self.runInteraction( "get_e2e_room_keys_version_info", @@ -264,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": info["auth_data"], + "auth_data": json.dumps(info["auth_data"]), }, ) -- cgit 1.4.1 From 42a394caa2884096e5afe1f1c2c75680792769d8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 21 Aug 2018 14:51:34 -0400 Subject: allow session_data to be any JSON instead of just a string --- synapse/storage/e2e_room_keys.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index c7b1fad21e..b695570a7b 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -58,6 +58,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_key", ) + row["session_data"] = json.loads(row["session_data"]); + defer.returnValue(row) @defer.inlineCallbacks @@ -87,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key['first_message_index'], "forwarded_count": room_key['forwarded_count'], "is_verified": room_key['is_verified'], - "session_data": room_key['session_data'], + "session_data": json.dumps(room_key['session_data']), }, lock=False, ) @@ -145,7 +147,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], "is_verified": row["is_verified"], - "session_data": row["session_data"], + "session_data": json.loads(row["session_data"]), } defer.returnValue(sessions) -- cgit 1.4.1 From 3801b8aa035594972c400c8bd036894a388c4ab3 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 6 Sep 2018 11:23:16 -0400 Subject: try to make flake8 and isort happy --- synapse/api/errors.py | 1 + synapse/handlers/e2e_room_keys.py | 2 +- synapse/rest/client/v2_alpha/room_keys.py | 5 ++++- synapse/storage/e2e_room_keys.py | 6 ++++-- tests/handlers/test_e2e_room_keys.py | 9 +++++---- 5 files changed, 15 insertions(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 3002c95dd1..140dbfe8b8 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -326,6 +326,7 @@ class RoomKeysVersionError(SynapseError): ) self.current_version = current_version + class IncompatibleRoomVersionError(SynapseError): """A server is trying to join a room whose version it does not support.""" diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index c09816b372..2c330382cf 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import StoreError, SynapseError, RoomKeysVersionError +from synapse.api.errors import RoomKeysVersionError, StoreError, SynapseError from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 9f0172e6f5..1ed18e986f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -19,8 +19,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, parse_string + RestServlet, + parse_json_object_from_request, + parse_string, ) + from ._base import client_v2_patterns logger = logging.getLogger(__name__) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index b695570a7b..969f4aef9c 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import simplejson as json + from twisted.internet import defer + from synapse.api.errors import StoreError -import simplejson as json from ._base import SQLBaseStore @@ -58,7 +60,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_key", ) - row["session_data"] = json.loads(row["session_data"]); + row["session_data"] = json.loads(row["session_data"]) defer.returnValue(row) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 7fa4264441..9e08eac0a5 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -14,17 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import mock -from synapse.api import errors + from twisted.internet import defer -import copy import synapse.api.errors import synapse.handlers.e2e_room_keys - import synapse.storage -from tests import unittest, utils +from synapse.api import errors +from tests import unittest, utils # sample room_key data for use in the tests room_keys = { -- cgit 1.4.1 From bc74925c5b94f0c02603297d4ccb7e05008d5124 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 13 Sep 2018 17:02:59 +0100 Subject: WIP e2e key backups Continues from uhoreg's branch This just fixed the errcode on /room_keys/version if no backup and updates the schema delta to be on the latest so it gets run --- synapse/rest/client/v2_alpha/room_keys.py | 14 ++++++--- synapse/storage/schema/delta/46/e2e_room_keys.sql | 38 ----------------------- synapse/storage/schema/delta/51/e2e_room_keys.sql | 38 +++++++++++++++++++++++ 3 files changed, 48 insertions(+), 42 deletions(-) delete mode 100644 synapse/storage/schema/delta/46/e2e_room_keys.sql create mode 100644 synapse/storage/schema/delta/51/e2e_room_keys.sql (limited to 'synapse/storage') diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 1ed18e986f..ea114bc8b4 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -324,9 +324,15 @@ class RoomKeysVersionServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - info = yield self.e2e_room_keys_handler.get_version_info( - user_id, version - ) + try: + info = yield self.e2e_room_keys_handler.get_version_info( + user_id, version + ) + except SynapseError as e: + if e.code == 404: + e.errcode = Codes.NOT_FOUND + e.msg = "No backup found" + raise e defer.returnValue((200, info)) @defer.inlineCallbacks diff --git a/synapse/storage/schema/delta/46/e2e_room_keys.sql b/synapse/storage/schema/delta/46/e2e_room_keys.sql deleted file mode 100644 index 4531fd56ee..0000000000 --- a/synapse/storage/schema/delta/46/e2e_room_keys.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- users' optionally backed up encrypted e2e sessions -CREATE TABLE e2e_room_keys ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - version TEXT NOT NULL, - first_message_index INT, - forwarded_count INT, - is_verified BOOLEAN, - session_data TEXT NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); - --- the metadata for each generation of encrypted e2e session backups -CREATE TABLE e2e_room_keys_versions ( - user_id TEXT NOT NULL, - version TEXT NOT NULL, - algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..4531fd56ee --- /dev/null +++ b/synapse/storage/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,38 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); -- cgit 1.4.1 From 85a43f4167b5e775f7c3afe96d71137818561272 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 17 Sep 2018 13:19:00 +0100 Subject: Return a 404 when deleting unknown room alias As per https://github.com/matrix-org/matrix-doc/issues/1675 Fixes https://github.com/matrix-org/synapse/issues/2782 --- synapse/handlers/directory.py | 19 ++++++++++++++++--- synapse/storage/directory.py | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ef866da1b6..c745e6740b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -20,7 +20,14 @@ import string from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, CodeMessageException, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + NotFoundError, + StoreError, + SynapseError, +) from synapse.types import RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -109,7 +116,13 @@ class DirectoryHandler(BaseHandler): def delete_association(self, requester, user_id, room_alias): # association deletion for human users - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + try: + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown room alias") + raise + if not can_delete: raise AuthError( 403, "You don't have permission to delete the alias.", @@ -320,7 +333,7 @@ class DirectoryHandler(BaseHandler): def _user_can_delete_alias(self, alias, user_id): creator = yield self.store.get_room_alias_creator(alias.to_string()) - if creator and creator == user_id: + if creator == user_id: defer.returnValue(True) is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 808194236a..cfb687cb53 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): }, retcol="creator", desc="get_room_alias_creator", - allow_none=True ) @cached(max_entries=5000) -- cgit 1.4.1 From b9158ac2bf46e15107162be136964c17bd8a948c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Sep 2018 14:16:34 +0100 Subject: Remove get_destination_retry_timings cache Currently we rely on the master to invalidate this cache promptly. However, after having moved most federation endpoints off of master this no longer happens, causing outbound fedeariont to get blackholed. Fixes #3798 --- synapse/storage/transactions.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 0c42bd3322..7d51f9612a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -156,7 +156,6 @@ class TransactionStore(SQLBaseStore): """ pass - @cached(max_entries=10000) def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -212,10 +211,6 @@ class TransactionStore(SQLBaseStore): retry_last_ts, retry_interval): self.database_engine.lock_table(txn, "destinations") - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. -- cgit 1.4.1 From 392a54128c03128255ee538e17ed9bd87b95d4ca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Sep 2018 14:37:49 +0100 Subject: pep8 --- synapse/storage/transactions.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 7d51f9612a..449bdc1e49 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,7 +23,6 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore, db_to_json -- cgit 1.4.1 From bbab6ebfd9257ba9ad5e576c4b9b8bdea1c166cf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Sep 2018 14:45:14 +0100 Subject: Fix up changelog and remove spurious comment --- changelog.d/3914.bugfix | 2 +- synapse/storage/transactions.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/changelog.d/3914.bugfix b/changelog.d/3914.bugfix index 60fb8ec1ae..27e6bad590 100644 --- a/changelog.d/3914.bugfix +++ b/changelog.d/3914.bugfix @@ -1 +1 @@ -Fix bug where outboud federation would stop talking to some servers when using workers +Fix bug where outbound federation would stop talking to some servers when using workers diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 449bdc1e49..baf0379a68 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -196,8 +196,6 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # XXX: we could chose to not bother persisting this if our cache thinks - # this is a NOOP return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, -- cgit 1.4.1 From 1f3f5fcf52efdf5215dc6c7cd3805b1357bf8236 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Sep 2018 20:14:34 +1000 Subject: Fix client IPs being broken on Python 3 (#3908) --- .travis.yml | 5 + changelog.d/3908.bugfix | 1 + synapse/http/site.py | 2 +- synapse/storage/client_ips.py | 34 ++++--- tests/server.py | 8 +- tests/storage/test_client_ips.py | 202 ++++++++++++++++++++++++++++++++------- tests/unittest.py | 7 +- tests/utils.py | 27 +++++- tox.ini | 10 ++ 9 files changed, 238 insertions(+), 58 deletions(-) create mode 100644 changelog.d/3908.bugfix (limited to 'synapse/storage') diff --git a/.travis.yml b/.travis.yml index b3ee66da8f..b6faca4b92 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,6 +31,11 @@ matrix: - python: 3.6 env: TOX_ENV=py36 + - python: 3.6 + env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4" + services: + - postgresql + - python: 3.6 env: TOX_ENV=check_isort diff --git a/changelog.d/3908.bugfix b/changelog.d/3908.bugfix new file mode 100644 index 0000000000..518aee6c4d --- /dev/null +++ b/changelog.d/3908.bugfix @@ -0,0 +1 @@ +Fix adding client IPs to the database failing on Python 3. \ No newline at end of file diff --git a/synapse/http/site.py b/synapse/http/site.py index d41d672934..50be2de3bb 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -308,7 +308,7 @@ class XForwardedForRequest(SynapseRequest): C{b"-"}. """ return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() + b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') class SynapseRequestFactory(object): diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8fc678fa67..9ad17b7c25 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": last_seen, - }, - lock=False, - ) + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) @defer.inlineCallbacks def get_last_client_ip_by_device(self, user_id, device_id): diff --git a/tests/server.py b/tests/server.py index ccea3baa55..7bee58dff1 100644 --- a/tests/server.py +++ b/tests/server.py @@ -98,7 +98,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b"", access_token=None): +def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -120,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None): site = FakeSite() channel = FakeChannel() - req = SynapseRequest(site, channel) + req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) if access_token: req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) - req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") + if content: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + req.requestReceived(method, path, b"1.1") return req, channel diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c9b02a062b..2ffbb9f14f 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +13,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import hashlib +import hmac +import json + from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils +from synapse.http.site import XForwardedForRequest +from synapse.rest.client.v1 import admin, login + +from tests import unittest -class ClientIpStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.clock = None # type: tests.utils.MockClock +class ClientIpStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs - @defer.inlineCallbacks - def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) + def prepare(self, hs, reactor, clock): self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - @defer.inlineCallbacks def test_insert_new_client_ip(self): - self.clock.now = 12345678 + self.reactor.advance(12345678) + user_id = "@user:id" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) r = result[(user_id, "device_id")] self.assertDictContainsSubset( @@ -55,18 +66,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): r, ) - @defer.inlineCallbacks def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_full(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 @@ -76,38 +87,159 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - @defer.inlineCallbacks def test_updating_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.register(user_id=user_id, token="123", password_hash=None) + self.get_success( + self.store.register(user_id=user_id, token="123", password_hash=None) + ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + + +class ClientIpAuthTestCase(unittest.HomeserverTestCase): + + servlets = [admin.register_servlets, login.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.hs.config.registration_shared_secret = u"shared" + self.store = self.hs.get_datastore() + + # Create the user + request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + ) + self.render(request) + + self.assertEqual(channel.code, 200) + self.user_id = channel.json_body["user_id"] + + def test_request_with_xforwarded(self): + """ + The IP in X-Forwarded-For is entered into the client IPs table. + """ + self._runtest( + {b"X-Forwarded-For": b"127.9.0.1"}, + "127.9.0.1", + {"request": XForwardedForRequest}, + ) + + def test_request_from_getPeer(self): + """ + The IP returned by getPeer is entered into the client IPs table, if + there's no X-Forwarded-For header. + """ + self._runtest({}, "127.0.0.1", {}) + + def _runtest(self, headers, expected_ip, make_request_args): + device_id = "bleb" + + body = json.dumps( + { + "type": "m.login.password", + "user": "bob", + "password": "abc123", + "device_id": device_id, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", body.encode('utf8'), **make_request_args + ) + self.render(request) + self.assertEqual(channel.code, 200) + access_token = channel.json_body["access_token"].encode('ascii') + + # Advance to a known time + self.reactor.advance(123456 - self.reactor.seconds()) + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/admin/users/" + self.user_id, + body.encode('utf8'), + access_token=access_token, + **make_request_args + ) + request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") + + # Add the optional headers + for h, v in headers.items(): + request.requestHeaders.addRawHeader(h, v) + self.render(request) + + # Advance so the save loop occurs + self.reactor.advance(100) + + result = self.get_success( + self.store.get_last_client_ip_by_device(self.user_id, device_id) + ) + r = result[(self.user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": self.user_id, + "device_id": device_id, + "ip": expected_ip, + "user_agent": "Mozzila pizza", + "last_seen": 123456100, + }, + r, + ) diff --git a/tests/unittest.py b/tests/unittest.py index 56f3dca394..ef905e6389 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -26,6 +26,7 @@ from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource +from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.logcontext import LoggingContextFilter @@ -237,7 +238,9 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - def make_request(self, method, path, content=b""): + def make_request( + self, method, path, content=b"", access_token=None, request=SynapseRequest + ): """ Create a SynapseRequest at the path using the method and containing the given content. @@ -255,7 +258,7 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content) + return make_request(method, path, content, access_token, request) def render(self, request): """ diff --git a/tests/utils.py b/tests/utils.py index 215226debf..aaed1149c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,9 @@ import atexit import hashlib import os +import time import uuid +import warnings from inspect import getcallargs from mock import Mock, patch @@ -237,20 +239,41 @@ def setup_test_homeserver( else: # We need to do cleanup on PostgreSQL def cleanup(): + import psycopg2 + # Close all the db pools hs.get_db_pool().close() + dropped = False + # Drop the test database db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER ) db_conn.autocommit = True cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + cur.close() db_conn.close() + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + if not LEAVE_DB: # Register the cleanup hook cleanup_func(cleanup) diff --git a/tox.ini b/tox.ini index 88875c3318..e4db563b4b 100644 --- a/tox.ini +++ b/tox.ini @@ -70,6 +70,16 @@ usedevelop=true [testenv:py36] usedevelop=true +[testenv:py36-postgres] +usedevelop=true +deps = + {[base]deps} + psycopg2 +setenv = + {[base]setenv} + SYNAPSE_POSTGRES = 1 + + [testenv:packaging] deps = check-manifest -- cgit 1.4.1 From fdd1a62e8d09ddccbe685fe7d7840990a9c06241 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 Sep 2018 14:55:47 +0100 Subject: Add a five minute cache to get_destination_retry_timings Hopefully helps with #3931 --- synapse/storage/transactions.py | 23 ++++++++++++++++++++++- synapse/util/caches/expiringcache.py | 13 +++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index baf0379a68..ab54977a75 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.caches.expiringcache import ExpiringCache from ._base import SQLBaseStore, db_to_json @@ -49,6 +50,8 @@ _UpdateTransactionRow = namedtuple( ) ) +SENTINEL = object() + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. @@ -59,6 +62,12 @@ class TransactionStore(SQLBaseStore): self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -155,6 +164,7 @@ class TransactionStore(SQLBaseStore): """ pass + @defer.inlineCallbacks def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -165,10 +175,20 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a dict for the retry scheme """ - return self.runInteraction( + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + defer.returnValue(result) + + result = yield self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + defer.returnValue(result) + def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, @@ -196,6 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ + self._destination_retry_cache.pop(destination) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 48ca2d634d..346669e4ce 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -24,6 +24,9 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) +SENTINEL = object() + + class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, reset_expiry_on_get=False, iterable=False): @@ -102,6 +105,16 @@ class ExpiringCache(object): return entry.value + def pop(self, key, default=None): + value = self._cache.pop(key, SENTINEL) + if value is SENTINEL: + return default + + if self.iterable: + self._size_estimate -= len(value.value) + + return value + def __contains__(self, key): return key in self._cache -- cgit 1.4.1 From 28781b65e7cb8b0d1fc409d0d78253662b49fe06 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 26 Sep 2018 16:16:41 +0100 Subject: fix #3854 --- synapse/storage/monthly_active_users.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 59580949f1..0fe8c8e24c 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -172,6 +172,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[bool]: True if a new entry was created, False if an existing one was updated. """ + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more is_insert = yield self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", @@ -181,7 +185,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={ "timestamp": int(self._clock.time_msec()), }, - lock=False, ) if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) -- cgit 1.4.1 From ae6ad4cf4190dbd2dadd72fc8131e4c139f8eb4a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 Sep 2018 11:22:25 +0100 Subject: docstrings and unittests for storage.state (#3958) I spent ages trying to figure out how I was going mad... --- changelog.d/3958.misc | 1 + synapse/storage/state.py | 30 ++++++++++++++++++++++-------- tests/storage/test_state.py | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 changelog.d/3958.misc (limited to 'synapse/storage') diff --git a/changelog.d/3958.misc b/changelog.d/3958.misc new file mode 100644 index 0000000000..5931d06dcf --- /dev/null +++ b/changelog.d/3958.misc @@ -0,0 +1 @@ +Fix docstrings and add tests for state store methods diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 4b971efdba..3f4cbd61c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): member events (if True), or to exclude member events (if False) Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types is not None: non_member_types = [t for t in types if t[0] != EventTypes.Member] @@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types: types = frozenset(types) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b910965932..b9c5b39d59 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -74,6 +74,45 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) + @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, + ) + @defer.inlineCallbacks def test_get_state_for_event(self): -- cgit 1.4.1 From 5b68f29f48acfa45e671af1fb325d0c0d532f3d6 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Sat, 29 Sep 2018 02:14:40 +0100 Subject: fix thinkos --- synapse/handlers/register.py | 12 ++++++------ synapse/storage/directory.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a358bfc723..05e8f4ea73 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -515,10 +515,10 @@ class RegistrationHandler(BaseHandler): def _join_user_to_room(self, requester, room_identifier): # try to create the room if we're the first user on the server - if self.config.autocreate_auto_join_rooms: + if self.hs.config.autocreate_auto_join_rooms: count = yield self.store.count_all_users() if count == 1 and RoomAlias.is_valid(room_identifier): - room_creation_handler = hs.get_room_creation_handler() + room_creation_handler = self.hs.get_room_creation_handler() info = yield room_creation_handler.create_room( requester, config={ @@ -528,11 +528,11 @@ class RegistrationHandler(BaseHandler): ) room_id = info["room_id"] - directory_handler = hs.get_handlers().directory_handler + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_identifier) yield directory_handler.create_association( - self, - requester.user, - room_identifier, + requester.user.to_string(), + room_alias, room_id ) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index cfb687cb53..61a029a53c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an associatin between a room alias and room_id/servers + """ Creates an association between a room alias and room_id/servers Args: room_alias (RoomAlias) -- cgit 1.4.1 From 7258d081a5523040dddf5a99ff745acf4fe238bf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Oct 2018 15:47:57 +0100 Subject: Fix bug when invalidating destination retry timings --- synapse/storage/transactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index ab54977a75..a3032cdce9 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -216,7 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - self._destination_retry_cache.pop(destination) + self._destination_retry_cache.pop(destination, None) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, -- cgit 1.4.1 From ae61ade8919085ad0c90e0b54c4e8338998ee64a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 2 Oct 2018 23:33:29 +0100 Subject: Fix bug in forward_extremity update logic An event does not stop being a forward_extremity just because an outlier or rejected event refers to it. --- changelog.d/3995.bug | 1 + synapse/storage/events.py | 111 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 33 deletions(-) create mode 100644 changelog.d/3995.bug (limited to 'synapse/storage') diff --git a/changelog.d/3995.bug b/changelog.d/3995.bug new file mode 100644 index 0000000000..2adc36756b --- /dev/null +++ b/changelog.d/3995.bug @@ -0,0 +1 @@ +Fix bug in event persistence logic which caused 'NoneType is not iterable' \ No newline at end of file diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e7487311ce..046174edb3 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder @@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) for room_id, ev_ctx_rm in iteritems(events_by_room): - # Work out new extremities by recursively adding and removing - # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremeties( + new_latest_event_ids = yield self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # No change in extremities, so no change in state continue + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + new_forward_extremeties[room_id] = new_latest_event_ids len_1 = ( @@ -517,44 +522,88 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremeties for a room given events to + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - new_latest_event_ids = set(latest_event_ids) - # First, add all the new events to the list - new_latest_event_ids.update( - event.event_id for event, ctx in event_contexts + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event for event, ctx in event_contexts if not event.internal_metadata.is_outlier() and not ctx.rejected + ] + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update( + event.event_id for event in new_events ) - # Now remove all events that are referenced by the to-be-added events - new_latest_event_ids.difference_update( + + # Now remove all events which are prev_events of any of the new events + result.difference_update( e_id - for event, ctx in event_contexts + for event in new_events for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() and not ctx.rejected ) - # And finally remove any events that are referenced by previously added - # events. - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=list(new_latest_event_ids), - retcols=["prev_event_id"], - keyvalues={ - "is_state": False, - }, - desc="_calculate_new_extremeties", - ) + # Finally, remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) - new_latest_event_ids.difference_update( - row["prev_event_id"] for row in rows - ) + if not result: + logger.warn( + "Forward extremity list A+B-C-D is now empty in %s. " + "Old extremities (A): %s, new events (B): %s, " + "existing events which are reffed by new events (C): %s, " + "new events which are reffed by existing events (D): %s", + room_id, latest_event_ids, new_events, + [e_id for event in new_events for e_id, _ in event.prev_events], + existing_prevs, + ) + defer.returnValue(result) - defer.returnValue(new_latest_event_ids) + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events(txn, batch): + sql = """ + SELECT prev_event_id + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND rejections.event_id IS NULL + """ % ( + ",".join("?" for _ in batch), + ) + + txn.execute(sql, batch) + results.extend(r[0] for r in txn) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events, + chunk, + ) + + defer.returnValue(results) @defer.inlineCallbacks def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, @@ -586,10 +635,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore the new current state is only returned if we've already calculated it. """ - - if not new_latest_event_ids: - return - # map from state_group to ((type, key) -> event_id) state map state_groups_map = {} -- cgit 1.4.1 From 3e39783d5d7ca5da7b3619d0328f6aeec48854de Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 2 Oct 2018 23:44:14 +0100 Subject: remove debugging --- synapse/storage/events.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 046174edb3..8822dc7bcb 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -555,16 +555,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore existing_prevs = yield self._get_events_which_are_prevs(result) result.difference_update(existing_prevs) - if not result: - logger.warn( - "Forward extremity list A+B-C-D is now empty in %s. " - "Old extremities (A): %s, new events (B): %s, " - "existing events which are reffed by new events (C): %s, " - "new events which are reffed by existing events (D): %s", - room_id, latest_event_ids, new_events, - [e_id for event in new_events for e_id, _ in event.prev_events], - existing_prevs, - ) defer.returnValue(result) @defer.inlineCallbacks -- cgit 1.4.1 From 9693625e556df1af66ba376d49411064c2d0f47e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Oct 2018 10:19:41 +0100 Subject: actually exclude outliers --- synapse/storage/events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 8822dc7bcb..03cedf3a75 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -560,7 +560,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore @defer.inlineCallbacks def _get_events_which_are_prevs(self, event_ids): """Filter the supplied list of event_ids to get those which are prev_events of - existing (non-outlier) events. + existing (non-outlier/rejected) events. Args: event_ids (Iterable[str]): event ids to filter @@ -578,6 +578,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore LEFT JOIN rejections USING (event_id) WHERE prev_event_id IN (%s) + AND NOT events.outlier AND rejections.event_id IS NULL """ % ( ",".join("?" for _ in batch), -- cgit 1.4.1 From 17d585753f1df2b2c2b13ddb8171e174cef97aac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Oct 2018 15:18:52 +0100 Subject: Delete unreferened state groups during purge --- synapse/storage/events.py | 33 +++++++++++++++++++++++++------ synapse/storage/state.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e7487311ce..0fb190530a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2025,6 +2025,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] finding state groups which depend on redundant" " state groups") remaining_state_groups = [] + unreferenced_state_groups = 0 for i in range(0, len(state_rows), 100): chunk = [sg for sg, in state_rows[i:i + 100]] # look for state groups whose prev_state_group is one we are about @@ -2037,13 +2038,33 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore retcols=["state_group"], keyvalues={}, ) - remaining_state_groups.extend( - row["state_group"] for row in rows - # exclude state groups we are about to delete: no point in - # updating them - if row["state_group"] not in state_groups_to_delete - ) + for row in rows: + sg = row["state_group"] + + if sg in state_groups_to_delete: + # exclude state groups we are about to delete: no point in + # updating them + continue + + if not self._is_state_group_referenced(txn, sg): + # Let's also delete unreferenced state groups while we're + # here, since otherwise we'd need to de-delta them + state_groups_to_delete.add(sg) + unreferenced_state_groups += 1 + continue + + remaining_state_groups.append(sg) + + logger.info( + "[purge] found %i extra unreferenced state groups to delete", + unreferenced_state_groups, + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) # Now we turn the state groups that reference to-be-deleted state # groups to non delta versions. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3f4cbd61c4..b88c7dc091 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1041,6 +1041,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return count + def _is_state_group_referenced(self, txn, state_group): + """Checks if a given state group is referenced, or is safe to delete. + + A state groups is referenced if it or any of its descendants are + pointed at by an event. (A descendant is a group which has the given + state_group as a prev group) + """ + + # We check this by doing a depth first search to look for any + # descendant referenced by `event_to_state_groups`. + + # State groups we need to check, contains state groups that are + # descendants of `state_group` + state_groups_to_search = [state_group] + + # Set of state groups we've already checked + state_groups_searched = set() + + while state_groups_to_search: + state_group = state_groups_to_search.pop() # Next state group to check + + is_referenced = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"state_group": state_group}, + retcol="event_id", + allow_none=True, + ) + if is_referenced: + # A descendant is referenced by event_to_state_groups, so + # original state group is referenced. + return True + + state_groups_searched.add(state_group) + + # Find all children of current state group and add to search + references = self._simple_select_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"prev_state_group": state_group}, + retcol="state_group", + ) + state_groups_to_search.extend(references) + + # Lets be paranoid and check for cycles + if state_groups_searched.intersection(references): + raise Exception("State group %s has cyclic dependency", state_group) + + return False + class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): """ Keeps track of the state at a given event. -- cgit 1.4.1 From 4917ff55234718c0e650c6dc2a1117304465b9be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Oct 2018 15:24:01 +0100 Subject: Add state_group index to event_to_state_groups This is needed to efficiently check for unreferenced state groups during purge. --- synapse/storage/prepare_database.py | 2 +- .../delta/52/add_event_to_state_group_index.sql | 19 +++++++++++++++++++ synapse/storage/state.py | 7 +++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/delta/52/add_event_to_state_group_index.sql (limited to 'synapse/storage') diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b364719312..bd740e1e45 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 51 +SCHEMA_VERSION = 52 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql new file mode 100644 index 0000000000..91e03d13e1 --- /dev/null +++ b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is needed to efficiently check for unreferenced state groups during +-- purge. Added events_to_state_group(state_group) index +INSERT into background_updates (update_name, progress_json) + VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b88c7dc091..3f08f447a3 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1114,6 +1114,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, db_conn, hs): super(StateStore, self).__init__(db_conn, hs) @@ -1132,6 +1133,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): columns=["state_key"], where_clause="type='m.room.member'", ) + self.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) def _store_event_state_mappings_txn(self, txn, events_and_contexts): state_groups = {} -- cgit 1.4.1 From 497444f1fdd2c39906179b1dde8c67415e465398 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 5 Oct 2018 15:08:36 +0100 Subject: Don't reuse backup versions Since we don't actually delete the keys, just mark the versions as deleted in the db rather than actually deleting them, then we won't reuse versions. Fixes https://github.com/vector-im/riot-web/issues/7448 --- synapse/storage/e2e_room_keys.py | 9 +++++++-- synapse/storage/schema/delta/51/e2e_room_keys.sql | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 969f4aef9c..4d439bb164 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -193,7 +193,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): @staticmethod def _get_current_version(txn, user_id): txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", (user_id,) ) row = txn.fetchone() @@ -226,6 +227,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={ "user_id": user_id, "version": this_version, + "deleted": 0, }, retcols=( "version", @@ -300,13 +302,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - return self._simple_delete_one_txn( + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={ "user_id": user_id, "version": this_version, }, + updatevalues={ + "deleted": 1, + } ) return self.runInteraction( diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql index 4531fd56ee..c0e66a697d 100644 --- a/synapse/storage/schema/delta/51/e2e_room_keys.sql +++ b/synapse/storage/schema/delta/51/e2e_room_keys.sql @@ -32,7 +32,8 @@ CREATE TABLE e2e_room_keys_versions ( user_id TEXT NOT NULL, version TEXT NOT NULL, algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL ); CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); -- cgit 1.4.1 From 8ddd0f273c501b83b75a58e379d73aef3cfab35e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 12 Oct 2018 09:55:41 +0100 Subject: Comments on get_all_new_events_stream just some docstrings to clarify the behaviour here --- synapse/storage/stream.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 4c296d72c0..d6cfdba519 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -630,7 +630,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events""" + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ def get_all_new_events_stream_txn(txn): sql = ( -- cgit 1.4.1 From 306361b31bc0fb229de8db29209b9da38374d6f2 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 12 Oct 2018 11:48:56 +0100 Subject: Misc PR feedback bits --- synapse/storage/e2e_room_keys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 4d439bb164..f25ded2295 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import simplejson as json +import json from twisted.internet import defer @@ -76,7 +76,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): session_id(str): the session whose room_key we're setting room_key(dict): the room_key being set Raises: - StoreError if stuff goes wrong, probably + StoreError """ yield self._simple_upsert( -- cgit 1.4.1 From 67a1e315cc7f8b270ec87e0ab6e58aa2adaee32d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Oct 2018 13:49:48 +0100 Subject: Fix up comments --- synapse/storage/state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3f08f447a3..f7cf5c86c9 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1044,9 +1044,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _is_state_group_referenced(self, txn, state_group): """Checks if a given state group is referenced, or is safe to delete. - A state groups is referenced if it or any of its descendants are - pointed at by an event. (A descendant is a group which has the given - state_group as a prev group) + A state group is referenced if it or any of its descendants are + pointed at by an event. (A descendant is a state_group whose chain of + prev_groups includes the given state_group.) """ # We check this by doing a depth first search to look for any -- cgit 1.4.1 From e238013c44714174f1bf9a2b9b1f576728f40784 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Sep 2018 15:13:17 +0100 Subject: Add v2 state res algorithm. We hook this up to the vdh test room version. --- synapse/state/__init__.py | 87 ++++++-- synapse/state/v2.py | 544 ++++++++++++++++++++++++++++++++++++++++++++++ synapse/storage/events.py | 9 +- 3 files changed, 616 insertions(+), 24 deletions(-) create mode 100644 synapse/state/v2.py (limited to 'synapse/storage') diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b22495c1f9..836e137296 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -19,13 +19,14 @@ from collections import namedtuple from six import iteritems, itervalues +import attr from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext -from synapse.state import v1 +from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -372,15 +373,10 @@ class StateHandler(object): result = yield self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, None, - self._state_map_factory, + state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) - def _state_map_factory(self, ev_ids): - return self.store.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - @defer.inlineCallbacks def resolve_events(self, room_version, state_sets, event): logger.info( @@ -401,7 +397,7 @@ class StateHandler(object): new_state = yield resolve_events_with_factory( room_version, state_set_ids, event_map=state_map, - state_map_factory=self._state_map_factory + state_res_store=StateResolutionStore(self.store), ) new_state = { @@ -436,7 +432,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_res_store, ): """Resolves conflicts between a set of state groups @@ -454,9 +450,11 @@ class StateResolutionHandler(object): a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. + events will be requested via state_res_store. + + If None, all events will be fetched via state_res_store. - If None, all events will be fetched via state_map_factory. + state_res_store (StateResolutionStore) Returns: Deferred[_StateCacheEntry]: resolved state @@ -502,7 +500,7 @@ class StateResolutionHandler(object): room_version, list(itervalues(state_groups_ids)), event_map=event_map, - state_map_factory=state_map_factory, + state_res_store=state_res_store, ) # if the new state matches any of the input state groups, we can @@ -583,7 +581,7 @@ def _make_state_cache_entry( ) -def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): +def resolve_events_with_factory(room_version, state_sets, event_map, state_res_store): """ Args: room_version(str): Version of the room @@ -599,17 +597,19 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called - with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + state_res_store (StateResolutionStore) Returns Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + if room_version == RoomVersions.V1: return v1.resolve_events_with_factory( - state_sets, event_map, state_map_factory, + state_sets, event_map, state_res_store.get_events, + ) + elif room_version == RoomVersions.VDH_TEST: + return v2.resolve_events_with_factory( + state_sets, event_map, state_res_store, ) else: # This should only happen if we added a version but forgot to add it to @@ -617,3 +617,54 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f raise Exception( "No state resolution algorithm defined for version %r" % (room_version,) ) + + +@attr.s +class StateResolutionStore(object): + """Interface that allows state resolution algorithms to access the database + in well defined way. + + Args: + store (DataStore) + """ + + store = attr.ib() + + def get_events(self, event_ids, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + """ + + return self.store.get_events( + event_ids, + check_redacted=False, + get_prev_content=False, + allow_rejected=allow_rejected, + ) + + def get_auth_chain(self, event_ids): + """Gets the full auth chain for a set of events (including rejected + events). + + Includes the given event IDs in the result. + + Note that: + 1. All events must be state events. + 2. For v1 rooms this may not have the full auth chain in the + presence of rejected events + + Args: + event_ids (list): The event IDs of the events to fetch the auth + chain for. Must be state events. + + Returns: + Deferred[list[str]]: List of event IDs of the auth chain. + """ + + return self.store.get_auth_chain_ids(event_ids, include_given=True) diff --git a/synapse/state/v2.py b/synapse/state/v2.py new file mode 100644 index 0000000000..d034b4f38e --- /dev/null +++ b/synapse/state/v2.py @@ -0,0 +1,544 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +import itertools +import logging + +from six import iteritems, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +@defer.inlineCallbacks +def resolve_events_with_factory(state_sets, event_map, state_res_store): + """Resolves the state using the v2 state resolution algorithm + + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_res_store. + + If None, all events will be fetched via state_res_store. + + state_res_store (StateResolutionStore) + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + + logger.debug("Computing conflicted state") + + # First split up the un/conflicted state + unconflicted_state, conflicted_state = _seperate(state_sets) + + if not conflicted_state: + defer.returnValue(unconflicted_state) + + logger.debug("%d conflicted state entries", len(conflicted_state)) + logger.debug("Calculating auth chain difference") + + # Also fetch all auth events that appear in only some of the state sets' + # auth chains. + auth_diff = yield _get_auth_chain_difference( + state_sets, event_map, state_res_store, + ) + + full_conflicted_set = set(itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), + auth_diff, + )) + + events = yield state_res_store.get_events([ + eid for eid in full_conflicted_set + if eid not in event_map + ], allow_rejected=True) + event_map.update(events) + + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + + logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) + + # Get and sort all the power events (kicks/bans/etc) + power_events = ( + eid for eid in full_conflicted_set + if _is_power_event(event_map[eid]) + ) + + sorted_power_events = yield _reverse_topological_power_sort( + power_events, + event_map, + state_res_store, + full_conflicted_set, + ) + + logger.debug("sorted %d power events", len(sorted_power_events)) + + # Now sequentially auth each one + resolved_state = yield _iterative_auth_checks( + sorted_power_events, unconflicted_state, event_map, + state_res_store, + ) + + logger.debug("resolved power events") + + # OK, so we've now resolved the power events. Now sort the remaining + # events using the mainline of the resolved power level. + + leftover_events = [ + ev_id + for ev_id in full_conflicted_set + if ev_id not in sorted_power_events + ] + + logger.debug("sorting %d remaining events", len(leftover_events)) + + pl = resolved_state.get((EventTypes.PowerLevels, ""), None) + leftover_events = yield _mainline_sort( + leftover_events, pl, event_map, state_res_store, + ) + + logger.debug("resolving remaining events") + + resolved_state = yield _iterative_auth_checks( + leftover_events, resolved_state, event_map, + state_res_store, + ) + + logger.debug("resolved") + + # We make sure that unconflicted state always still applies. + resolved_state.update(unconflicted_state) + + logger.debug("done") + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _get_power_level_for_sender(event_id, event_map, state_res_store): + """Return the power level of the sender of the given event according to + their auth events. + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + event = yield _get_event(event_id, event_map, state_res_store) + + pl = None + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + pl = aev + break + + if pl is None: + # Couldn't find power level. Check if they're the creator of the room + for aid, _ in event.auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.Create, ""): + if aev.content.get("creator") == event.sender: + defer.returnValue(100) + break + defer.returnValue(0) + + level = pl.content.get("users", {}).get(event.sender) + if level is None: + level = pl.content.get("users_default", 0) + + if level is None: + defer.returnValue(0) + else: + defer.returnValue(int(level)) + + +@defer.inlineCallbacks +def _get_auth_chain_difference(state_sets, event_map, state_res_store): + """Compare the auth chains of each state set and return the set of events + that only appear in some but not all of the auth chains. + + Args: + state_sets (list) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[set[str]]: Set of event IDs + """ + common = set(itervalues(state_sets[0])).intersection( + *(itervalues(s) for s in state_sets[1:]) + ) + + auth_sets = [] + for state_set in state_sets: + auth_ids = set( + eid + for key, eid in iteritems(state_set) + if (key[0] in ( + EventTypes.Member, + EventTypes.ThirdPartyInvite, + ) or key in ( + (EventTypes.PowerLevels, ''), + (EventTypes.Create, ''), + (EventTypes.JoinRules, ''), + )) and eid not in common + ) + + auth_chain = yield state_res_store.get_auth_chain(auth_ids) + auth_ids.update(auth_chain) + + auth_sets.append(auth_ids) + + intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) + union = set().union(*auth_sets) + + defer.returnValue(union - intersection) + + +def _seperate(state_sets): + """Return the unconflicted and conflicted state. This is different than in + the original algorithm, as this defines a key to be conflicted if one of + the state sets doesn't have that key. + + Args: + state_sets (list) + + Returns: + tuple[dict, dict]: A tuple of unconflicted and conflicted state. The + conflicted state dict is a map from type/state_key to set of event IDs + """ + unconflicted_state = {} + conflicted_state = {} + + for key in set(itertools.chain.from_iterable(state_sets)): + event_ids = set(state_set.get(key) for state_set in state_sets) + if len(event_ids) == 1: + unconflicted_state[key] = event_ids.pop() + else: + event_ids.discard(None) + conflicted_state[key] = event_ids + + return unconflicted_state, conflicted_state + + +def _is_power_event(event): + """Return whether or not the event is a "power event", as defined by the + v2 state resolution algorithm + + Args: + event (FrozenEvent) + + Returns: + boolean + """ + if (event.type, event.state_key) in ( + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Create, ""), + ): + return True + + if event.type == EventTypes.Member: + if event.membership in ('leave', 'ban'): + return event.sender != event.state_key + + return False + + +@defer.inlineCallbacks +def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, + state_res_store, auth_diff): + """Helper function for _reverse_topological_power_sort that add the event + and its auth chain (that is in the auth diff) to the graph + + Args: + graph (dict[str, set[str]]): A map from event ID to the events auth + event IDs + event_id (str): Event to add to the graph + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + """ + + state = [event_id] + while state: + eid = state.pop() + graph.setdefault(eid, set()) + + event = yield _get_event(eid, event_map, state_res_store) + for aid, _ in event.auth_events: + if aid in auth_diff: + if aid not in graph: + state.append(aid) + + graph.setdefault(eid, set()).add(aid) + + +@defer.inlineCallbacks +def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): + """Returns a list of the event_ids sorted by reverse topological ordering, + and then by power level and origin_server_ts + + Args: + event_ids (list[str]): The events to sort + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + auth_diff (set[str]): Set of event IDs that are in the auth difference. + + Returns: + Deferred[list[str]]: The sorted list + """ + + graph = {} + for event_id in event_ids: + yield _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff, + ) + + event_to_pl = {} + for event_id in graph: + pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + event_to_pl[event_id] = pl + + def _get_power_order(event_id): + ev = event_map[event_id] + pl = event_to_pl[event_id] + + return -pl, ev.origin_server_ts, event_id + + # Note: graph is modified during the sort + it = lexicographical_topological_sort( + graph, + key=_get_power_order, + ) + sorted_events = list(it) + + defer.returnValue(sorted_events) + + +@defer.inlineCallbacks +def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): + """Sequentially apply auth checks to each event in given list, updating the + state as it goes along. + + Args: + event_ids (list[str]): Ordered list of events to apply auth checks to + base_state (dict[tuple[str, str], str]): The set of state to start with + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[dict[tuple[str, str], str]]: Returns the final updated state + """ + resolved_state = base_state.copy() + + for event_id in event_ids: + event = event_map[event_id] + + auth_events = {} + for aid, _ in event.auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[(ev.type, ev.state_key)] = ev + + for key in event_auth.auth_types_for_event(event): + if key in resolved_state: + ev_id = resolved_state[key] + ev = yield _get_event(ev_id, event_map, state_res_store) + + if ev.rejected_reason is None: + auth_events[key] = event_map[ev_id] + + try: + event_auth.check( + event, auth_events, + do_sig_check=False, + do_size_check=False + ) + + resolved_state[(event.type, event.state_key)] = event_id + except AuthError: + pass + + defer.returnValue(resolved_state) + + +@defer.inlineCallbacks +def _mainline_sort(event_ids, resolved_power_event_id, event_map, + state_res_store): + """Returns a sorted list of event_ids sorted by mainline ordering based on + the given event resolved_power_event_id + + Args: + event_ids (list[str]): Events to sort + resolved_power_event_id (str): The final resolved power level event ID + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[list[str]]: The sorted list + """ + mainline = [] + pl = resolved_power_event_id + while pl: + mainline.append(pl) + pl_ev = yield _get_event(pl, event_map, state_res_store) + auth_events = pl_ev.auth_events + pl = None + for aid, _ in auth_events: + ev = yield _get_event(aid, event_map, state_res_store) + if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): + pl = aid + break + + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} + + event_ids = list(event_ids) + + order_map = {} + for ev_id in event_ids: + depth = yield _get_mainline_depth_for_event( + event_map[ev_id], mainline_map, + event_map, state_res_store, + ) + order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + + event_ids.sort(key=lambda ev_id: order_map[ev_id]) + + defer.returnValue(event_ids) + + +@defer.inlineCallbacks +def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): + """Get the mainline depths for the given event based on the mainline map + + Args: + event (FrozenEvent) + mainline_map (dict[str, int]): Map from event_id to mainline depth for + events in the mainline. + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[int] + """ + + # We do an iterative search, replacing `event with the power level in its + # auth events (if any) + while event: + depth = mainline_map.get(event.event_id) + if depth is not None: + defer.returnValue(depth) + + auth_events = event.auth_events + event = None + + for aid, _ in auth_events: + aev = yield _get_event(aid, event_map, state_res_store) + if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + event = aev + break + + # Didn't find a power level auth event, so we just return 0 + defer.returnValue(0) + + +@defer.inlineCallbacks +def _get_event(event_id, event_map, state_res_store): + """Helper function to look up event in event_map, falling back to looking + it up in the store + + Args: + event_id (str) + event_map (dict[str,FrozenEvent]) + state_res_store (StateResolutionStore) + + Returns: + Deferred[FrozenEvent] + """ + if event_id not in event_map: + events = yield state_res_store.get_events([event_id], allow_rejected=True) + event_map.update(events) + defer.returnValue(event_map[event_id]) + + +def lexicographical_topological_sort(graph, key): + """Performs a lexicographic reverse topological sort on the graph. + + This returns a reverse topological sort (i.e. if node A references B then B + appears before A in the sort), with ties broken lexicographically based on + return value of the `key` function. + + NOTE: `graph` is modified during the sort. + + Args: + graph (dict[str, set[str]]): A representation of the graph where each + node is a key in the dict and its value are the nodes edges. + key (func): A function that takes a node and returns a value that is + comparable and used to order nodes + + Yields: + str: The next node in the topological sort + """ + + # Note, this is basically Kahn's algorithm except we look at nodes with no + # outgoing edges, c.f. + # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + outdegree_map = graph + reverse_graph = {} + + # Lists of nodes with zero out degree. Is actually a tuple of + # `(key(node), node)` so that sorting does the right thing + zero_outdegree = [] + + for node, edges in iteritems(graph): + if len(edges) == 0: + zero_outdegree.append((key(node), node)) + + reverse_graph.setdefault(node, set()) + for edge in edges: + reverse_graph.setdefault(edge, set()).add(node) + + # heapq is a built in implementation of a sorted queue. + heapq.heapify(zero_outdegree) + + while zero_outdegree: + _, node = heapq.heappop(zero_outdegree) + + for parent in reverse_graph[node]: + out = outdegree_map[parent] + out.discard(node) + if len(out) == 0: + heapq.heappush(zero_outdegree, (key(parent), parent)) + + yield node diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 03cedf3a75..fc88edcb39 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore @@ -731,11 +732,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Ok, we need to defer to the state handler to resolve our state sets. - def get_events(ev_ids): - return self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - state_groups = { sg: state_groups_map[sg] for sg in new_state_groups } @@ -745,7 +741,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, + state_res_store=StateResolutionStore(self) ) defer.returnValue((res.state, None)) -- cgit 1.4.1 From 4a28d3d36fa21446e60a2a5142626a385e1f27af Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Oct 2018 14:01:53 +0100 Subject: Update event_auth table for rejected events --- synapse/storage/events.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index fc88edcb39..c780f55277 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -851,6 +851,27 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self._simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id, _ in event.auth_events + if event.is_state() + ], + ) + # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( @@ -1326,21 +1347,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore txn, event.room_id, event.redacts ) - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id, _ in event.auth_events - if event.is_state() - ], - ) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( -- cgit 1.4.1 From fc0f13dd036cec4e41f5969d021d9dd10d6e5016 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 16 Oct 2018 20:37:16 +0100 Subject: Fix incorrect truncation in get_missing_events It's quite important that get_missing_events returns the *latest* events in the room; however we were pulling event ids out of the database until we got *at least* 10, and then taking the *earliest* of the results. We also shouldn't really be relying on depth, and should be checking the room_id. --- changelog.d/4045.bugfix | 1 + synapse/federation/federation_server.py | 8 +++---- synapse/federation/transport/server.py | 2 -- synapse/handlers/federation.py | 12 +++++------ synapse/storage/event_federation.py | 38 ++++++++++++++------------------- 5 files changed, 26 insertions(+), 35 deletions(-) create mode 100644 changelog.d/4045.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/4045.bugfix b/changelog.d/4045.bugfix new file mode 100644 index 0000000000..fa50eb5aff --- /dev/null +++ b/changelog.d/4045.bugfix @@ -0,0 +1 @@ +Fix bug which made get_missing_events return too few events \ No newline at end of file diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 819e8f7331..4efe95faa4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -507,19 +507,19 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," - " limit: %d, min_depth: %d", - earliest_events, latest_events, limit, min_depth + " limit: %d", + earliest_events, latest_events, limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, min_depth + origin, room_id, earliest_events, latest_events, limit, ) if len(missing_events) < 5: diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2f874b4838..7288d49074 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -560,7 +560,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): @defer.inlineCallbacks def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) - min_depth = int(content.get("min_depth", 0)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) @@ -569,7 +568,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, - min_depth=min_depth, limit=limit, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 45d955e6f5..cab57a8849 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -309,8 +309,8 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warn( - "[%s %s] Failed to fetch %d prev events: rejecting", - room_id, event_id, len(prevs - seen), + "[%s %s] Rejecting: failed to fetch %d prev events: %s", + room_id, event_id, len(prevs - seen), shortstr(prevs - seen) ) raise FederationError( "ERROR", @@ -452,8 +452,8 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "[%s %s]: Requesting %d prev_events: %s", - room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + "[%s %s]: Requesting missing events between %s and %s", + room_id, event_id, shortstr(latest), event_id, ) # XXX: we set timeout to 10s to help workaround @@ -1852,7 +1852,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit, min_depth): + latest_events, limit): in_room = yield self.auth.check_host_in_room( room_id, origin @@ -1861,14 +1861,12 @@ class FederationHandler(BaseHandler): raise AuthError(403, "Host not in room.") limit = min(limit, 20) - min_depth = max(min_depth, 0) missing_events = yield self.store.get_missing_events( room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, limit=limit, - min_depth=min_depth, ) missing_events = yield filter_events_for_server( diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 24345b20a6..3faca2a042 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -376,33 +376,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, - limit, min_depth): + limit): ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, - room_id, earliest_events, latest_events, limit, min_depth + room_id, earliest_events, latest_events, limit, ) - events = yield self._get_events(ids) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - defer.returnValue(events[:limit]) + defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, - limit, min_depth): - - earliest_events = set(earliest_events) - front = set(latest_events) - earliest_events + limit): - event_results = set() + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] query = ( "SELECT prev_event_id FROM event_edges " - "WHERE event_id = ? AND is_state = ? " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -411,18 +403,20 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, for event_id in front: txn.execute( query, - (event_id, False, limit - len(event_results)) + (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn: - new_front.add(e_id) + new_results = set(t[0] for t in txn) - seen_events - new_front -= earliest_events - new_front -= event_results + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) front = new_front - event_results |= new_front + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() return event_results -- cgit 1.4.1 From 47a9da28caa6a4f27d2df31f043971a2c9c7b555 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Oct 2018 20:43:18 +0100 Subject: Batch process handling state groups --- synapse/storage/events.py | 81 +++++++++------------------------ synapse/storage/state.py | 112 +++++++++++++++++++++++++++++----------------- 2 files changed, 92 insertions(+), 101 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0fb190530a..e4d0f8b1a9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -37,6 +37,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.state import StateGroupWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -203,7 +204,8 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore): +class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, + BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -1995,70 +1997,29 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] finding redundant state groups") - # Get all state groups that are only referenced by events that are - # to be deleted. - # This works by first getting state groups that we may want to delete, - # joining against event_to_state_groups to get events that use that - # state group, then left joining against events_to_purge again. Any - # state group where the left join produce *no nulls* are referenced - # only by events that are going to be purged. + # Get all state groups that are referenced by events that are to be + # deleted. We then go and check if they are referenced by other events + # or state groups, and if not we delete them. txn.execute(""" - SELECT state_group FROM - ( - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) - ) AS sp - INNER JOIN event_to_state_groups USING (state_group) - LEFT JOIN events_to_purge AS ep USING (event_id) - GROUP BY state_group - HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0 + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) """) - state_rows = txn.fetchall() - logger.info("[purge] found %i redundant state groups", len(state_rows)) - - # make a set of the redundant state groups, so that we can look them up - # efficiently - state_groups_to_delete = set([sg for sg, in state_rows]) - - # Now we get all the state groups that rely on these state groups - logger.info("[purge] finding state groups which depend on redundant" - " state groups") - remaining_state_groups = [] - unreferenced_state_groups = 0 - for i in range(0, len(state_rows), 100): - chunk = [sg for sg, in state_rows[i:i + 100]] - # look for state groups whose prev_state_group is one we are about - # to delete - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=chunk, - retcols=["state_group"], - keyvalues={}, - ) - - for row in rows: - sg = row["state_group"] - - if sg in state_groups_to_delete: - # exclude state groups we are about to delete: no point in - # updating them - continue + referenced_state_groups = set(sg for sg, in txn) + logger.info( + "[purge] found %i referenced state groups", + len(referenced_state_groups), + ) - if not self._is_state_group_referenced(txn, sg): - # Let's also delete unreferenced state groups while we're - # here, since otherwise we'd need to de-delta them - state_groups_to_delete.add(sg) - unreferenced_state_groups += 1 - continue + logger.info("[purge] finding state groups that can be deleted") - remaining_state_groups.append(sg) + state_groups_to_delete, remaining_state_groups = self._find_unreferenced_groups( + txn, referenced_state_groups, + ) logger.info( - "[purge] found %i extra unreferenced state groups to delete", - unreferenced_state_groups, + "[purge] found %i state groups to delete", + len(state_groups_to_delete), ) logger.info( @@ -2109,11 +2070,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", - state_rows + ((sg,) for sg in state_groups_to_delete), ) txn.executemany( "DELETE FROM state_groups WHERE id = ?", - state_rows + ((sg,) for sg in state_groups_to_delete), ) logger.info("[purge] removing events from event_to_state_groups") diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f7cf5c86c9..0f86311ed4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1041,55 +1041,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return count - def _is_state_group_referenced(self, txn, state_group): - """Checks if a given state group is referenced, or is safe to delete. + def _find_unreferenced_groups(self, txn, state_groups): + """Used when purging history to figure out which state groups can be + deleted and which need to be de-delta'ed (due to one of its prev groups + being scheduled for deletion). - A state group is referenced if it or any of its descendants are - pointed at by an event. (A descendant is a state_group whose chain of - prev_groups includes the given state_group.) - """ - - # We check this by doing a depth first search to look for any - # descendant referenced by `event_to_state_groups`. - - # State groups we need to check, contains state groups that are - # descendants of `state_group` - state_groups_to_search = [state_group] - - # Set of state groups we've already checked - state_groups_searched = set() - - while state_groups_to_search: - state_group = state_groups_to_search.pop() # Next state group to check + Args: + txn + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. - is_referenced = self._simple_select_one_onecol_txn( + Returns: + tuple[set[int], set[int]]: The set of state groups that can be + deleted and the set of state groups that need to be de-delta'ed + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + lst = list(next_to_search) + current_search = set(lst[:100]) + next_to_search = set(lst[100:]) + + # Check if state groups are referenced + sql = """ + SELECT state_group, count(*) FROM event_to_state_groups + LEFT JOIN events_to_purge AS ep USING (event_id) + WHERE state_group IN (%s) AND ep.event_id IS NULL + GROUP BY state_group + """ % (",".join("?" for _ in current_search),) + txn.execute(sql, list(current_search)) + + referenced = set(sg for sg, cnt in txn if cnt > 0) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + rows = self._simple_select_many_txn( txn, - table="event_to_state_groups", - keyvalues={"state_group": state_group}, - retcol="event_id", - allow_none=True, + table="state_group_edges", + column="prev_state_group", + iterable=current_search, + keyvalues={}, + retcols=("prev_state_group", "state_group",), ) - if is_referenced: - # A descendant is referenced by event_to_state_groups, so - # original state group is referenced. - return True - state_groups_searched.add(state_group) + next_to_search.update(row["state_group"] for row in rows) + # We don't bother re-handling groups we've already seen + next_to_search -= state_groups_seen + state_groups_seen |= next_to_search - # Find all children of current state group and add to search - references = self._simple_select_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"prev_state_group": state_group}, - retcol="state_group", - ) - state_groups_to_search.extend(references) + for row in rows: + # Note: Each state group can have at most one prev group + graph[row["state_group"]] = row["prev_state_group"] + + to_delete = state_groups_seen - referenced_groups - # Lets be paranoid and check for cycles - if state_groups_searched.intersection(references): - raise Exception("State group %s has cyclic dependency", state_group) + to_dedelta = set() + for sg in referenced_groups: + prev_sg = graph.get(sg) + if prev_sg and prev_sg in to_delete: + to_dedelta.add(sg) - return False + return to_delete, to_dedelta class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): -- cgit 1.4.1 From 67f7b9cb50c246226b94027e3c1d4b958ff9f840 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Oct 2018 16:06:59 +0100 Subject: pep8 --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index af822fb69d..379b5c514f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2041,7 +2041,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore INNER JOIN event_to_state_groups USING (event_id) """) - referenced_state_groups = set(sg for sg, in txn) + referenced_state_groups = set(sg for sg, in txn) logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups), -- cgit 1.4.1 From e1728dfcbe585edfb590bce50adeaab341a70db8 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 20 Oct 2018 11:16:55 +1100 Subject: Make scripts/ and scripts-dev/ pass pyflakes (and the rest of the codebase on py3) (#4068) --- .travis.yml | 2 +- changelog.d/4068.misc | 1 + scripts-dev/check_auth.py | 36 ++-- scripts-dev/check_event_hash.py | 32 ++-- scripts-dev/check_signature.py | 36 ++-- scripts-dev/convert_server_keys.py | 40 +++-- scripts-dev/definitions.py | 54 +++--- scripts-dev/dump_macaroon.py | 13 +- scripts-dev/federation_client.py | 99 +++++------ scripts-dev/hash_history.py | 62 ++++--- scripts-dev/list_url_patterns.py | 16 +- scripts-dev/tail-synapse.py | 24 +-- scripts/hash_password | 5 +- scripts/move_remote_media_to_new_store.py | 36 ++-- scripts/register_new_matrix_user | 76 ++++---- scripts/synapse_port_db | 272 +++++++++++++---------------- synapse/app/homeserver.py | 2 +- synapse/config/__main__.py | 2 +- synapse/config/_base.py | 121 +++++++------ synapse/storage/_base.py | 4 +- synapse/storage/keys.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/signatures.py | 2 +- synapse/storage/transactions.py | 2 +- synapse/util/caches/stream_change_cache.py | 4 +- synctl | 80 ++++----- tox.ini | 4 +- 27 files changed, 511 insertions(+), 518 deletions(-) create mode 100644 changelog.d/4068.misc (limited to 'synapse/storage') diff --git a/.travis.yml b/.travis.yml index 2077f6af72..7ee1a278db 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,7 @@ matrix: - python: 2.7 env: TOX_ENV=packaging - - python: 2.7 + - python: 3.6 env: TOX_ENV=pep8 - python: 2.7 diff --git a/changelog.d/4068.misc b/changelog.d/4068.misc new file mode 100644 index 0000000000..db6c4ade59 --- /dev/null +++ b/changelog.d/4068.misc @@ -0,0 +1 @@ +Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8. diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py index 4fa8792a5f..b3d11f49ec 100644 --- a/scripts-dev/check_auth.py +++ b/scripts-dev/check_auth.py @@ -1,21 +1,20 @@ -from synapse.events import FrozenEvent -from synapse.api.auth import Auth - -from mock import Mock +from __future__ import print_function import argparse import itertools import json import sys +from mock import Mock + +from synapse.api.auth import Auth +from synapse.events import FrozenEvent + def check_auth(auth, auth_chain, events): auth_chain.sort(key=lambda e: e.depth) - auth_map = { - e.event_id: e - for e in auth_chain - } + auth_map = {e.event_id: e for e in auth_chain} create_events = {} for e in auth_chain: @@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events): for e in itertools.chain(auth_chain, events): auth_events_list = [auth_map[i] for i, _ in e.auth_events] - auth_events = { - (e.type, e.state_key): e - for e in auth_events_list - } + auth_events = {(e.type, e.state_key): e for e in auth_events_list} auth_events[("m.room.create", "")] = create_events[e.room_id] try: auth.check(e, auth_events=auth_events) except Exception as ex: - print "Failed:", e.event_id, e.type, e.state_key - print "Auth_events:", auth_events - print ex - print json.dumps(e.get_dict(), sort_keys=True, indent=4) + print("Failed:", e.event_id, e.type, e.state_key) + print("Auth_events:", auth_events) + print(ex) + print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) # raise - print "Success:", e.event_id, e.type, e.state_key + print("Success:", e.event_id, e.type, e.state_key) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - 'json', - nargs='?', - type=argparse.FileType('r'), - default=sys.stdin, + 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin ) args = parser.parse_args() diff --git a/scripts-dev/check_event_hash.py b/scripts-dev/check_event_hash.py index 7ccae34d48..8535f99697 100644 --- a/scripts-dev/check_event_hash.py +++ b/scripts-dev/check_event_hash.py @@ -1,10 +1,15 @@ -from synapse.crypto.event_signing import * -from unpaddedbase64 import encode_base64 - import argparse import hashlib -import sys import json +import logging +import sys + +from unpaddedbase64 import encode_base64 + +from synapse.crypto.event_signing import ( + check_event_content_hash, + compute_event_reference_hash, +) class dictobj(dict): @@ -24,27 +29,26 @@ class dictobj(dict): def main(): parser = argparse.ArgumentParser() - parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), - default=sys.stdin) + parser.add_argument( + "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + ) args = parser.parse_args() logging.basicConfig() event_json = dictobj(json.load(args.input_json)) - algorithms = { - "sha256": hashlib.sha256, - } + algorithms = {"sha256": hashlib.sha256} for alg_name in event_json.hashes: if check_event_content_hash(event_json, algorithms[alg_name]): - print "PASS content hash %s" % (alg_name,) + print("PASS content hash %s" % (alg_name,)) else: - print "FAIL content hash %s" % (alg_name,) + print("FAIL content hash %s" % (alg_name,)) for algorithm in algorithms.values(): name, h_bytes = compute_event_reference_hash(event_json, algorithm) - print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) + print("Reference hash %s: %s" % (name, encode_base64(h_bytes))) -if __name__=="__main__": - main() +if __name__ == "__main__": + main() diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index 079577908a..612f17ca7f 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -1,15 +1,15 @@ -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes, write_signing_keys -from unpaddedbase64 import decode_base64 - -import urllib2 +import argparse import json +import logging import sys +import urllib2 + import dns.resolver -import pprint -import argparse -import logging +from signedjson.key import decode_verify_key_bytes, write_signing_keys +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 + def get_targets(server_name): if ":" in server_name: @@ -23,6 +23,7 @@ def get_targets(server_name): except dns.resolver.NXDOMAIN: yield (server_name, 8448) + def get_server_keys(server_name, target, port): url = "https://%s:%i/_matrix/key/v1" % (target, port) keys = json.load(urllib2.urlopen(url)) @@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port): verify_keys[key_id] = verify_key return verify_keys + def main(): parser = argparse.ArgumentParser() parser.add_argument("signature_name") - parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), - default=sys.stdin) + parser.add_argument( + "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + ) args = parser.parse_args() logging.basicConfig() @@ -48,24 +51,23 @@ def main(): for target, port in get_targets(server_name): try: keys = get_server_keys(server_name, target, port) - print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) + print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port)) write_signing_keys(sys.stdout, keys.values()) break - except: + except Exception: logging.exception("Error talking to %s:%s", target, port) json_to_check = json.load(args.input_json) - print "Checking JSON:" + print("Checking JSON:") for key_id in json_to_check["signatures"][args.signature_name]: try: key = keys[key_id] verify_signed_json(json_to_check, args.signature_name, key) - print "PASS %s" % (key_id,) - except: + print("PASS %s" % (key_id,)) + except Exception: logging.exception("Check for key %s failed" % (key_id,)) - print "FAIL %s" % (key_id,) + print("FAIL %s" % (key_id,)) if __name__ == '__main__': main() - diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 151551f22c..dde8596697 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -1,13 +1,21 @@ -import psycopg2 -import yaml -import sys +import hashlib import json +import sys import time -import hashlib -from unpaddedbase64 import encode_base64 + +import six + +import psycopg2 +import yaml +from canonicaljson import encode_canonical_json from signedjson.key import read_signing_keys from signedjson.sign import sign_json -from canonicaljson import encode_canonical_json +from unpaddedbase64 import encode_base64 + +if six.PY2: + db_type = six.moves.builtins.buffer +else: + db_type = memoryview def select_v1_keys(connection): @@ -39,7 +47,9 @@ def select_v2_json(connection): cursor.close() results = {} for server_name, key_id, key_json in rows: - results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) + results.setdefault(server_name, {})[key_id] = json.loads( + str(key_json).decode("utf-8") + ) return results @@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate): return { "old_verify_keys": {}, "server_name": server_name, - "verify_keys": { - key_id: {"key": key} - for key_id, key in keys.items() - }, + "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()}, "valid_until_ts": valid_until, "tls_fingerprints": [fingerprint(certificate)], } @@ -65,7 +72,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) def main(): @@ -87,7 +94,7 @@ def main(): result = {} for server in keys: - if not server in json: + if server not in json: v2_json = convert_v1_to_v2( server, valid_until, keys[server], certificates[server] ) @@ -96,10 +103,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list( - row for server, json in result.items() - for row in rows_v2(server, json) - ) + rows = list(row for server, json in result.items() for row in rows_v2(server, json)) cursor = connection.cursor() cursor.executemany( @@ -107,7 +111,7 @@ def main(): " server_name, key_id, from_server," " ts_added_ms, ts_valid_until_ms, key_json" ") VALUES (%s, %s, %s, %s, %s, %s)", - rows + rows, ) connection.commit() diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 47dac7772d..1deb0fe2b7 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -1,8 +1,16 @@ #! /usr/bin/python +from __future__ import print_function + +import argparse import ast +import os +import re +import sys + import yaml + class DefinitionVisitor(ast.NodeVisitor): def __init__(self): super(DefinitionVisitor, self).__init__() @@ -42,15 +50,18 @@ def non_empty(defs): functions = {name: non_empty(f) for name, f in defs['def'].items()} classes = {name: non_empty(f) for name, f in defs['class'].items()} result = {} - if functions: result['def'] = functions - if classes: result['class'] = classes + if functions: + result['def'] = functions + if classes: + result['class'] = classes names = defs['names'] uses = [] for name in names.get('Load', ()): if name not in names.get('Param', ()) and name not in names.get('Store', ()): uses.append(name) uses.extend(defs['attrs']) - if uses: result['uses'] = uses + if uses: + result['uses'] = uses result['names'] = names result['attrs'] = defs['attrs'] return result @@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names): if __name__ == '__main__': - import sys, os, argparse, re parser = argparse.ArgumentParser(description='Find definitions.') parser.add_argument( @@ -105,24 +115,28 @@ if __name__ == '__main__': "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" ) parser.add_argument( - "--pattern", action="append", metavar="REGEXP", - help="Search for a pattern" + "--pattern", action="append", metavar="REGEXP", help="Search for a pattern" ) parser.add_argument( - "directories", nargs='+', metavar="DIR", - help="Directories to search for definitions" + "directories", + nargs='+', + metavar="DIR", + help="Directories to search for definitions", ) parser.add_argument( - "--referrers", default=0, type=int, - help="Include referrers up to the given depth" + "--referrers", + default=0, + type=int, + help="Include referrers up to the given depth", ) parser.add_argument( - "--referred", default=0, type=int, - help="Include referred down to the given depth" + "--referred", + default=0, + type=int, + help="Include referred down to the given depth", ) parser.add_argument( - "--format", default="yaml", - help="Output format, one of 'yaml' or 'dot'" + "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'" ) args = parser.parse_args() @@ -162,7 +176,7 @@ if __name__ == '__main__': for used_by in entry.get("used", ()): referrers.add(used_by) for name, definition in names.items(): - if not name in referrers: + if name not in referrers: continue if ignore and any(pattern.match(name) for pattern in ignore): continue @@ -176,7 +190,7 @@ if __name__ == '__main__': for uses in entry.get("uses", ()): referred.add(uses) for name, definition in names.items(): - if not name in referred: + if name not in referred: continue if ignore and any(pattern.match(name) for pattern in ignore): continue @@ -185,12 +199,12 @@ if __name__ == '__main__': if args.format == 'yaml': yaml.dump(result, sys.stdout, default_flow_style=False) elif args.format == 'dot': - print "digraph {" + print("digraph {") for name, entry in result.items(): - print name + print(name) for used_by in entry.get("used", ()): if used_by in result: - print used_by, "->", name - print "}" + print(used_by, "->", name) + print("}") else: raise ValueError("Unknown format %r" % (args.format)) diff --git a/scripts-dev/dump_macaroon.py b/scripts-dev/dump_macaroon.py index fcc5568835..22b30fa78e 100755 --- a/scripts-dev/dump_macaroon.py +++ b/scripts-dev/dump_macaroon.py @@ -1,8 +1,11 @@ #!/usr/bin/env python2 -import pymacaroons +from __future__ import print_function + import sys +import pymacaroons + if len(sys.argv) == 1: sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) sys.exit(1) @@ -11,14 +14,14 @@ macaroon_string = sys.argv[1] key = sys.argv[2] if len(sys.argv) > 2 else None macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) -print macaroon.inspect() +print(macaroon.inspect()) -print "" +print("") verifier = pymacaroons.Verifier() verifier.satisfy_general(lambda c: True) try: verifier.verify(macaroon, key) - print "Signature is correct" + print("Signature is correct") except Exception as e: - print str(e) + print(str(e)) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index d2acc7654d..2566ce7cef 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -18,21 +18,21 @@ from __future__ import print_function import argparse +import base64 +import json +import sys from urlparse import urlparse, urlunparse import nacl.signing -import json -import base64 import requests -import sys - -from requests.adapters import HTTPAdapter import srvlookup import yaml +from requests.adapters import HTTPAdapter # uncomment the following to enable debug logging of http requests -#from httplib import HTTPConnection -#HTTPConnection.debuglevel = 1 +# from httplib import HTTPConnection +# HTTPConnection.debuglevel = 1 + def encode_base64(input_bytes): """Encode bytes as a base64 string without any padding.""" @@ -58,15 +58,15 @@ def decode_base64(input_string): def encode_canonical_json(value): return json.dumps( - value, - # Encode code-points outside of ASCII as UTF-8 rather than \u escapes - ensure_ascii=False, - # Remove unecessary white space. - separators=(',',':'), - # Sort the keys of dictionaries. - sort_keys=True, - # Encode the resulting unicode as UTF-8 bytes. - ).encode("UTF-8") + value, + # Encode code-points outside of ASCII as UTF-8 rather than \u escapes + ensure_ascii=False, + # Remove unecessary white space. + separators=(',', ':'), + # Sort the keys of dictionaries. + sort_keys=True, + # Encode the resulting unicode as UTF-8 bytes. + ).encode("UTF-8") def sign_json(json_object, signing_key, signing_name): @@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name): NACL_ED25519 = "ed25519" + def decode_signing_key_base64(algorithm, version, key_base64): """Decode a base64 encoded signing key Args: @@ -143,14 +144,12 @@ def request_json(method, origin_name, origin_key, destination, path, content): authorization_headers = [] for key, sig in signed_json["signatures"][origin_name].items(): - header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - origin_name, key, sig, - ) + header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) authorization_headers.append(bytes(header)) - print ("Authorization: %s" % header, file=sys.stderr) + print("Authorization: %s" % header, file=sys.stderr) dest = "matrix://%s%s" % (destination, path) - print ("Requesting %s" % dest, file=sys.stderr) + print("Requesting %s" % dest, file=sys.stderr) s = requests.Session() s.mount("matrix://", MatrixConnectionAdapter()) @@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): result = s.request( method=method, url=dest, - headers={ - "Host": destination, - "Authorization": authorization_headers[0] - }, + headers={"Host": destination, "Authorization": authorization_headers[0]}, verify=False, data=content, ) @@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content): def main(): parser = argparse.ArgumentParser( - description= - "Signs and sends a federation request to a matrix homeserver", + description="Signs and sends a federation request to a matrix homeserver" ) parser.add_argument( - "-N", "--server-name", + "-N", + "--server-name", help="Name to give as the local homeserver. If unspecified, will be " - "read from the config file.", + "read from the config file.", ) parser.add_argument( - "-k", "--signing-key-path", + "-k", + "--signing-key-path", help="Path to the file containing the private ed25519 key to sign the " - "request with.", + "request with.", ) parser.add_argument( - "-c", "--config", + "-c", + "--config", default="homeserver.yaml", help="Path to server config file. Ignored if --server-name and " - "--signing-key-path are both given.", + "--signing-key-path are both given.", ) parser.add_argument( - "-d", "--destination", + "-d", + "--destination", default="matrix.org", help="name of the remote homeserver. We will do SRV lookups and " - "connect appropriately.", + "connect appropriately.", ) parser.add_argument( - "-X", "--method", + "-X", + "--method", help="HTTP method to use for the request. Defaults to GET if --data is" - "unspecified, POST if it is." + "unspecified, POST if it is.", ) - parser.add_argument( - "--body", - help="Data to send as the body of the HTTP request" - ) + parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument( - "path", - help="request path. We will add '/_matrix/federation/v1/' to this." + "path", help="request path. We will add '/_matrix/federation/v1/' to this." ) args = parser.parse_args() @@ -227,13 +223,15 @@ def main(): result = request_json( args.method, - args.server_name, key, args.destination, + args.server_name, + key, + args.destination, "/_matrix/federation/v1/" + args.path, content=args.body, ) json.dump(result, sys.stdout) - print ("") + print("") def read_args_from_config(args): @@ -253,7 +251,7 @@ class MatrixConnectionAdapter(HTTPAdapter): return s, 8448 if ":" in s: - out = s.rsplit(":",1) + out = s.rsplit(":", 1) try: port = int(out[1]) except ValueError: @@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter): try: srv = srvlookup.lookup("matrix", "tcp", s)[0] return srv.host, srv.port - except: + except Exception: return s, 8448 def get_connection(self, url, proxies=None): @@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter): (host, port) = self.lookup(parsed.netloc) netloc = "%s:%d" % (host, port) print("Connecting to %s" % (netloc,), file=sys.stderr) - url = urlunparse(( - "https", netloc, parsed.path, parsed.params, parsed.query, - parsed.fragment, - )) + url = urlunparse( + ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) + ) return super(MatrixConnectionAdapter, self).get_connection(url, proxies) diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index 616d6a10e7..514d80fa60 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -1,23 +1,31 @@ -from synapse.storage.pdu import PduStore -from synapse.storage.signatures import SignatureStore -from synapse.storage._base import SQLBaseStore -from synapse.federation.units import Pdu -from synapse.crypto.event_signing import ( - add_event_pdu_content_hash, compute_pdu_event_reference_hash -) -from synapse.api.events.utils import prune_pdu -from unpaddedbase64 import encode_base64, decode_base64 -from canonicaljson import encode_canonical_json +from __future__ import print_function + import sqlite3 import sys +from unpaddedbase64 import decode_base64, encode_base64 + +from synapse.crypto.event_signing import ( + add_event_pdu_content_hash, + compute_pdu_event_reference_hash, +) +from synapse.federation.units import Pdu +from synapse.storage._base import SQLBaseStore +from synapse.storage.pdu import PduStore +from synapse.storage.signatures import SignatureStore + + class Store(object): _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] - _get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] + _get_pdu_origin_signatures_txn = SignatureStore.__dict__[ + "_get_pdu_origin_signatures_txn" + ] _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] - _store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] + _store_pdu_reference_hash_txn = SignatureStore.__dict__[ + "_store_pdu_reference_hash_txn" + ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] @@ -26,9 +34,7 @@ store = Store() def select_pdus(cursor): - cursor.execute( - "SELECT pdu_id, origin FROM pdus ORDER BY depth ASC" - ) + cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC") ids = cursor.fetchall() @@ -41,23 +47,30 @@ def select_pdus(cursor): for pdu in pdus: try: if pdu.prev_pdus: - print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) for pdu_id, origin, hashes in pdu.prev_pdus: ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] hashes[ref_alg] = encode_base64(ref_hsh) - store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) - print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + store._store_prev_pdu_hash_txn( + cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh + ) + print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) pdu = add_event_pdu_content_hash(pdu) ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) - store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) + store._store_pdu_reference_hash_txn( + cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh + ) for alg, hsh_base64 in pdu.hashes.items(): - print alg, hsh_base64 - store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) + print(alg, hsh_base64) + store._store_pdu_content_hash_txn( + cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64) + ) + + except Exception: + print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus) - except: - print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus def main(): conn = sqlite3.connect(sys.argv[1]) @@ -65,5 +78,6 @@ def main(): select_pdus(cursor) conn.commit() -if __name__=='__main__': + +if __name__ == '__main__': main() diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 58d40c4ff4..da027be26e 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -1,18 +1,17 @@ #! /usr/bin/python -import ast import argparse +import ast import os import sys + import yaml PATTERNS_V1 = [] PATTERNS_V2 = [] -RESULT = { - "v1": PATTERNS_V1, - "v2": PATTERNS_V2, -} +RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2} + class CallVisitor(ast.NodeVisitor): def visit_Call(self, node): @@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor): else: return - if name == "client_path_patterns": PATTERNS_V1.append(node.args[0].s) elif name == "client_v2_patterns": @@ -42,8 +40,10 @@ def find_patterns_in_file(filepath): parser = argparse.ArgumentParser(description='Find url patterns.') parser.add_argument( - "directories", nargs='+', metavar="DIR", - help="Directories to search for definitions" + "directories", + nargs='+', + metavar="DIR", + help="Directories to search for definitions", ) args = parser.parse_args() diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py index 18be711e92..1a36b94038 100644 --- a/scripts-dev/tail-synapse.py +++ b/scripts-dev/tail-synapse.py @@ -1,8 +1,9 @@ -import requests import collections +import json import sys import time -import json + +import requests Entry = collections.namedtuple("Entry", "name position rows") @@ -30,11 +31,11 @@ def parse_response(content): def replicate(server, streams): - return parse_response(requests.get( - server + "/_synapse/replication", - verify=False, - params=streams - ).content) + return parse_response( + requests.get( + server + "/_synapse/replication", verify=False, params=streams + ).content + ) def main(): @@ -45,7 +46,7 @@ def main(): try: streams = { row.name: row.position - for row in replicate(server, {"streams":"-1"})["streams"].rows + for row in replicate(server, {"streams": "-1"})["streams"].rows } except requests.exceptions.ConnectionError as e: time.sleep(0.1) @@ -53,8 +54,8 @@ def main(): while True: try: results = replicate(server, streams) - except: - sys.stdout.write("connection_lost("+ repr(streams) + ")\n") + except Exception: + sys.stdout.write("connection_lost(" + repr(streams) + ")\n") break for update in results.values(): for row in update.rows: @@ -62,6 +63,5 @@ def main(): streams[update.name] = update.position - -if __name__=='__main__': +if __name__ == '__main__': main() diff --git a/scripts/hash_password b/scripts/hash_password index 215ab25cfe..a62bb5aa83 100755 --- a/scripts/hash_password +++ b/scripts/hash_password @@ -1,12 +1,10 @@ #!/usr/bin/env python import argparse - +import getpass import sys import bcrypt -import getpass - import yaml bcrypt_rounds=12 @@ -52,4 +50,3 @@ if __name__ == "__main__": password = prompt_for_pass() print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) - diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 7914ead889..e630936f78 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -36,12 +36,9 @@ from __future__ import print_function import argparse import logging - -import sys - import os - import shutil +import sys from synapse.rest.media.v1.filepath import MediaFilePaths @@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths): if not os.path.exists(original_file): logger.warn( "Original for %s/%s (%s) does not exist", - origin_server, file_id, original_file, + origin_server, + file_id, + original_file, ) else: mkdir_and_move( - original_file, - dest_paths.remote_media_filepath(origin_server, file_id), + original_file, dest_paths.remote_media_filepath(origin_server, file_id) ) # now look for thumbnails - original_thumb_dir = src_paths.remote_media_thumbnail_dir( - origin_server, file_id, - ) + original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) if not os.path.exists(original_thumb_dir): return mkdir_and_move( original_thumb_dir, - dest_paths.remote_media_thumbnail_dir(origin_server, file_id) + dest_paths.remote_media_thumbnail_dir(origin_server, file_id), ) @@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file): if __name__ == "__main__": parser = argparse.ArgumentParser( - description=__doc__, - formatter_class = argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "-v", action='store_true', help='enable debug logging') - parser.add_argument( - "src_repo", - help="Path to source content repo", - ) - parser.add_argument( - "dest_repo", - help="Path to source content repo", + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) + parser.add_argument("-v", action='store_true', help='enable debug logging') + parser.add_argument("src_repo", help="Path to source content repo") + parser.add_argument("dest_repo", help="Path to source content repo") args = parser.parse_args() logging_config = { "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", } logging.basicConfig(**logging_config) diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 91bdb3a25b..89143c5d59 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import argparse import getpass @@ -22,19 +23,23 @@ import hmac import json import sys import urllib2 + +from six import input + import yaml def request_registration(user, password, server_location, shared_secret, admin=False): req = urllib2.Request( "%s/_matrix/client/r0/admin/register" % (server_location,), - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, ) try: if sys.version_info[:3] >= (2, 7, 9): # As of version 2.7.9, urllib2 now checks SSL certs import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) else: f = urllib2.urlopen(req) @@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F f.close() nonce = json.loads(body)["nonce"] except urllib2.HTTPError as e: - print "ERROR! Received %d %s" % (e.code, e.reason,) + print("ERROR! Received %d %s" % (e.code, e.reason)) if 400 <= e.code < 500: if e.info().type == "application/json": resp = json.load(e) if "error" in resp: - print resp["error"] + print(resp["error"]) sys.exit(1) - mac = hmac.new( - key=shared_secret, - digestmod=hashlib.sha1, - ) + mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1) mac.update(nonce) mac.update("\x00") @@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F server_location = server_location.rstrip("/") - print "Sending registration request..." + print("Sending registration request...") req = urllib2.Request( "%s/_matrix/client/r0/admin/register" % (server_location,), data=json.dumps(data), - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, ) try: if sys.version_info[:3] >= (2, 7, 9): # As of version 2.7.9, urllib2 now checks SSL certs import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) else: f = urllib2.urlopen(req) f.read() f.close() - print "Success." + print("Success.") except urllib2.HTTPError as e: - print "ERROR! Received %d %s" % (e.code, e.reason,) + print("ERROR! Received %d %s" % (e.code, e.reason)) if 400 <= e.code < 500: if e.info().type == "application/json": resp = json.load(e) if "error" in resp: - print resp["error"] + print(resp["error"]) sys.exit(1) @@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin): if not user: try: default_user = getpass.getuser() - except: + except Exception: default_user = None if default_user: - user = raw_input("New user localpart [%s]: " % (default_user,)) + user = input("New user localpart [%s]: " % (default_user,)) if not user: user = default_user else: - user = raw_input("New user localpart: ") + user = input("New user localpart: ") if not user: - print "Invalid user name" + print("Invalid user name") sys.exit(1) if not password: password = getpass.getpass("Password: ") if not password: - print "Password cannot be blank." + print("Password cannot be blank.") sys.exit(1) confirm_password = getpass.getpass("Confirm password: ") if password != confirm_password: - print "Passwords do not match" + print("Passwords do not match") sys.exit(1) if admin is None: - admin = raw_input("Make admin [no]: ") + admin = input("Make admin [no]: ") if admin in ("y", "yes", "true"): admin = True else: @@ -146,42 +149,51 @@ def register_new_user(user, password, server_location, shared_secret, admin): if __name__ == "__main__": parser = argparse.ArgumentParser( description="Used to register new users with a given home server when" - " registration has been disabled. The home server must be" - " configured with the 'registration_shared_secret' option" - " set.", + " registration has been disabled. The home server must be" + " configured with the 'registration_shared_secret' option" + " set." ) parser.add_argument( - "-u", "--user", + "-u", + "--user", default=None, help="Local part of the new user. Will prompt if omitted.", ) parser.add_argument( - "-p", "--password", + "-p", + "--password", default=None, help="New password for user. Will prompt if omitted.", ) admin_group = parser.add_mutually_exclusive_group() admin_group.add_argument( - "-a", "--admin", + "-a", + "--admin", action="store_true", - help="Register new user as an admin. Will prompt if --no-admin is not set either.", + help=( + "Register new user as an admin. " + "Will prompt if --no-admin is not set either." + ), ) admin_group.add_argument( "--no-admin", action="store_true", - help="Register new user as a regular user. Will prompt if --admin is not set either.", + help=( + "Register new user as a regular user. " + "Will prompt if --admin is not set either." + ), ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( - "-c", "--config", + "-c", + "--config", type=argparse.FileType('r'), help="Path to server config file. Used to read in shared secret.", ) group.add_argument( - "-k", "--shared-secret", - help="Shared secret as defined in server config file.", + "-k", "--shared-secret", help="Shared secret as defined in server config file." ) parser.add_argument( @@ -189,7 +201,7 @@ if __name__ == "__main__": default="https://localhost:8448", nargs='?', help="URL to use to talk to the home server. Defaults to " - " 'https://localhost:8448'.", + " 'https://localhost:8448'.", ) args = parser.parse_args() @@ -198,7 +210,7 @@ if __name__ == "__main__": config = yaml.safe_load(args.config) secret = config.get("registration_shared_secret", None) if not secret: - print "No 'registration_shared_secret' defined in config." + print("No 'registration_shared_secret' defined in config.") sys.exit(1) else: secret = args.shared_secret diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b9b828c154..2f6e69e552 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -15,23 +15,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from twisted.enterprise import adbapi - -from synapse.storage._base import LoggingTransaction, SQLBaseStore -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database - import argparse import curses import logging import sys import time import traceback -import yaml from six import string_types +import yaml + +from twisted.enterprise import adbapi +from twisted.internet import defer, reactor + +from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("synapse_port_db") @@ -105,6 +105,7 @@ class Store(object): *All* database interactions should go through this object. """ + def __init__(self, db_pool, engine): self.db_pool = db_pool self.database_engine = engine @@ -135,7 +136,8 @@ class Store(object): txn = conn.cursor() return func( LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, **kwargs + *args, + **kwargs ) except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): @@ -158,22 +160,20 @@ class Store(object): def r(txn): txn.execute(sql, args) return txn.fetchall() + return self.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in headers), - ", ".join("%s" for _ in headers) + ", ".join("%s" for _ in headers), ) try: txn.executemany(sql, rows) - except: - logger.exception( - "Failed to insert: %s", - table, - ) + except Exception: + logger.exception("Failed to insert: %s", table) raise @@ -206,7 +206,7 @@ class Porter(object): "table_name": table, "forward_rowid": 1, "backward_rowid": 0, - } + }, ) forward_chunk = 1 @@ -221,10 +221,10 @@ class Porter(object): table, forward_chunk, backward_chunk ) else: + def delete_all(txn): txn.execute( - "DELETE FROM port_from_sqlite3 WHERE table_name = %s", - (table,) + "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) ) txn.execute("TRUNCATE %s CASCADE" % (table,)) @@ -232,11 +232,7 @@ class Porter(object): yield self.postgres_store._simple_insert( table="port_from_sqlite3", - values={ - "table_name": table, - "forward_rowid": 1, - "backward_rowid": 0, - } + values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) forward_chunk = 1 @@ -251,12 +247,16 @@ class Porter(object): ) @defer.inlineCallbacks - def handle_table(self, table, postgres_size, table_size, forward_chunk, - backward_chunk): + def handle_table( + self, table, postgres_size, table_size, forward_chunk, backward_chunk + ): logger.info( "Table %s: %i/%i (rows %i-%i) already ported", - table, postgres_size, table_size, - backward_chunk+1, forward_chunk-1, + table, + postgres_size, + table_size, + backward_chunk + 1, + forward_chunk - 1, ) if not table_size: @@ -271,7 +271,9 @@ class Porter(object): return if table in ( - "user_directory", "user_directory_search", "users_who_share_rooms", + "user_directory", + "user_directory_search", + "users_who_share_rooms", "users_in_pubic_room", ): # We don't port these tables, as they're a faff and we can regenreate @@ -283,37 +285,35 @@ class Porter(object): # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. yield self.postgres_store._simple_insert( - table=table, - values={"stream_id": None}, + table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done return forward_select = ( - "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" - % (table,) + "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) ) backward_select = ( - "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" - % (table,) + "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) ) do_forward = [True] do_backward = [True] while True: + def r(txn): forward_rows = [] backward_rows = [] if do_forward[0]: - txn.execute(forward_select, (forward_chunk, self.batch_size,)) + txn.execute(forward_select, (forward_chunk, self.batch_size)) forward_rows = txn.fetchall() if not forward_rows: do_forward[0] = False if do_backward[0]: - txn.execute(backward_select, (backward_chunk, self.batch_size,)) + txn.execute(backward_select, (backward_chunk, self.batch_size)) backward_rows = txn.fetchall() if not backward_rows: do_backward[0] = False @@ -325,9 +325,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction( - "select", r - ) + headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) if frows or brows: if frows: @@ -339,9 +337,7 @@ class Porter(object): rows = self._convert_rows(table, headers, rows) def insert(txn): - self.postgres_store.insert_many_txn( - txn, table, headers[1:], rows - ) + self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store._simple_update_one_txn( txn, @@ -362,8 +358,9 @@ class Porter(object): return @defer.inlineCallbacks - def handle_search_table(self, postgres_size, table_size, forward_chunk, - backward_chunk): + def handle_search_table( + self, postgres_size, table_size, forward_chunk, backward_chunk + ): select = ( "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" " FROM event_search as es" @@ -373,8 +370,9 @@ class Porter(object): ) while True: + def r(txn): - txn.execute(select, (forward_chunk, self.batch_size,)) + txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() headers = [column[0] for column in txn.description] @@ -402,18 +400,21 @@ class Porter(object): else: rows_dict.append(d) - txn.executemany(sql, [ - ( - row["event_id"], - row["room_id"], - row["key"], - row["sender"], - row["value"], - row["origin_server_ts"], - row["stream_ordering"], - ) - for row in rows_dict - ]) + txn.executemany( + sql, + [ + ( + row["event_id"], + row["room_id"], + row["key"], + row["sender"], + row["value"], + row["origin_server_ts"], + row["stream_ordering"], + ) + for row in rows_dict + ], + ) self.postgres_store._simple_update_one_txn( txn, @@ -437,7 +438,8 @@ class Porter(object): def setup_db(self, db_config, database_engine): db_conn = database_engine.module.connect( **{ - k: v for k, v in db_config.get("args", {}).items() + k: v + for k, v in db_config.get("args", {}).items() if not k.startswith("cp_") } ) @@ -450,13 +452,11 @@ class Porter(object): def run(self): try: sqlite_db_pool = adbapi.ConnectionPool( - self.sqlite_config["name"], - **self.sqlite_config["args"] + self.sqlite_config["name"], **self.sqlite_config["args"] ) postgres_db_pool = adbapi.ConnectionPool( - self.postgres_config["name"], - **self.postgres_config["args"] + self.postgres_config["name"], **self.postgres_config["args"] ) sqlite_engine = create_engine(sqlite_config) @@ -465,9 +465,7 @@ class Porter(object): self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) - yield self.postgres_store.execute( - postgres_engine.check_database - ) + yield self.postgres_store.execute(postgres_engine.check_database) # Step 1. Set up databases. self.progress.set_state("Preparing SQLite3") @@ -477,6 +475,7 @@ class Porter(object): self.setup_db(postgres_config, postgres_engine) self.progress.set_state("Creating port tables") + def create_port_table(txn): txn.execute( "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" @@ -501,9 +500,7 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction( - "alter_table", alter_table - ) + yield self.postgres_store.runInteraction("alter_table", alter_table) except Exception as e: pass @@ -514,11 +511,7 @@ class Porter(object): # Step 2. Get tables. self.progress.set_state("Fetching tables") sqlite_tables = yield self.sqlite_store._simple_select_onecol( - table="sqlite_master", - keyvalues={ - "type": "table", - }, - retcol="name", + table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) postgres_tables = yield self.postgres_store._simple_select_onecol( @@ -545,18 +538,14 @@ class Porter(object): # Step 4. Do the copying. self.progress.set_state("Copying to postgres") yield defer.gatherResults( - [ - self.handle_table(*res) - for res in setup_res - ], - consumeErrors=True, + [self.handle_table(*res) for res in setup_res], consumeErrors=True ) # Step 5. Do final post-processing yield self._setup_state_group_id_seq() self.progress.done() - except: + except Exception: global end_error_exec_info end_error_exec_info = sys.exc_info() logger.exception("") @@ -566,9 +555,7 @@ class Porter(object): def _convert_rows(self, table, headers, rows): bool_col_names = BOOLEAN_COLUMNS.get(table, []) - bool_cols = [ - i for i, h in enumerate(headers) if h in bool_col_names - ] + bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] class BadValueException(Exception): pass @@ -577,18 +564,21 @@ class Porter(object): if j in bool_cols: return bool(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) - raise BadValueException(); + logger.warn( + "DROPPING ROW: NUL value in table %s col %s: %r", + table, + headers[j], + col, + ) + raise BadValueException() return col outrows = [] for i, row in enumerate(rows): try: - outrows.append(tuple( - conv(j, col) - for j, col in enumerate(row) - if j > 0 - )) + outrows.append( + tuple(conv(j, col) for j, col in enumerate(row) if j > 0) + ) except BadValueException: pass @@ -616,9 +606,7 @@ class Porter(object): return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction( - "select", r, - ) + headers, rows = yield self.sqlite_store.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -639,7 +627,7 @@ class Porter(object): txn.execute( "SELECT rowid FROM sent_transactions WHERE ts >= ?" " ORDER BY rowid ASC LIMIT 1", - (yesterday,) + (yesterday,), ) rows = txn.fetchall() @@ -657,21 +645,17 @@ class Porter(object): "table_name": "sent_transactions", "forward_rowid": next_chunk, "backward_rowid": 0, - } + }, ) def get_sent_table_size(txn): txn.execute( - "SELECT count(*) FROM sent_transactions" - " WHERE ts >= ?", - (yesterday,) + "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) size, = txn.fetchone() return int(size) - remaining_count = yield self.sqlite_store.execute( - get_sent_table_size - ) + remaining_count = yield self.sqlite_store.execute(get_sent_table_size) total_count = remaining_count + inserted_rows @@ -680,13 +664,11 @@ class Porter(object): @defer.inlineCallbacks def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): frows = yield self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), - forward_chunk, + "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk ) brows = yield self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), - backward_chunk, + "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk ) defer.returnValue(frows[0][0] + brows[0][0]) @@ -694,7 +676,7 @@ class Porter(object): @defer.inlineCallbacks def _get_already_ported_count(self, table): rows = yield self.postgres_store.execute_sql( - "SELECT count(*) FROM %s" % (table,), + "SELECT count(*) FROM %s" % (table,) ) defer.returnValue(rows[0][0]) @@ -717,22 +699,21 @@ class Porter(object): def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0]+1 - txn.execute( - "ALTER SEQUENCE state_group_id_seq RESTART WITH %s", - (next_id,), - ) + next_id = txn.fetchone()[0] + 1 + txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) + return self.postgres_store.runInteraction("setup_state_group_id_seq", r) ############################################## -###### The following is simply UI stuff ###### +# The following is simply UI stuff ############################################## class Progress(object): """Used to report progress of the port """ + def __init__(self): self.tables = {} @@ -758,6 +739,7 @@ class Progress(object): class CursesProgress(Progress): """Reports progress to a curses window """ + def __init__(self, stdscr): self.stdscr = stdscr @@ -801,7 +783,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds,) + duration_str = '%02dm %02ds' % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -814,16 +796,12 @@ class CursesProgress(Progress): est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" - status = ( - "Time spent: %s (est. remaining: %s)" - % (duration_str, est_remaining_str,) + status = "Time spent: %s (est. remaining: %s)" % ( + duration_str, + est_remaining_str, ) - self.stdscr.addstr( - 0, 0, - status, - curses.A_BOLD, - ) + self.stdscr.addstr(0, 0, status, curses.A_BOLD) max_len = max([len(t) for t in self.tables.keys()]) @@ -831,9 +809,7 @@ class CursesProgress(Progress): middle_space = 1 items = self.tables.items() - items.sort( - key=lambda i: (i[1]["perc"], i[0]), - ) + items.sort(key=lambda i: (i[1]["perc"], i[0])) for i, (table, data) in enumerate(items): if i + 2 >= rows: @@ -844,9 +820,7 @@ class CursesProgress(Progress): color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) self.stdscr.addstr( - i + 2, left_margin + max_len - len(table), - table, - curses.A_BOLD | color, + i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color ) size = 20 @@ -857,15 +831,13 @@ class CursesProgress(Progress): ) self.stdscr.addstr( - i + 2, left_margin + max_len + middle_space, + i + 2, + left_margin + max_len + middle_space, "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), ) if self.finished: - self.stdscr.addstr( - rows - 1, 0, - "Press any key to exit...", - ) + self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") self.stdscr.refresh() self.last_update = time.time() @@ -877,29 +849,25 @@ class CursesProgress(Progress): def set_state(self, state): self.stdscr.clear() - self.stdscr.addstr( - 0, 0, - state + "...", - curses.A_BOLD, - ) + self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) self.stdscr.refresh() class TerminalProgress(Progress): """Just prints progress to the terminal """ + def update(self, table, num_done): super(TerminalProgress, self).update(table, num_done) data = self.tables[table] - print "%s: %d%% (%d/%d)" % ( - table, data["perc"], - data["num_done"], data["total"], + print( + "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) ) def set_state(self, state): - print state + "..." + print(state + "...") ############################################## @@ -909,34 +877,38 @@ class TerminalProgress(Progress): if __name__ == "__main__": parser = argparse.ArgumentParser( description="A script to port an existing synapse SQLite database to" - " a new PostgreSQL database." + " a new PostgreSQL database." ) parser.add_argument("-v", action='store_true') parser.add_argument( - "--sqlite-database", required=True, + "--sqlite-database", + required=True, help="The snapshot of the SQLite database file. This must not be" - " currently used by a running synapse server" + " currently used by a running synapse server", ) parser.add_argument( - "--postgres-config", type=argparse.FileType('r'), required=True, - help="The database config file for the PostgreSQL database" + "--postgres-config", + type=argparse.FileType('r'), + required=True, + help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', - help="display a curses based progress UI" + "--curses", action='store_true', help="display a curses based progress UI" ) parser.add_argument( - "--batch-size", type=int, default=1000, + "--batch-size", + type=int, + default=1000, help="The number of rows to select from the SQLite table each" - " iteration [default=1000]", + " iteration [default=1000]", ) args = parser.parse_args() logging_config = { "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", } if args.curses: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e3f0d99a3f..84f8fe2d9e 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -568,7 +568,7 @@ def run(hs): clock.call_later(5 * 60, start_phone_stats_home) if hs.config.daemonize and hs.config.print_pidfile: - print (hs.config.pid_file) + print(hs.config.pid_file) _base.start_reactor( "synapse-homeserver", diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index 8fccf573ee..79fe9c3dac 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -28,7 +28,7 @@ if __name__ == "__main__": sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) - print (getattr(config, key)) + print(getattr(config, key)) sys.exit(0) else: sys.stderr.write("Unknown command %r\n" % (action,)) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 3d2e90dd5b..14dae65ea0 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -106,10 +106,7 @@ class Config(object): @classmethod def check_file(cls, file_path, config_name): if file_path is None: - raise ConfigError( - "Missing config for %s." - % (config_name,) - ) + raise ConfigError("Missing config for %s." % (config_name,)) try: os.stat(file_path) except OSError as e: @@ -128,9 +125,7 @@ class Config(object): if e.errno != errno.EEXIST: raise if not os.path.isdir(dir_path): - raise ConfigError( - "%s is not a directory" % (dir_path,) - ) + raise ConfigError("%s is not a directory" % (dir_path,)) return dir_path @classmethod @@ -156,21 +151,20 @@ class Config(object): return results def generate_config( - self, - config_dir_path, - server_name, - is_generating_file, - report_stats=None, + self, config_dir_path, server_name, is_generating_file, report_stats=None ): default_config = "# vim:ft=yaml\n" - default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", - config_dir_path=config_dir_path, - server_name=server_name, - is_generating_file=is_generating_file, - report_stats=report_stats, - )) + default_config += "\n\n".join( + dedent(conf) + for conf in self.invoke_all( + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + is_generating_file=is_generating_file, + report_stats=report_stats, + ) + ) config = yaml.load(default_config) @@ -178,23 +172,22 @@ class Config(object): @classmethod def load_config(cls, description, argv): - config_parser = argparse.ArgumentParser( - description=description, - ) + config_parser = argparse.ArgumentParser(description=description) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Where files such as certs and signing keys are stored when" - " their location is given explicitly in the config." - " Defaults to the directory containing the last config file", + " their location is given explicitly in the config." + " Defaults to the directory containing the last config file", ) config_args = config_parser.parse_args(argv) @@ -203,9 +196,7 @@ class Config(object): obj = cls() obj.read_config_files( - config_files, - keys_directory=config_args.keys_directory, - generate_keys=False, + config_files, keys_directory=config_args.keys_directory, generate_keys=False ) return obj @@ -213,38 +204,38 @@ class Config(object): def load_or_generate_config(cls, description, argv): config_parser = argparse.ArgumentParser(add_help=False) config_parser.add_argument( - "-c", "--config-path", + "-c", + "--config-path", action="append", metavar="CONFIG_FILE", help="Specify config file. Can be given multiple times and" - " may specify directories containing *.yaml files." + " may specify directories containing *.yaml files.", ) config_parser.add_argument( "--generate-config", action="store_true", - help="Generate a config file for the server name" + help="Generate a config file for the server name", ) config_parser.add_argument( "--report-stats", action="store", help="Whether the generated config reports anonymized usage statistics", - choices=["yes", "no"] + choices=["yes", "no"], ) config_parser.add_argument( "--generate-keys", action="store_true", - help="Generate any missing key files then exit" + help="Generate any missing key files then exit", ) config_parser.add_argument( "--keys-directory", metavar="DIRECTORY", help="Used with 'generate-*' options to specify where files such as" - " certs and signing keys should be stored in, unless explicitly" - " specified in the config." + " certs and signing keys should be stored in, unless explicitly" + " specified in the config.", ) config_parser.add_argument( - "-H", "--server-name", - help="The server name to generate a config file for" + "-H", "--server-name", help="The server name to generate a config file for" ) config_args, remaining_args = config_parser.parse_known_args(argv) @@ -257,8 +248,8 @@ class Config(object): if config_args.generate_config: if config_args.report_stats is None: config_parser.error( - "Please specify either --report-stats=yes or --report-stats=no\n\n" + - MISSING_REPORT_STATS_SPIEL + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + MISSING_REPORT_STATS_SPIEL ) if not config_files: config_parser.error( @@ -287,26 +278,32 @@ class Config(object): config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), - is_generating_file=True + is_generating_file=True, ) obj.invoke_all("generate_files", config) config_file.write(config_str) - print(( - "A config file has been generated in %r for server name" - " %r with corresponding SSL keys and self-signed" - " certificates. Please review this file and customise it" - " to your needs." - ) % (config_path, server_name)) + print( + ( + "A config file has been generated in %r for server name" + " %r with corresponding SSL keys and self-signed" + " certificates. Please review this file and customise it" + " to your needs." + ) + % (config_path, server_name) + ) print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print(( - "Config file %r already exists. Generating any missing key" - " files." - ) % (config_path,)) + print( + ( + "Config file %r already exists. Generating any missing key" + " files." + ) + % (config_path,) + ) generate_keys = True parser = argparse.ArgumentParser( @@ -338,8 +335,7 @@ class Config(object): return obj - def read_config_files(self, config_files, keys_directory=None, - generate_keys=False): + def read_config_files(self, config_files, keys_directory=None, generate_keys=False): if not keys_directory: keys_directory = os.path.dirname(config_files[-1]) @@ -364,8 +360,9 @@ class Config(object): if "report_stats" not in config: raise ConfigError( - MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + - MISSING_REPORT_STATS_SPIEL + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + + "\n" + + MISSING_REPORT_STATS_SPIEL ) if generate_keys: @@ -399,16 +396,16 @@ def find_config_files(search_paths): for entry in os.listdir(config_path): entry_path = os.path.join(config_path, entry) if not os.path.isfile(entry_path): - print ( - "Found subdirectory in config directory: %r. IGNORING." - ) % (entry_path, ) + err = "Found subdirectory in config directory: %r. IGNORING." + print(err % (entry_path,)) continue if not entry.endswith(".yaml"): - print ( - "Found file in config directory that does not" - " end in '.yaml': %r. IGNORING." - ) % (entry_path, ) + err = ( + "Found file in config directory that does not end in " + "'.yaml': %r. IGNORING." + ) + print(err % (entry_path,)) continue files.append(entry_path) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index be61147b9b..d9d0255d0b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,7 +18,7 @@ import threading import time from six import PY2, iteritems, iterkeys, itervalues -from six.moves import intern, range +from six.moves import builtins, intern, range from canonicaljson import json from prometheus_client import Histogram @@ -1233,7 +1233,7 @@ def db_to_json(db_content): # psycopg2 on Python 2 returns buffer objects, which we need to cast to # bytes to decode - if PY2 and isinstance(db_content, buffer): + if PY2 and isinstance(db_content, builtins.buffer): db_content = bytes(db_content) # Decode it to a Unicode string before feeding it to json.loads, so we diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index a1331c1a61..8af17921e3 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index c7987bfcdd..2743b52bad 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -29,7 +29,7 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 5623391f6e..158e9dbe7b 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -27,7 +27,7 @@ from ._base import SQLBaseStore # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index a3032cdce9..d8bf953ec0 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index f2bde74dc5..625aedc940 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -15,6 +15,8 @@ import logging +from six import integer_types + from sortedcontainers import SortedDict from synapse.util import caches @@ -47,7 +49,7 @@ class StreamChangeCache(object): def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) is int or type(stream_pos) is long + assert type(stream_pos) in integer_types if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synctl b/synctl index bb8cb084cc..7e79b05c39 100755 --- a/synctl +++ b/synctl @@ -76,8 +76,7 @@ def start(configfile): try: subprocess.check_call(args) - write("started synapse.app.homeserver(%r)" % - (configfile,), colour=GREEN) + write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) except subprocess.CalledProcessError as e: write( "error starting (exit code: %d); see above for logs" % e.returncode, @@ -86,21 +85,15 @@ def start(configfile): def start_worker(app, configfile, worker_configfile): - args = [ - sys.executable, "-B", - "-m", app, - "-c", configfile, - "-c", worker_configfile - ] + args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile] try: subprocess.check_call(args) write("started %s(%r)" % (app, worker_configfile), colour=GREEN) except subprocess.CalledProcessError as e: write( - "error starting %s(%r) (exit code: %d); see above for logs" % ( - app, worker_configfile, e.returncode, - ), + "error starting %s(%r) (exit code: %d); see above for logs" + % (app, worker_configfile, e.returncode), colour=RED, ) @@ -120,9 +113,9 @@ def stop(pidfile, app): abort("Cannot stop %s: Unknown error" % (app,)) -Worker = collections.namedtuple("Worker", [ - "app", "configfile", "pidfile", "cache_factor", "cache_factors", -]) +Worker = collections.namedtuple( + "Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"] +) def main(): @@ -141,24 +134,20 @@ def main(): help="the homeserver config file, defaults to homeserver.yaml", ) parser.add_argument( - "-w", "--worker", - metavar="WORKERCONFIG", - help="start or stop a single worker", + "-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker" ) parser.add_argument( - "-a", "--all-processes", + "-a", + "--all-processes", metavar="WORKERCONFIGDIR", help="start or stop all the workers in the given directory" - " and the main synapse process", + " and the main synapse process", ) options = parser.parse_args() if options.worker and options.all_processes: - write( - 'Cannot use "--worker" with "--all-processes"', - stream=sys.stderr - ) + write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr) sys.exit(1) configfile = options.configfile @@ -167,9 +156,7 @@ def main(): write( "No config file found\n" "To generate a config file, run '%s -c %s --generate-config" - " --server-name='\n" % ( - " ".join(SYNAPSE), options.configfile - ), + " --server-name='\n" % (" ".join(SYNAPSE), options.configfile), stream=sys.stderr, ) sys.exit(1) @@ -194,8 +181,7 @@ def main(): worker_configfile = options.worker if not os.path.exists(worker_configfile): write( - "No worker config found at %r" % (worker_configfile,), - stream=sys.stderr, + "No worker config found at %r" % (worker_configfile,), stream=sys.stderr ) sys.exit(1) worker_configfiles.append(worker_configfile) @@ -211,9 +197,9 @@ def main(): stream=sys.stderr, ) sys.exit(1) - worker_configfiles.extend(sorted(glob.glob( - os.path.join(worker_configdir, "*.yaml") - ))) + worker_configfiles.extend( + sorted(glob.glob(os.path.join(worker_configdir, "*.yaml"))) + ) workers = [] for worker_configfile in worker_configfiles: @@ -223,14 +209,12 @@ def main(): if worker_app == "synapse.app.homeserver": # We need to special case all of this to pick up options that may # be set in the main config file or in this worker config file. - worker_pidfile = ( - worker_config.get("pid_file") - or pidfile + worker_pidfile = worker_config.get("pid_file") or pidfile + worker_cache_factor = ( + worker_config.get("synctl_cache_factor") or cache_factor ) - worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor worker_cache_factors = ( - worker_config.get("synctl_cache_factors") - or cache_factors + worker_config.get("synctl_cache_factors") or cache_factors ) daemonize = worker_config.get("daemonize") or config.get("daemonize") assert daemonize, "Main process must have daemonize set to true" @@ -239,19 +223,27 @@ def main(): for key in worker_config: if key == "worker_app": # But we allow worker_app continue - assert not key.startswith("worker_"), \ - "Main process cannot use worker_* config" + assert not key.startswith( + "worker_" + ), "Main process cannot use worker_* config" else: worker_pidfile = worker_config["worker_pid_file"] worker_daemonize = worker_config["worker_daemonize"] assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, "worker_daemonize") + worker_configfile, + "worker_daemonize", + ) worker_cache_factor = worker_config.get("synctl_cache_factor") worker_cache_factors = worker_config.get("synctl_cache_factors", {}) - workers.append(Worker( - worker_app, worker_configfile, worker_pidfile, worker_cache_factor, - worker_cache_factors, - )) + workers.append( + Worker( + worker_app, + worker_configfile, + worker_pidfile, + worker_cache_factor, + worker_cache_factors, + ) + ) action = options.action diff --git a/tox.ini b/tox.ini index 87b5e4782d..04d2f721bf 100644 --- a/tox.ini +++ b/tox.ini @@ -108,10 +108,10 @@ commands = [testenv:pep8] skip_install = True -basepython = python2.7 +basepython = python3.6 deps = flake8 -commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" +commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" [testenv:check_isort] skip_install = True -- cgit 1.4.1 From 6105c6101fa0f4560daf7d23dedcfd217530b02a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 23 Oct 2018 15:24:58 +0100 Subject: fix race condiftion in calling initialise_reserved_users --- changelog.d/4081.bugfix | 2 ++ synapse/app/homeserver.py | 8 ------ synapse/storage/monthly_active_users.py | 46 +++++++++++++++++++++--------- synapse/storage/registration.py | 16 ++++++++--- tests/storage/test_monthly_active_users.py | 10 +++++-- 5 files changed, 55 insertions(+), 27 deletions(-) create mode 100644 changelog.d/4081.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix new file mode 100644 index 0000000000..f275acb61c --- /dev/null +++ b/changelog.d/4081.bugfix @@ -0,0 +1,2 @@ +Fix race condition in populating reserved users + diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 0b85b377e3..593e1e75db 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -553,14 +553,6 @@ def run(hs): generate_monthly_active_users, ) - # XXX is this really supposed to be a background process? it looks - # like it needs to complete before some of the other stuff runs. - run_as_background_process( - "initialise_reserved_users", - hs.get_datastore().initialise_reserved_users, - hs.config.mau_limits_reserved_threepids, - ) - start_generate_monthly_active_users() if hs.config.limit_usage_by_mau: clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 0fe8c8e24c..26e577814a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,19 +33,28 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () + self.initialise_reserved_users( + dbconn.cursor(), hs.config.mau_limits_reserved_threepids + ) - @defer.inlineCallbacks - def initialise_reserved_users(self, threepids): - store = self.hs.get_datastore() + def initialise_reserved_users(self, txn, threepids): + """ + Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Arguments: + threepids []: List of threepid dicts to reserve + """ reserved_user_list = [] # Do not add more reserved users than the total allowable number for tp in threepids[:self.hs.config.max_mau_value]: - user_id = yield store.get_user_id_by_threepid( + user_id = self.get_user_id_by_threepid_txn( + txn, tp["medium"], tp["address"] ) if user_id: - yield self.upsert_monthly_active_user(user_id) + self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: logger.warning( @@ -55,8 +64,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def reap_monthly_active_users(self): - """ - Cleans out monthly active user table to ensure that no stale + """Cleans out monthly active user table to ensure that no stale entries exist. Returns: @@ -165,19 +173,33 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): + """Updates or inserts monthly active user member + Arguments: + user_id (str): user to add/update + """ + is_insert = yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, + user_id + ) + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + def upsert_monthly_active_user_txn(self, txn, user_id): """ Updates or inserts monthly active user member Arguments: + txn (cursor): user_id (str): user to add/update - Deferred[bool]: True if a new entry was created, False if an + bool: True if a new entry was created, False if an existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = yield self._simple_upsert( - desc="upsert_monthly_active_user", + is_insert = self._simple_upsert_txn( + txn, table="monthly_active_users", keyvalues={ "user_id": user_id, @@ -186,9 +208,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): "timestamp": int(self._clock.time_msec()), }, ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + return is_insert @cached(num_args=1) def user_last_seen_monthly_active(self, user_id): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26b429e307..01931f29cc 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -474,17 +474,25 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): - ret = yield self._simple_select_one( + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, + medium, address + ) + defer.returnValue(user_id) + + def get_user_id_by_threepid_txn(self, txn, medium, address): + ret = self._simple_select_one_txn( + txn, "user_threepids", { "medium": medium, "address": address }, - ['user_id'], True, 'get_user_id_by_threepid' + ['user_id'], True ) if ret: - defer.returnValue(ret['user_id']) - defer.returnValue(None) + return ret['user_id'] + return None def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 686f12a0dc..0c17745ae7 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -52,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.initialise_reserved_users(threepids) + + self.store.runInteraction( + "initialise", self.store.initialise_reserved_users, threepids + ) self.pump() active_count = self.store.get_monthly_active_count() @@ -199,7 +202,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): {'medium': 'email', 'address': user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.initialise_reserved_users(threepids) + self.store.runInteraction( + "initialise", self.store.initialise_reserved_users, threepids + ) + self.pump() count = self.store.get_registered_reserved_users_count() self.assertEquals(self.get_success(count), 0) -- cgit 1.4.1 From 329d18b39cbc1bc9eba8ae9de1fcf734d0cf1a78 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 23 Oct 2018 15:27:20 +0100 Subject: remove white space --- changelog.d/4081.bugfix | 1 - synapse/storage/monthly_active_users.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix index f275acb61c..13dad58842 100644 --- a/changelog.d/4081.bugfix +++ b/changelog.d/4081.bugfix @@ -1,2 +1 @@ Fix race condition in populating reserved users - diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 26e577814a..cf15f8c5ba 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -38,8 +38,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): ) def initialise_reserved_users(self, txn, threepids): - """ - Ensures that reserved threepids are accounted for in the MAU table, should + """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Arguments: -- cgit 1.4.1 From ef771cc4c2d988e5188ba7e75df9adebb0ebafe1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Oct 2018 10:35:01 +0100 Subject: Fix a number of flake8 errors Broadly three things here: * disable W504 which seems a bit whacko * remove a bunch of `as e` expressions from exception handlers that don't use them * use `r""` for strings which include backslashes Also, we don't use pep8 any more, so we can get rid of the duplicate config there. --- changelog.d/4082.misc | 1 + scripts-dev/tail-synapse.py | 2 +- scripts/synapse_port_db | 3 ++- setup.cfg | 17 ++++++++--------- synapse/config/repository.py | 2 +- synapse/crypto/keyclient.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/storage/registration.py | 2 +- tests/config/test_generate.py | 2 +- tests/events/test_utils.py | 4 ++-- 12 files changed, 21 insertions(+), 20 deletions(-) create mode 100644 changelog.d/4082.misc (limited to 'synapse/storage') diff --git a/changelog.d/4082.misc b/changelog.d/4082.misc new file mode 100644 index 0000000000..a81faf5e9b --- /dev/null +++ b/changelog.d/4082.misc @@ -0,0 +1 @@ +Clean up some bits of code which were flagged by the linter diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py index 1a36b94038..7c9985d9f0 100644 --- a/scripts-dev/tail-synapse.py +++ b/scripts-dev/tail-synapse.py @@ -48,7 +48,7 @@ def main(): row.name: row.position for row in replicate(server, {"streams": "-1"})["streams"].rows } - except requests.exceptions.ConnectionError as e: + except requests.exceptions.ConnectionError: time.sleep(0.1) while True: diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 2f6e69e552..3c7b606323 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -501,7 +501,8 @@ class Porter(object): try: yield self.postgres_store.runInteraction("alter_table", alter_table) - except Exception as e: + except Exception: + # On Error Resume Next pass yield self.postgres_store.runInteraction( diff --git a/setup.cfg b/setup.cfg index 52feaa9cc7..b6b4aa740d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,17 +14,16 @@ ignore = pylint.cfg tox.ini -[pep8] -max-line-length = 90 -# W503 requires that binary operators be at the end, not start, of lines. Erik -# doesn't like it. E203 is contrary to PEP8. E731 is silly. -ignore = W503,E203,E731 - [flake8] -# note that flake8 inherits the "ignore" settings from "pep8" (because it uses -# pep8 to do those checks), but not the "max-line-length" setting max-line-length = 90 -ignore=W503,E203,E731 + +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# E731: do not assign a lambda expression, use a def +ignore=W503,W504,E203,E731 [isort] line_length = 89 diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fc909c1fac..06c62ab62c 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,7 +178,7 @@ class ContentRepositoryConfig(Config): def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") - return """ + return r""" # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 57d4665e84..080c81f14b 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -55,7 +55,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): raise IOError("Cannot get key for %r" % server_name) except (ConnectError, DomainError) as e: logger.warn("Error getting key for %r: %s", server_name, e) - except Exception as e: + except Exception: logger.exception("Error getting key for %r", server_name) raise IOError("Cannot get key for %r" % server_name) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4efe95faa4..af0107a46e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -800,7 +800,7 @@ class FederationHandlerRegistry(object): yield handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: + except Exception: logger.exception("Failed to handle edu %r", edu_type) def on_query(self, query_type, args): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index bd8b5f4afa..693b303881 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -99,7 +99,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERNS = client_v2_patterns("/auth/(?P[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8c892ff187..1a7bfd6b56 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -674,7 +674,7 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # This splits the paragraph into words, but keeping the # (preceeding) whitespace intact so we can easily concat # words back together. - for match in re.finditer("\s*\S+", description): + for match in re.finditer(r"\s*\S+", description): word = match.group() # Keep adding words while the total length is less than diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26b429e307..2dd14aba1c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -567,7 +567,7 @@ class RegistrationStore(RegistrationWorkerStore, def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - regex = re.compile("^@(\d+):") + regex = re.compile(r"^@(\d+):") found = set() diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index f88d28a19d..0c23068bcf 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase): with open(log_config_file) as f: config = f.read() # find the 'filename' line - matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) + matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M) self.assertEqual(1, len(matches)) self.assertEqual(matches[0], expected) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index ff217ca8b9..d0cc492deb 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase): room_id="!foo:bar", content={"key.with.dots": {}}, ), - ["content.key\.with\.dots"], + [r"content.key\.with\.dots"], ), {"content": {"key.with.dots": {}}}, ) @@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase): "nested.dot.key": {"leaf.key": 42, "not_me_either": 1}, }, ), - ["content.nested\.dot\.key.leaf\.key"], + [r"content.nested\.dot\.key.leaf\.key"], ), {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) -- cgit 1.4.1 From ea69a84bbb2fc9c1da6db0384a98f577f3bc95a7 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 24 Oct 2018 17:18:08 +0100 Subject: fix style inconsistencies --- changelog.d/4081.bugfix | 3 ++- synapse/storage/monthly_active_users.py | 43 +++++++++++++++++++----------- synapse/storage/registration.py | 19 +++++++++++++ tests/storage/test_monthly_active_users.py | 4 +-- 4 files changed, 51 insertions(+), 18 deletions(-) (limited to 'synapse/storage') diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix index 13dad58842..cfe4b3e9d9 100644 --- a/changelog.d/4081.bugfix +++ b/changelog.d/4081.bugfix @@ -1 +1,2 @@ -Fix race condition in populating reserved users +Fix race condition where config defined reserved users were not being added to +the monthly active user list prior to the homeserver reactor firing up diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index cf15f8c5ba..9a5c3b7eda 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,21 +33,23 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () - self.initialise_reserved_users( - dbconn.cursor(), hs.config.mau_limits_reserved_threepids + # Do not add more reserved users than the total allowable number + self._initialise_reserved_users( + dbconn.cursor(), + hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], ) - def initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users(self, txn, threepids): """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. - Arguments: - threepids []: List of threepid dicts to reserve + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve """ reserved_user_list = [] - # Do not add more reserved users than the total allowable number - for tp in threepids[:self.hs.config.max_mau_value]: + for tp in threepids: user_id = self.get_user_id_by_threepid_txn( txn, tp["medium"], tp["address"] @@ -172,26 +174,36 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): - """Updates or inserts monthly active user member - Arguments: + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: user_id (str): user to add/update """ is_insert = yield self.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) + # Considered pushing cache invalidation down into txn method, but + # did not because txn is not a LoggingTransaction. This means I could not + # call txn.call_after(). Therefore cache is altered in background thread + # and calls from elsewhere to user_last_seen_monthly_active and + # get_monthly_active_count fail with ValueError in + # synapse/util/caches/descriptors.py#check_thread if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) def upsert_monthly_active_user_txn(self, txn, user_id): - """ - Updates or inserts monthly active user member - Arguments: - txn (cursor): - user_id (str): user to add/update + """Updates or inserts monthly active user member + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: bool: True if a new entry was created, False if an - existing one was updated. + existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -207,6 +219,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): "timestamp": int(self._clock.time_msec()), }, ) + return is_insert @cached(num_args=1) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0f970850e8..80d76bf9d7 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -474,6 +474,15 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ user_id = yield self.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address @@ -481,6 +490,16 @@ class RegistrationStore(RegistrationWorkerStore, defer.returnValue(user_id) def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ ret = self._simple_select_one_txn( txn, "user_threepids", diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 0c17745ae7..832e379a83 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -54,7 +54,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.runInteraction( - "initialise", self.store.initialise_reserved_users, threepids + "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -203,7 +203,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): ] self.hs.config.mau_limits_reserved_threepids = threepids self.store.runInteraction( - "initialise", self.store.initialise_reserved_users, threepids + "initialise", self.store._initialise_reserved_users, threepids ) self.pump() -- cgit 1.4.1 From f8fe98812be74e37432f11508c07488079bf949d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 25 Oct 2018 14:58:59 +0100 Subject: improve comments --- synapse/storage/monthly_active_users.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 9a5c3b7eda..01963c879f 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -184,18 +184,18 @@ class MonthlyActiveUsersStore(SQLBaseStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - # Considered pushing cache invalidation down into txn method, but - # did not because txn is not a LoggingTransaction. This means I could not - # call txn.call_after(). Therefore cache is altered in background thread - # and calls from elsewhere to user_last_seen_monthly_active and - # get_monthly_active_count fail with ValueError in - # synapse/util/caches/descriptors.py#check_thread + if is_insert: self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. Args: txn (cursor): -- cgit 1.4.1 From fcbd488e9a60a70a81cd7a94a6adebe6e06c525a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 25 Oct 2018 16:13:43 +0100 Subject: add new line --- synapse/storage/monthly_active_users.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 01963c879f..cf4104dc2e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -191,6 +191,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member + Note that, after calling this method, it will generally be necessary to invalidate the caches on user_last_seen_monthly_active and get_monthly_active_count. We can't do that here, because we are running -- cgit 1.4.1 From cb53ce9d6429252d5ee012f5a476cc834251c27d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Oct 2018 17:49:55 +0100 Subject: Refactor state group lookup to reduce DB hits (#4011) Currently when fetching state groups from the data store we make two hits two the database: once for members and once for non-members (unless request is filtered to one or the other). This adds needless load to the datbase, so this PR refactors the lookup to make only a single database hit. --- changelog.d/4011.misc | 1 + synapse/handlers/initial_sync.py | 4 +- synapse/handlers/message.py | 20 +- synapse/handlers/pagination.py | 15 +- synapse/handlers/room.py | 22 +- synapse/handlers/sync.py | 97 ++--- synapse/rest/client/v1/room.py | 3 +- synapse/storage/events.py | 2 +- synapse/storage/state.py | 845 ++++++++++++++++++++++++--------------- synapse/visibility.py | 15 +- tests/storage/test_state.py | 175 +++++--- 11 files changed, 714 insertions(+), 485 deletions(-) create mode 100644 changelog.d/4011.misc (limited to 'synapse/storage') diff --git a/changelog.d/4011.misc b/changelog.d/4011.misc new file mode 100644 index 0000000000..ad7768c4cd --- /dev/null +++ b/changelog.d/4011.misc @@ -0,0 +1 @@ +Reduce database load when fetching state groups diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e009395207..563bb3cea3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler): room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( self.store.get_state_for_events, - [event.event_id], None, + [event.event_id], ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted(self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking): room_state = yield self.store.get_state_for_events( - [member_event_id], None + [member_event_id], ) room_state = room_state[member_event_id] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6c4fcfb10a..969e588e73 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -80,7 +81,7 @@ class MessageHandler(object): elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [membership_event_id], [key] + [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -88,7 +89,7 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, types=None, filtered_types=None, + self, user_id, room_id, state_filter=StateFilter.all(), at_token=None, is_guest=False, ): """Retrieve all state events for a given room. If the user is @@ -100,13 +101,8 @@ class MessageHandler(object): Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. at_token(StreamToken|None): the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current @@ -139,7 +135,7 @@ class MessageHandler(object): event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], types, filtered_types=filtered_types, + [event.event_id], state_filter=state_filter, ) room_state = room_state[event.event_id] else: @@ -158,12 +154,12 @@ class MessageHandler(object): if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, types, filtered_types=filtered_types, + room_id, state_filter=state_filter, ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], types, filtered_types=filtered_types, + [membership_event_id], state_filter=state_filter, ) room_state = room_state[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a155b6e938..43f81bd607 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -21,6 +21,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background @@ -255,16 +256,14 @@ class PaginationHandler(object): if event_filter and event_filter.lazy_load_members(): # TODO: remove redundant members - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in events - ) - ] + # FIXME: we also care about invite targets etc. + state_filter = StateFilter.from_types( + (EventTypes.Member, event.sender) + for event in events + ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, types=types, + events[0].event_id, state_filter=state_filter, ) if state_ids: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ab1571b27b..3ba92bdb4c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from synapse.api.constants import ( RoomCreationPreset, ) from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils from synapse.visibility import filter_events_for_client @@ -489,23 +490,24 @@ class RoomContextHandler(object): else: last_event_id = event_id - types = None - filtered_types = None if event_filter and event_filter.lazy_load_members(): - members = set(ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], - )) - filtered_types = [EventTypes.Member] - types = [(EventTypes.Member, member) for member in members] + state_filter = StateFilter.from_lazy_load_member_list( + ev.sender + for ev in itertools.chain( + results["events_before"], + (results["event"],), + results["events_after"], + ) + ) + else: + state_filter = StateFilter.all() # XXX: why do we return the state as of the last event rather than the # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], types, filtered_types=filtered_types, + [last_event_id], state_filter=state_filter, ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 351892a94f..09739f2862 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -27,6 +27,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary +from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -469,25 +470,20 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event, types=None, filtered_types=None): + def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event Args: event(synapse.events.EventBase): event of interest - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, types, filtered_types=filtered_types, + event.event_id, state_filter=state_filter, ) if event.is_state(): state_ids = state_ids.copy() @@ -495,18 +491,14 @@ class SyncHandler(object): defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): + def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): """ Get the room state at a particular stream position Args: room_id(str): room for which to get state stream_position(StreamToken): point at which to get state - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A Deferred map from ((type, state_key)->Event) @@ -522,7 +514,7 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, types, filtered_types=filtered_types, + last_event, state_filter=state_filter, ) else: @@ -563,10 +555,11 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( - last_event.event_id, [ + last_event.event_id, + state_filter=StateFilter.from_types([ (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), - ] + ]), ) # this is heavily cached, thus: fast. @@ -717,8 +710,7 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): - types = None - filtered_types = None + members_to_fetch = None lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -729,16 +721,21 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - types = [ - (EventTypes.Member, state_key) - for state_key in set( - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - ) - ] + members_to_fetch = set( + event.sender # FIXME: we also care about invite targets etc. + for event in batch.events + ) + + if full_state: + # always make sure we LL ourselves so we know we're in the room + # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in #3840. + members_to_fetch.add(sync_config.user.to_string()) - # only apply the filtering to room members - filtered_types = [EventTypes.Member] + state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + else: + state_filter = StateFilter.all() timeline_state = { (event.type, event.state_key): event.event_id @@ -746,28 +743,19 @@ class SyncHandler(object): } if full_state: - if lazy_load_members: - # always make sure we LL ourselves so we know we're in the room - # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 - # We only need apply this on full state syncs given we disabled - # LL for incr syncs in #3840. - types.append((EventTypes.Member, sync_config.user.to_string())) - if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=now_token, + state_filter=state_filter, ) state_ids = current_state_ids @@ -781,8 +769,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, + batch.events[0].event_id, state_filter=state_filter, ) # for now, we disable LL for gappy syncs - see @@ -797,17 +784,15 @@ class SyncHandler(object): # members to just be ones which were timeline senders, which then ensures # all of the rest get included in the state block (if we need to know # about them). - types = None - filtered_types = None + state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, types=types, - filtered_types=filtered_types, + room_id, stream_position=since_token, + state_filter=state_filter, ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, types=types, - filtered_types=filtered_types, + batch.events[-1].event_id, state_filter=state_filter, ) state_ids = _calculate_state( @@ -821,7 +806,7 @@ class SyncHandler(object): else: state_ids = {} if lazy_load_members: - if types and batch.events: + if members_to_fetch and batch.events: # We're returning an incremental sync, with no # "gap" since the previous sync, so normally there would be # no state to return. @@ -831,8 +816,12 @@ class SyncHandler(object): # timeline here, and then dedupe any redundant ones below. state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=None, # we only want members! + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), ) if lazy_load_members and not include_redundant_members: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 663934efd0..fcfe7857f6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -33,6 +33,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID @@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, - types=[(EventTypes.Member, None)], + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) chunk = [] diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c780f55277..8881b009df 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2089,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], types=None + txn, [sg], ) curr_state = curr_state[sg] diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3f4cbd61c4..ef65929bb2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from collections import namedtuple from six import iteritems, itervalues from six.moves import range +import attr + from twisted.internet import defer from synapse.api.constants import EventTypes @@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 +@attr.s(slots=True) +class StateFilter(object): + """A filter used when querying for state. + + Attributes: + types (dict[str, set[str]|None]): Map from type to set of state keys (or + None). This specifies which state_keys for the given type to fetch + from the DB. If None then all events with that type are fetched. If + the set is empty then no events with that type are fetched. + include_others (bool): Whether to fetch events with types that do not + appear in `types`. + """ + + types = attr.ib() + include_others = attr.ib(default=False) + + def __attrs_post_init__(self): + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + self.types = { + k: v for k, v in iteritems(self.types) + if v is not None + } + + @staticmethod + def all(): + """Creates a filter that fetches everything. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=True) + + @staticmethod + def none(): + """Creates a filter that fetches nothing. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=False) + + @staticmethod + def from_types(types): + """Creates a filter that only fetches the given types + + Args: + types (Iterable[tuple[str, str|None]]): A list of type and state + keys to fetch. A state_key of None fetches everything for + that type + + Returns: + StateFilter + """ + type_dict = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) + + return StateFilter(types=type_dict) + + @staticmethod + def from_lazy_load_member_list(members): + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members (iterable[str]): Set of user IDs + + Returns: + StateFilter + """ + return StateFilter( + types={EventTypes.Member: set(members)}, + include_others=True, + ) + + def return_expanded(self): + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + StateFilter + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in iteritems(self.types) + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + else: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types={EventTypes.Member: self.types[EventTypes.Member]}, + include_others=True, + ) + + def make_sql_filter_clause(self): + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + tuple[str, list]: The SQL string (may be empty) and arguments. An + empty SQL string is returned when the filter matches everything + (i.e. is "full"). + """ + + where_clause = "" + where_args = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in iteritems(self.types): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % ( + ",".join(["?"] * len(self.types)), + ) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self): + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict): + """Returns the state filtered with by this StateFilter + + Args: + state (dict[tuple[str, str], Any]): The state map to filter + + Returns: + dict[tuple[str, str], Any]: The filtered state map + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in iteritems(state_dict): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self): + """Whether this filter fetches everything or not + + Returns: + bool + """ + return self.include_others and not self.types + + def has_wildcards(self): + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + bool + """ + + return ( + self.include_others + or any( + state_keys is None + for state_keys in itervalues(self.types) + ) + ) + + def concrete_types(self): + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + list[tuple[str,str]] + """ + return [ + (t, s) + for t, state_keys in iteritems(self.types) + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self): + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + tuple[StateFilter, StateFilter]: The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter({EventTypes.Member: state_keys}) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. @@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + Args: room_id (str) - types (list[(Str, (Str|None))]): List of (type, state_key) tuples - which are used to filter the state fetched. `state_key` may be - None, which matches any `state_key` - filtered_types (list[Str]|None): List of types to apply the above filter to. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: - deferred: dict of (type, state_key) -> event + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. """ - include_other_types = False if filtered_types is None else True - def _get_filtered_current_state_ids_txn(txn): results = {} - sql = """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? %s""" - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - - if include_other_types: - unique_types = set(filtered_types) - clause_to_args.append( - ( - "AND type <> ? " * len(unique_types), - list(unique_types) - ) - ) - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - for where_clause, where_args in clause_to_args: - args = [room_id] - args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results return self.runInteraction( @@ -322,20 +616,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types, members=None): + def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. Args: groups(list[int]): list of state group IDs to query - types (Iterable[str, str|None]|None): list of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. If None, all types are returned. - members (bool|None): If not None, then, in addition to any filtering - implied by types, the results are also filtered to only include - member events (if True), or to exclude member events (if False) - - Returns: + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) @@ -346,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, members, + self._get_state_groups_from_groups_txn, chunk, state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, members=None, + self, txn, groups, state_filter=StateFilter.all(), ): results = {group: {} for group in groups} - if types is not None: - types = list(set(types)) # deduplicate types list + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -374,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT type, state_key, last_value(event_id) OVER ( + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( PARTITION BY type, state_key ORDER BY state_group ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS event_id FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state ) - %s - """) + """ - if members is True: - sql += " AND type = '%s'" % (EventTypes.Member,) - elif members is False: - sql += " AND type <> '%s'" % (EventTypes.Member,) - - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types is not None: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - - for where_clause, where_args in clause_to_args: - for group in groups: - args = [group] - args.extend(where_args) + for group in groups: + args = [group] + args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: - where_args = [] - where_clauses = [] - wildcard_types = False - if types is not None: - for typ in types: - if typ[1] is None: - where_clauses.append("(type = ?)") - where_args.append(typ[0]) - wildcard_types = True - else: - where_clauses.append("(type = ? AND state_key = ?)") - where_args.extend([typ[0], typ[1]]) - - where_clause = "AND (%s)" % (" OR ".join(where_clauses)) - else: - where_clause = "" - - if members is True: - where_clause += " AND type = '%s'" % EventTypes.Member - elif members is False: - where_clause += " AND type <> '%s'" % EventTypes.Member + max_entries_returned = state_filter.max_entries_returned() # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) @@ -460,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # without the right indices (which we can't add until # after we finish deduping state, which requires this func) args = [next_group] - if types: - args.extend(where_args) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? %s" % (where_clause,), + " WHERE state_group = ? " + where_clause, args ) results[group].update( @@ -481,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - types is not None and - not wildcard_types and - len(results[group]) == len(types) + max_entries_returned is not None and + len(results[group]) == max_entries_returned ): break @@ -498,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types, filtered_types=None): + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state - dicts for each event. The state dicts will only have the type/state_keys - that are in the `types` list. + dicts for each event. Args: event_ids (list[string]) - types (list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -521,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -540,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from event_id -> (type, state_key) -> event_id @@ -563,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) event_to_state = { event_id: group_to_state[group] @@ -573,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None, filtered_types=None): + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_ids_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -642,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): + def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` Args: cache(DictionaryCache): the state group cache to use group(int): The state group to lookup - types(list[str, str|None]): List of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool indicating if we successfully retrieved all @@ -662,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ is_all, known_absent, state_dict_ids = cache.get(group) - type_to_key = {} + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all # tracks whether any of our requested types are missing from the cache missing_types = False - for typ, state_key in types: - key = (typ, state_key) - - if ( - state_key is None or - (filtered_types is not None and typ not in filtered_types) - ): - type_to_key[typ] = None - # we mark the type as missing from the cache because - # when the cache was populated it might have been done with a - # restricted set of state_keys, so the wildcard will not work - # and the cache may be incomplete. - missing_types = True - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): if key not in state_dict_ids and key not in known_absent: missing_types = True + break - sentinel = object() - - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return filtered_types is not None and typ not in filtered_types - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True - return False - - got_all = is_all - if not got_all: - # the cache is incomplete. We may still have got all the results we need, if - # we don't have any wildcards in the match list. - if not missing_types and filtered_types is None: - got_all = True - - return { - k: v for k, v in iteritems(state_dict_ids) - if include(k[0], k[1]) - }, got_all - - def _get_all_state_from_cache(self, cache, group): - """Checks if group is in cache. See `_get_state_for_groups` - - Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool - indicating if we successfully retrieved all requests state from the - cache, if False we need to query the DB for the missing state. - - Args: - cache(DictionaryCache): the state group cache to use - group: The state group to lookup - """ - is_all, _, state_dict_ids = cache.get(group) - - return state_dict_ids, is_all + return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None, filtered_types=None): + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: groups (iterable[int]): list of state groups for which we want to get the state. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - if types is not None: - non_member_types = [t for t in types if t[0] != EventTypes.Member] - if filtered_types is not None and EventTypes.Member not in filtered_types: - # we want all of the membership events - member_types = None - else: - member_types = [t for t in types if t[0] == EventTypes.Member] - - else: - non_member_types = None - member_types = None + member_filter, non_member_filter = state_filter.get_member_split() - non_member_state = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, non_member_types, filtered_types, + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, + state_filter=non_member_filter, + ) ) - # XXX: we could skip this entirely if member_types is [] - member_state = yield self._get_state_for_groups_using_cache( - # we set filtered_types=None as member_state only ever contain members. - groups, self._state_group_members_cache, member_types, None, + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, + state_filter=member_filter, + ) ) - state = non_member_state + state = dict(non_member_state) for group in groups: state[group].update(member_state[group]) + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + defer.returnValue(state) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), + state_filter=db_state_filter, + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + defer.returnValue(state) - @defer.inlineCallbacks def _get_state_for_groups_using_cache( - self, groups, cache, types=None, filtered_types=None + self, groups, cache, state_filter, ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. @@ -790,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache (DictionaryCache): the cache of group ids to state dicts which we will pass through - either the normal state cache or the specific members state cache. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. """ - if types: - types = frozenset(types) results = {} - missing_groups = [] - if types is not None: - for group in set(groups): - state_dict_ids, got_all = self._get_some_state_from_cache( - cache, group, types, filtered_types - ) - results[group] = state_dict_ids + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids - if not got_all: - missing_groups.append(group) - else: - for group in set(groups): - state_dict_ids, got_all = self._get_all_state_from_cache( - cache, group - ) + if not got_all: + incomplete_groups.add(group) - results[group] = state_dict_ids + return results, incomplete_groups - if not got_all: - missing_groups.append(group) + def _insert_into_cache(self, group_to_state_dict, state_filter, + cache_seq_num_members, cache_seq_num_non_members): + """Inserts results from querying the database into the relevant cache. - if missing_groups: - # Okay, so we have some missing_types, let's fetch them. - cache_seq_num = cache.sequence + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ - # the DictionaryCache knows if it has *all* the state, but - # does not know if it has all of the keys of a particular type, - # which makes wildcard lookups expensive unless we have a complete - # cache. Hence, if we are doing a wildcard lookup, populate the - # cache fully so that we can do an efficient lookup next time. + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) - if filtered_types or (types and any(k is None for (t, k) in types)): - types_to_fetch = None - else: - types_to_fetch = types + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() - group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch, cache == self._state_group_members_cache, - ) + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict = results[group] - - # update the result, filtering by `types`. - if types: - for k, v in iteritems(group_state_dict): - (typ, _) = k - if ( - (k in types or (typ, None) in types) or - (filtered_types and typ not in filtered_types) - ): - state_dict[k] = v + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v else: - state_dict.update(group_state_dict) - - # update the cache with all the things we fetched from the - # database. - cache.update( - cache_seq_num, - key=group, - value=group_state_dict, - fetched_keys=types_to_fetch, - ) + state_dict_non_members[k] = v - defer.returnValue(results) + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) def store_state_group(self, event_id, room_id, prev_group, delta_ids, current_state_ids): @@ -1181,12 +1374,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], types=None + txn, [prev_group], ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], types=None + txn, [state_group], ) curr_state = curr_state[state_group] diff --git a/synapse/visibility.py b/synapse/visibility.py index 43f48196be..0281a7c919 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, ) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), - types=types, + state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( @@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events): # need to check membership (as we know the server is in the room). event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), + state_filter=StateFilter.from_types( + types=((EventTypes.RoomHistoryVisibility, ""),), ) ) @@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events): # of the history vis and membership state at those events. event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), + state_filter=StateFilter.from_types( + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ), ) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b9c5b39d59..086a39d834 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID import tests.unittest @@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we get the full state as of the final event state = yield self.store.get_state_for_event( - e5.event_id, None, filtered_types=None + e5.event_id, ) self.assertIsNotNone(e4) @@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, '')], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) self.assertStateMapEqual( {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filtered_types to grab a specific room member - # without filtering out the other event types + # check we can grab a specific room member without filtering out the + # other event types state = yield self.store.get_state_for_event( e5.event_id, - [(EventTypes.Member, self.u_alice.to_string())], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ) ) self.assertStateMapEqual( @@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase): state, ) - # check that types=[], filtered_types=[EventTypes.Member] - # doesn't return all members + # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member] + e5.event_id, state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertStateMapEqual( @@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ####################################################### - # _get_some_state_from_cache tests against a full cache + # _get_state_for_group_using_cache tests against a full cache ####################################################### room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group = list(group_ids.keys())[0] - # test _get_some_state_from_cache correctly filters out members with types=[] - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members with wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) @@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase): ############################################ # test that things work with a partial cache - # test _get_some_state_from_cache correctly filters out members with types=[] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) -- cgit 1.4.1 From 4cda300058ba68f97c032923ebf429f437eddd8e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 12 Oct 2018 11:13:40 +0100 Subject: preserve room visibility --- synapse/handlers/room.py | 8 +++++--- synapse/storage/room.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3cce6f6150..2f9eb8ef4c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -32,7 +32,7 @@ from synapse.api.constants import ( JoinRules, RoomCreationPreset, ) -from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -97,9 +97,11 @@ class RoomCreationHandler(BaseHandler): with (yield self._upgrade_linearizer.queue(old_room_id)): # start by allocating a new room id - is_public = False # XXX fixme + r = yield self.store.get_room(old_room_id) + if r is None: + raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=is_public, + creator_id=user_id, is_public=r["is_public"], ) # we create and auth the tombstone event before properly creating the new diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 61013b8919..41c65e112a 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -47,7 +47,7 @@ class RoomWorkerStore(SQLBaseStore): Args: room_id (str): The ID of the room to retrieve. Returns: - A namedtuple containing the room information, or an empty list. + A dict containing the room information, or None if the room is unknown. """ return self._simple_select_one( table="rooms", -- cgit 1.4.1 From b2399f6281d7cd11e7762b683bdd5a4f0c24927e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Oct 2018 14:01:11 +0000 Subject: Make SQL a bit cleaner --- synapse/storage/state.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 59a50a5df9..45afd42b3f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1272,14 +1272,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Check if state groups are referenced sql = """ - SELECT state_group, count(*) FROM event_to_state_groups + SELECT DISTINCT state_group FROM event_to_state_groups LEFT JOIN events_to_purge AS ep USING (event_id) WHERE state_group IN (%s) AND ep.event_id IS NULL - GROUP BY state_group """ % (",".join("?" for _ in current_search),) txn.execute(sql, list(current_search)) - referenced = set(sg for sg, cnt in txn if cnt > 0) + referenced = set(sg for sg, in txn) referenced_groups |= referenced # We don't continue iterating up the state group graphs for state -- cgit 1.4.1 From f4f223aa4455bea3aa642c23ae957932b1168ba3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Oct 2018 14:01:49 +0000 Subject: Don't make temporary list --- synapse/storage/state.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 45afd42b3f..947d3fc177 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1266,9 +1266,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): current_search = next_to_search next_to_search = set() else: - lst = list(next_to_search) - current_search = set(lst[:100]) - next_to_search = set(lst[100:]) + current_search = set(islice(next_to_search, 100)) + next_to_search -= current_search # Check if state groups are referenced sql = """ -- cgit 1.4.1 From 664b192a3b4aa57597d9832361b025bc078ea87a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Oct 2018 14:21:43 +0000 Subject: Fix set operations thinko --- synapse/storage/state.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 947d3fc177..dfec57c045 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1293,10 +1293,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("prev_state_group", "state_group",), ) - next_to_search.update(row["state_group"] for row in rows) + prevs = set(row["state_group"] for row in rows) # We don't bother re-handling groups we've already seen - next_to_search -= state_groups_seen - state_groups_seen |= next_to_search + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs for row in rows: # Note: Each state group can have at most one prev group -- cgit 1.4.1 From ad88460e0d2e0a7c2cf39ec5539d5c4ff030bbbd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Oct 2018 14:23:34 +0000 Subject: Move _find_unreferenced_groups --- synapse/storage/events.py | 85 +++++++++++++++++++++++++++++++++++++++++++++-- synapse/storage/state.py | 79 ------------------------------------------- 2 files changed, 83 insertions(+), 81 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e791922733..919e855f3b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2052,8 +2052,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.info("[purge] finding state groups that can be deleted") - state_groups_to_delete, remaining_state_groups = self._find_unreferenced_groups( - txn, referenced_state_groups, + state_groups_to_delete, remaining_state_groups = ( + self._find_unreferenced_groups_during_purge( + txn, referenced_state_groups, + ) ) logger.info( @@ -2209,6 +2211,85 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.info("[purge] done") + def _find_unreferenced_groups_during_purge(self, txn, state_groups): + """Used when purging history to figure out which state groups can be + deleted and which need to be de-delta'ed (due to one of its prev groups + being scheduled for deletion). + + Args: + txn + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + tuple[set[int], set[int]]: The set of state groups that can be + deleted and the set of state groups that need to be de-delta'ed + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + # Check if state groups are referenced + sql = """ + SELECT DISTINCT state_group FROM event_to_state_groups + LEFT JOIN events_to_purge AS ep USING (event_id) + WHERE state_group IN (%s) AND ep.event_id IS NULL + """ % (",".join("?" for _ in current_search),) + txn.execute(sql, list(current_search)) + + referenced = set(sg for sg, in txn) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=current_search, + keyvalues={}, + retcols=("prev_state_group", "state_group",), + ) + + prevs = set(row["state_group"] for row in rows) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + for row in rows: + # Note: Each state group can have at most one prev group + graph[row["state_group"]] = row["prev_state_group"] + + to_delete = state_groups_seen - referenced_groups + + to_dedelta = set() + for sg in referenced_groups: + prev_sg = graph.get(sg) + if prev_sg and prev_sg in to_delete: + to_dedelta.add(sg) + + return to_delete, to_dedelta + @defer.inlineCallbacks def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dfec57c045..d737bd6778 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1234,85 +1234,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return count - def _find_unreferenced_groups(self, txn, state_groups): - """Used when purging history to figure out which state groups can be - deleted and which need to be de-delta'ed (due to one of its prev groups - being scheduled for deletion). - - Args: - txn - state_groups (set[int]): Set of state groups referenced by events - that are going to be deleted. - - Returns: - tuple[set[int], set[int]]: The set of state groups that can be - deleted and the set of state groups that need to be de-delta'ed - """ - # Graph of state group -> previous group - graph = {} - - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(islice(next_to_search, 100)) - next_to_search -= current_search - - # Check if state groups are referenced - sql = """ - SELECT DISTINCT state_group FROM event_to_state_groups - LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % (",".join("?" for _ in current_search),) - txn.execute(sql, list(current_search)) - - referenced = set(sg for sg, in txn) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=current_search, - keyvalues={}, - retcols=("prev_state_group", "state_group",), - ) - - prevs = set(row["state_group"] for row in rows) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - for row in rows: - # Note: Each state group can have at most one prev group - graph[row["state_group"]] = row["prev_state_group"] - - to_delete = state_groups_seen - referenced_groups - - to_dedelta = set() - for sg in referenced_groups: - prev_sg = graph.get(sg) - if prev_sg and prev_sg in to_delete: - to_dedelta.add(sg) - - return to_delete, to_dedelta - class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): """ Keeps track of the state at a given event. -- cgit 1.4.1 From 88e5ffe6fe816e54a5471728e93fde63353d9a70 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 29 Oct 2018 17:34:34 +0000 Subject: Deduplicate device updates sent over replication We currently send several kHz of device list updates over replication occisonally, which often causes the replications streams to lag and then get dropped. A lot of those updates will actually be duplicates, since we don't send e.g. device_ids across replication, so let's deduplicate it when we pull them out of the database. --- synapse/storage/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d10ff9e4b9..62497ab63f 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -589,10 +589,14 @@ class DeviceStore(SQLBaseStore): combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. """ + # We do a group by here as there can be a large number of duplicate + # entries, since we throw away device IDs. sql = """ - SELECT stream_id, user_id, destination FROM device_lists_stream + SELECT MAX(stream_id) AS stream_id, user_id, destination + FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id, destination """ return self._execute( "get_all_device_list_changes_for_remotes", None, -- cgit 1.4.1 From 50e328d1e7a61268dbf274f10798ea5994c6c25a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 1 Nov 2018 19:01:29 +0000 Subject: Remove redundant database locks for device list updates We can rely on the application-level per-user linearizer. --- synapse/storage/devices.py | 45 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 62497ab63f..bc3f575b1e 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -239,7 +239,19 @@ class DeviceStore(SQLBaseStore): def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): - """Updates a single user's device in the cache. + """Updates a single device in the cache of a remote user's devicelist. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + device_id (str): ID of decivice being updated + content (dict): new data on this device + stream_id (int): the version of the device list + + Returns: + Deferred[None] """ return self.runInteraction( "update_remote_device_list_cache_entry", @@ -272,7 +284,11 @@ class DeviceStore(SQLBaseStore): }, values={ "content": json.dumps(content), - } + }, + + # we don't need to lock, because we assume we are the only thread + # updating this user's devices. + lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) @@ -289,11 +305,26 @@ class DeviceStore(SQLBaseStore): }, values={ "stream_id": stream_id, - } + }, + + # again, we can assume we are the only thread updating this user's + # extremity. + lock=False, ) def update_remote_device_list_cache(self, user_id, devices, stream_id): - """Replace the cache of the remote user's devices. + """Replace the entire cache of the remote user's devices. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + devices (list[dict]): list of device objects supplied over federation + stream_id (int): the version of the device list + + Returns: + Deferred[None] """ return self.runInteraction( "update_remote_device_list_cache", @@ -338,7 +369,11 @@ class DeviceStore(SQLBaseStore): }, values={ "stream_id": stream_id, - } + }, + + # we don't need to lock, because we can assume we are the only thread + # updating this user's extremity. + lock=False, ) def get_devices_by_remote(self, destination, from_stream_id): -- cgit 1.4.1 From 350f654e7b80ff19d1fdf3861093cef09ccf3ab1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 1 Nov 2018 19:10:33 +0000 Subject: Add unique indexes to a couple of tables The indexes on device_lists_remote_extremeties can be unique, and they therefore should, to ensure that the db remains consistent. --- synapse/storage/devices.py | 49 +++++++++++++++++++++- .../schema/delta/40/device_list_streams.sql | 9 ++-- .../delta/52/device_list_streams_unique_idx.sql | 36 ++++++++++++++++ 3 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql (limited to 'synapse/storage') diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index bc3f575b1e..ecdab34e7d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -22,14 +22,19 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore, db_to_json +from ._base import Cache, db_to_json logger = logging.getLogger(__name__) +DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( + "drop_device_list_streams_non_unique_indexes" +) -class DeviceStore(SQLBaseStore): + +class DeviceStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): super(DeviceStore, self).__init__(db_conn, hs) @@ -52,6 +57,30 @@ class DeviceStore(SQLBaseStore): columns=["user_id", "device_id"], ) + # create a unique index on device_lists_remote_cache + self.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -757,3 +786,19 @@ class DeviceStore(SQLBaseStore): "_prune_old_outbound_device_pokes", _prune_txn, ) + + @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_lists_remote_cache_id" + ) + txn.execute( + "DROP INDEX IF EXISTS device_lists_remote_extremeties_id" + ) + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + defer.returnValue(1) diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 54841b3843..dd6dcb65f1 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -20,9 +20,6 @@ CREATE TABLE device_lists_remote_cache ( content TEXT NOT NULL ); -CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); - - -- The last update we got for a user. Empty if we're not receiving updates for -- that user. CREATE TABLE device_lists_remote_extremeties ( @@ -30,7 +27,11 @@ CREATE TABLE device_lists_remote_extremeties ( stream_id TEXT NOT NULL ); -CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); +-- we used to create non-unique indexes on these tables, but as of update 52 we create +-- unique indexes concurrently: +-- +-- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +-- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); -- Stream of device lists updates. Includes both local and remotes diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql new file mode 100644 index 0000000000..bfa49e6f92 --- /dev/null +++ b/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql @@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_remote_cache +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_remote_cache_unique_idx', '{}'); + +-- and one on device_lists_remote_extremeties +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'device_lists_remote_extremeties_unique_idx', '{}', + + -- doesn't really depend on this, but we need to make sure both happen + -- before we drop the old indexes. + 'device_lists_remote_cache_unique_idx' + ); + +-- once they complete, we can drop the old indexes. +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'drop_device_list_streams_non_unique_indexes', '{}', + 'device_lists_remote_extremeties_unique_idx' + ); -- cgit 1.4.1