From 54a87a7b086048e2131a6d1ed00e25836ffa2995 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Apr 2019 10:24:38 +0100 Subject: Collect room-version variations into one place (#4969) Collect all the things that make room-versions different to one another into one place, so that it's easier to define new room versions. --- tests/storage/test_redaction.py | 9 +++++---- tests/storage/test_roommember.py | 5 +++-- tests/storage/test_state.py | 5 +++-- 3 files changed, 11 insertions(+), 8 deletions(-) (limited to 'tests/storage') diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 3957561b1e..0fc5019e9f 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -18,7 +18,8 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest @@ -51,7 +52,7 @@ class RedactionTestCase(unittest.TestCase): ): content = {"membership": membership} content.update(extra_content) - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Member, @@ -74,7 +75,7 @@ class RedactionTestCase(unittest.TestCase): def inject_message(self, room, user, body): self.depth += 1 - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Message, @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Redaction, diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7fa2f4fd70..063387863e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -18,7 +18,8 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest @@ -49,7 +50,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Member, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 99cd3e09eb..78e260a7fa 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID @@ -48,7 +49,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": typ, -- cgit 1.5.1 From e8419554ffcb5ec41a1f5c22ebe89163f601170b Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 3 Apr 2019 11:11:15 +0100 Subject: Remove presence lists (#4989) Remove presence list support as per MSC 1819 --- changelog.d/4989.feature | 1 + synapse/handlers/presence.py | 167 +-------------------- synapse/replication/slave/storage/presence.py | 10 -- synapse/rest/client/v1/presence.py | 67 --------- synapse/storage/prepare_database.py | 2 +- synapse/storage/presence.py | 86 +---------- .../storage/schema/delta/54/drop_presence_list.sql | 16 ++ .../storage/schema/full_schemas/16/presence.sql | 12 +- tests/storage/test_presence.py | 118 --------------- 9 files changed, 23 insertions(+), 456 deletions(-) create mode 100644 changelog.d/4989.feature create mode 100644 synapse/storage/schema/delta/54/drop_presence_list.sql delete mode 100644 tests/storage/test_presence.py (limited to 'tests/storage') diff --git a/changelog.d/4989.feature b/changelog.d/4989.feature new file mode 100644 index 0000000000..a5138f5612 --- /dev/null +++ b/changelog.d/4989.feature @@ -0,0 +1 @@ +Remove presence list support as per MSC 1819. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index e85c49742d..3b22a22a19 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -113,27 +113,6 @@ class PresenceHandler(object): federation_registry.register_edu_handler( "m.presence", self.incoming_presence ) - federation_registry.register_edu_handler( - "m.presence_invite", - lambda origin, content: self.invite_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) - federation_registry.register_edu_handler( - "m.presence_accept", - lambda origin, content: self.accept_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) - federation_registry.register_edu_handler( - "m.presence_deny", - lambda origin, content: self.deny_presence( - observed_user=UserID.from_string(content["observed_user"]), - observer_user=UserID.from_string(content["observer_user"]), - ) - ) active_presence = self.store.take_presence_startup_info() @@ -759,137 +738,6 @@ class PresenceHandler(object): yield self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Returns the presence for all users in their presence list. - """ - if not self.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = yield self.get_states( - target_user_ids=[row["observed_user_id"] for row in presence_list], - as_event=False, - ) - - now = self.clock.time_msec() - results[:] = [format_user_presence_state(r, now) for r in results] - - is_accepted = { - row["observed_user_id"]: row["accepted"] for row in presence_list - } - - for result in results: - result.update({ - "accepted": is_accepted, - }) - - defer.returnValue(results) - - @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Sends a presence invite. - """ - yield self.store.add_presence_list_pending( - observer_user.localpart, observed_user.to_string() - ) - - if self.is_mine(observed_user): - yield self.invite_presence(observed_user, observer_user) - else: - yield self.federation.build_and_send_edu( - destination=observed_user.domain, - edu_type="m.presence_invite", - content={ - "observed_user": observed_user.to_string(), - "observer_user": observer_user.to_string(), - } - ) - - @defer.inlineCallbacks - def invite_presence(self, observed_user, observer_user): - """Handles new presence invites. - """ - if not self.is_mine(observed_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - # TODO: Don't auto accept - if self.is_mine(observer_user): - yield self.accept_presence(observed_user, observer_user) - else: - self.federation.build_and_send_edu( - destination=observer_user.domain, - edu_type="m.presence_accept", - content={ - "observed_user": observed_user.to_string(), - "observer_user": observer_user.to_string(), - } - ) - - state_dict = yield self.get_state(observed_user, as_event=False) - state_dict = format_user_presence_state(state_dict, self.clock.time_msec()) - - self.federation.build_and_send_edu( - destination=observer_user.domain, - edu_type="m.presence", - content={ - "push": [state_dict] - } - ) - - @defer.inlineCallbacks - def accept_presence(self, observed_user, observer_user): - """Handles a m.presence_accept EDU. Mark a presence invite from a - local or remote user as accepted in a local user's presence list. - Starts polling for presence updates from the local or remote user. - Args: - observed_user(UserID): The user to update in the presence list. - observer_user(UserID): The owner of the presence list to update. - """ - yield self.store.set_presence_list_accepted( - observer_user.localpart, observed_user.to_string() - ) - - @defer.inlineCallbacks - def deny_presence(self, observed_user, observer_user): - """Handle a m.presence_deny EDU. Removes a local or remote user from a - local user's presence list. - Args: - observed_user(UserID): The local or remote user to remove from the - list. - observer_user(UserID): The local owner of the presence list. - Returns: - A Deferred. - """ - yield self.store.del_presence_list( - observer_user.localpart, observed_user.to_string() - ) - - # TODO(paul): Inform the user somehow? - - @defer.inlineCallbacks - def drop(self, observed_user, observer_user): - """Remove a local or remote user from a local user's presence list and - unsubscribe the local user from updates that user. - Args: - observed_user(UserId): The local or remote user to remove from the - list. - observer_user(UserId): The local owner of the presence list. - Returns: - A Deferred. - """ - if not self.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - yield self.store.del_presence_list( - observer_user.localpart, observed_user.to_string() - ) - - # TODO: Inform the remote that we've dropped the presence list. - @defer.inlineCallbacks def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. @@ -904,11 +752,7 @@ class PresenceHandler(object): if observer_room_ids & observed_room_ids: defer.returnValue(True) - accepted_observers = yield self.store.get_presence_list_observers_accepted( - observed_user.to_string() - ) - - defer.returnValue(observer_user.to_string() in accepted_observers) + defer.returnValue(False) @defer.inlineCallbacks def get_all_presence_updates(self, last_id, current_id): @@ -1204,10 +1048,7 @@ class PresenceEventSource(object): updates for """ user_id = user.to_string() - plist = yield self.store.get_presence_list_accepted( - user.localpart, on_invalidate=cache_context.invalidate, - ) - users_interested_in = set(row["observed_user_id"] for row in plist) + users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( @@ -1412,10 +1253,6 @@ def get_interested_parties(store, states): for room_id in room_ids: room_ids_to_states.setdefault(room_id, []).append(state) - plist = yield store.get_presence_list_observers_accepted(state.user_id) - for u in plist: - users_to_states.setdefault(u, []).append(state) - # Always notify self users_to_states.setdefault(state.user_id, []).append(state) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 9e530defe0..0ec1db25ce 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -39,16 +39,6 @@ class SlavedPresenceStore(BaseSlavedStore): _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] - # XXX: This is a bit broken because we don't persist the accepted list in a - # way that can be replicated. This means that we don't have a way to - # invalidate the cache correctly. - get_presence_list_accepted = PresenceStore.__dict__[ - "get_presence_list_accepted" - ] - get_presence_list_observers_accepted = PresenceStore.__dict__[ - "get_presence_list_observers_accepted" - ] - def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index b5a6d6aebf..045d5a20ac 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -93,72 +93,5 @@ class PresenceStatusRestServlet(ClientV1RestServlet): return (200, {}) -class PresenceListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/presence/list/(?P[^/]*)") - - def __init__(self, hs): - super(PresenceListRestServlet, self).__init__(hs) - self.presence_handler = hs.get_presence_handler() - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self.hs.is_mine(user): - raise SynapseError(400, "User not hosted on this Home Server") - - if requester.user != user: - raise SynapseError(400, "Cannot get another user's presence list") - - presence = yield self.presence_handler.get_presence_list( - observer_user=user, accepted=True - ) - - defer.returnValue((200, presence)) - - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self.hs.is_mine(user): - raise SynapseError(400, "User not hosted on this Home Server") - - if requester.user != user: - raise SynapseError( - 400, "Cannot modify another user's presence list") - - content = parse_json_object_from_request(request) - - if "invite" in content: - for u in content["invite"]: - if not isinstance(u, string_types): - raise SynapseError(400, "Bad invite value.") - if len(u) == 0: - continue - invited_user = UserID.from_string(u) - yield self.presence_handler.send_presence_invite( - observer_user=user, observed_user=invited_user - ) - - if "drop" in content: - for u in content["drop"]: - if not isinstance(u, string_types): - raise SynapseError(400, "Bad drop value.") - if len(u) == 0: - continue - dropped_user = UserID.from_string(u) - yield self.presence_handler.drop( - observer_user=user, observed_user=dropped_user - ) - - defer.returnValue((200, {})) - - def on_OPTIONS(self, request): - return (200, {}) - - def register_servlets(hs, http_server): PresenceStatusRestServlet(hs).register(http_server) - PresenceListRestServlet(hs).register(http_server) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 81b4c57ad4..c1711bc8bd 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 53 +SCHEMA_VERSION = 54 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 089ea8c048..42ec8c6bb8 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.constants import PresenceState from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from ._base import SQLBaseStore @@ -205,87 +205,3 @@ class PresenceStore(SQLBaseStore): }, desc="disallow_presence_visible", ) - - def add_presence_list_pending(self, observer_localpart, observed_userid): - return self._simple_insert( - table="presence_list", - values={ - "user_id": observer_localpart, - "observed_user_id": observed_userid, - "accepted": False, - }, - desc="add_presence_list_pending", - ) - - def set_presence_list_accepted(self, observer_localpart, observed_userid): - def update_presence_list_txn(txn): - result = self._simple_update_one_txn( - txn, - table="presence_list", - keyvalues={ - "user_id": observer_localpart, - "observed_user_id": observed_userid, - }, - updatevalues={"accepted": True}, - ) - - self._invalidate_cache_and_stream( - txn, self.get_presence_list_accepted, (observer_localpart,) - ) - self._invalidate_cache_and_stream( - txn, self.get_presence_list_observers_accepted, (observed_userid,) - ) - - return result - - return self.runInteraction( - "set_presence_list_accepted", update_presence_list_txn - ) - - def get_presence_list(self, observer_localpart, accepted=None): - if accepted: - return self.get_presence_list_accepted(observer_localpart) - else: - keyvalues = {"user_id": observer_localpart} - if accepted is not None: - keyvalues["accepted"] = accepted - - return self._simple_select_list( - table="presence_list", - keyvalues=keyvalues, - retcols=["observed_user_id", "accepted"], - desc="get_presence_list", - ) - - @cached() - def get_presence_list_accepted(self, observer_localpart): - return self._simple_select_list( - table="presence_list", - keyvalues={"user_id": observer_localpart, "accepted": True}, - retcols=["observed_user_id", "accepted"], - desc="get_presence_list_accepted", - ) - - @cachedInlineCallbacks() - def get_presence_list_observers_accepted(self, observed_userid): - user_localparts = yield self._simple_select_onecol( - table="presence_list", - keyvalues={"observed_user_id": observed_userid, "accepted": True}, - retcol="user_id", - desc="get_presence_list_accepted", - ) - - defer.returnValue(["@%s:%s" % (u, self.hs.hostname) for u in user_localparts]) - - @defer.inlineCallbacks - def del_presence_list(self, observer_localpart, observed_userid): - yield self._simple_delete_one( - table="presence_list", - keyvalues={ - "user_id": observer_localpart, - "observed_user_id": observed_userid, - }, - desc="del_presence_list", - ) - self.get_presence_list_accepted.invalidate((observer_localpart,)) - self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/schema/delta/54/drop_presence_list.sql new file mode 100644 index 0000000000..e6ee70c623 --- /dev/null +++ b/synapse/storage/schema/delta/54/drop_presence_list.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql index 283136df20..0892c4cf96 100644 --- a/synapse/storage/schema/full_schemas/16/presence.sql +++ b/synapse/storage/schema/full_schemas/16/presence.sql @@ -28,13 +28,5 @@ CREATE TABLE IF NOT EXISTS presence_allow_inbound( UNIQUE (observed_user_id, observer_user_id) ); --- For each of /my/ users (watcher), which possibly-remote users are they --- watching? -CREATE TABLE IF NOT EXISTS presence_list( - user_id TEXT NOT NULL, - observed_user_id TEXT NOT NULL, -- a UserID, - accepted BOOLEAN NOT NULL, - UNIQUE (user_id, observed_user_id) -); - -CREATE INDEX presence_list_user_id ON presence_list (user_id); +-- We used to create a table called presence_list, but this is no longer used +-- and is removed in delta 54. \ No newline at end of file diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py deleted file mode 100644 index c7a63f39b9..0000000000 --- a/tests/storage/test_presence.py +++ /dev/null @@ -1,118 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.types import UserID - -from tests import unittest -from tests.utils import setup_test_homeserver - - -class PresenceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - - self.store = hs.get_datastore() - - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - @defer.inlineCallbacks - def test_presence_list(self): - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.add_presence_list_pending( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 0}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.set_presence_list_accepted( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 1}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [{"observed_user_id": "@banana:test", "accepted": 1}], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) - - yield self.store.del_presence_list( - observer_localpart=self.u_apple.localpart, - observed_userid=self.u_banana.to_string(), - ) - - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart - ) - ), - ) - self.assertEquals( - [], - ( - yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, accepted=True - ) - ), - ) -- cgit 1.5.1 From 0084309cd26d28fc8d2c152b4a7f13479fae16f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Apr 2019 22:00:11 +0100 Subject: Rewrite test_keys as a HomeserverTestCase --- tests/storage/test_keys.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) (limited to 'tests/storage') diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 0d2dc9f325..7170ae76c7 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -15,33 +15,29 @@ import signedjson.key -from twisted.internet import defer - import tests.unittest -import tests.utils - -class KeyStoreTestCase(tests.unittest.TestCase): +KEY_1 = signedjson.key.decode_verify_key_base64( + "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" +) +KEY_2 = signedjson.key.decode_verify_key_base64( + "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +) - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - @defer.inlineCallbacks +class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - key1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" - ) - key2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" - ) - yield self.store.store_server_verify_key("server1", "from_server", 0, key1) - yield self.store.store_server_verify_key("server1", "from_server", 0, key2) + store = self.hs.get_datastore() + + d = store.store_server_verify_key("server1", "from_server", 0, KEY_1) + self.get_success(d) + d = store.store_server_verify_key("server1", "from_server", 0, KEY_2) + self.get_success(d) - res = yield self.store.get_server_verify_keys( + d = store.get_server_verify_keys( "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"] ) + res = self.get_success(d) self.assertEqual(len(res.keys()), 2) self.assertEqual(res["ed25519:key1"].version, "key1") -- cgit 1.5.1 From 18b69be00f9fa79cf2b237992ef1f0094d1dc453 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Apr 2019 14:51:07 +0100 Subject: Rewrite Datastore.get_server_verify_keys Rewrite this so that it doesn't hammer the database. --- synapse/crypto/keyring.py | 38 +++++++++++------------- synapse/storage/keys.py | 74 ++++++++++++++++++++++++++++------------------ tests/storage/test_keys.py | 53 +++++++++++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 52 deletions(-) (limited to 'tests/storage') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ede120b2a6..834b107705 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -301,13 +301,12 @@ class Keyring(object): # complete this VerifyKeyRequest. result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - if key_id in result_keys: + key = result_keys.get(key_id) + if key: with PreserveLoggingContext(): - verify_request.deferred.callback(( - server_name, - key_id, - result_keys[key_id], - )) + verify_request.deferred.callback( + (server_name, key_id, key) + ) break else: # The else block is only reached if the loop above @@ -341,27 +340,24 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): """ - Args: - server_name_and_key_ids (list[(str, iterable[str])]): + server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): list of (server_name, iterable[key_id]) tuples to fetch keys for Returns: - Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from + Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from server_name -> key_id -> VerifyKey """ - res = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.get_server_verify_keys, - server_name, key_ids, - ).addCallback(lambda ks, server: (server, ks), server_name) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) - - defer.returnValue(dict(res)) + keys_to_fetch = ( + (server_name, key_id) + for server_name, key_ids in server_name_and_key_ids + for key_id in key_ids + ) + res = yield self.store.get_server_verify_keys(keys_to_fetch) + keys = {} + for (server_name, key_id), key in res.items(): + keys.setdefault(server_name, {})[key_id] = key + defer.returnValue(keys) @defer.inlineCallbacks def get_keys_from_perspectives(self, server_name_and_key_ids): diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 47a9aa784b..7036541792 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging import six from signedjson.key import decode_verify_key_bytes -from twisted.internet import defer - -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList from ._base import SQLBaseStore @@ -38,36 +39,50 @@ else: class KeyStore(SQLBaseStore): """Persistence for signature verification keys """ - @cachedInlineCallbacks() - def _get_server_verify_key(self, server_name, key_id): - verify_key_bytes = yield self._simple_select_one_onecol( - table="server_signature_keys", - keyvalues={"server_name": server_name, "key_id": key_id}, - retcol="verify_key", - desc="_get_server_verify_key", - allow_none=True, - ) - if verify_key_bytes: - defer.returnValue(decode_verify_key_bytes(key_id, bytes(verify_key_bytes))) + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() - @defer.inlineCallbacks - def get_server_verify_keys(self, server_name, key_ids): - """Retrieve the NACL verification key for a given server for the given - key_ids + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ Args: - server_name (str): The name of the server. - key_ids (iterable[str]): key_ids to try and look up. + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + Returns: - Deferred: resolves to dict[str, VerifyKey]: map from - key_id to verification key. + Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: + map from (server_name, key_id) -> VerifyKey, or None if the key is + unknown """ keys = {} - for key_id in key_ids: - key = yield self._get_server_verify_key(server_name, key_id) - if key: - keys[key_id] = key - defer.returnValue(keys) + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key FROM server_signature_keys " + "WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes = row + keys[(server_name, key_id)] = decode_verify_key_bytes( + key_id, bytes(key_bytes) + ) + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.runInteraction("get_server_verify_keys", _txn) def store_server_verify_key( self, server_name, from_server, time_now_ms, verify_key @@ -93,8 +108,11 @@ class KeyStore(SQLBaseStore): "verify_key": db_binary_type(verify_key.encode()), }, ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). txn.call_after( - self._get_server_verify_key.invalidate, (server_name, key_id) + self._get_server_verify_key.invalidate, ((server_name, key_id),) ) return self.runInteraction("store_server_verify_key", _txn) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 7170ae76c7..6bfaa00fe9 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -15,6 +15,8 @@ import signedjson.key +from twisted.internet.defer import Deferred + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -35,10 +37,55 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): self.get_success(d) d = store.get_server_verify_keys( - "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"] + [ + ("server1", "ed25519:key1"), + ("server1", "ed25519:key2"), + ("server1", "ed25519:key3"), + ] ) res = self.get_success(d) + self.assertEqual(len(res.keys()), 3) + self.assertEqual(res[("server1", "ed25519:key1")].version, "key1") + self.assertEqual(res[("server1", "ed25519:key2")].version, "key2") + + # non-existent result gives None + self.assertIsNone(res[("server1", "ed25519:key3")]) + + def test_cache(self): + """Check that updates correctly invalidate the cache.""" + + store = self.hs.get_datastore() + + key_id_1 = "ed25519:key1" + key_id_2 = "ed25519:key2" + + d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1) + self.get_success(d) + d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2) + self.get_success(d) + + d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + res = self.get_success(d) + self.assertEqual(len(res.keys()), 2) + self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + # we should be able to look up the same thing again without a db hit + res = store.get_server_verify_keys([("srv1", key_id_1)]) + if isinstance(res, Deferred): + res = self.successResultOf(res) + self.assertEqual(len(res.keys()), 1) + self.assertEqual(res[("srv1", key_id_1)], KEY_1) + + new_key_2 = signedjson.key.get_verify_key( + signedjson.key.generate_signing_key("key2") + ) + d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2) + self.get_success(d) + + d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res["ed25519:key1"].version, "key1") - self.assertEqual(res["ed25519:key2"].version, "key2") + self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_2)], new_key_2) -- cgit 1.5.1