diff options
-rw-r--r-- | synapse/groups/groups_server.py | 18 | ||||
-rw-r--r-- | synapse/handlers/__init__.py | 2 | ||||
-rw-r--r-- | synapse/handlers/groups_local.py | 16 | ||||
-rw-r--r-- | synapse/handlers/message.py | 3 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 80 | ||||
-rw-r--r-- | synapse/handlers/register.py | 4 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/profile.py | 18 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/storage/_base.py | 51 | ||||
-rw-r--r-- | synapse/storage/profile.py | 98 | ||||
-rw-r--r-- | synapse/storage/schema/delta/43/profile_cache.sql | 28 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 5 | ||||
-rw-r--r-- | tests/rest/client/v1/test_profile.py | 3 |
15 files changed, 292 insertions, 47 deletions
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index f25f327eb9..6bccae4bfb 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -503,6 +503,13 @@ class GroupsServerHandler(object): get_domain_from_id(user_id), group_id, user_id, content ) + user_profile = res.get("user_profile", {}) + yield self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + if res["state"] == "join": if not self.hs.is_mine_id(user_id): remote_attestation = res["attestation"] @@ -627,6 +634,9 @@ class GroupsServerHandler(object): get_domain_from_id(user_id), group_id, user_id, {} ) + if not self.hs.is_mine_id(user_id): + yield self.store.maybe_delete_remote_profile_cache(user_id) + defer.returnValue({}) @defer.inlineCallbacks @@ -647,6 +657,7 @@ class GroupsServerHandler(object): avatar_url = profile.get("avatar_url") short_description = profile.get("short_description") long_description = profile.get("long_description") + user_profile = content.get("user_profile", {}) yield self.store.create_group( group_id, @@ -679,6 +690,13 @@ class GroupsServerHandler(object): remote_attestation=remote_attestation, ) + if not self.hs.is_mine_id(user_id): + yield self.store.add_remote_profile_cache( + user_id, + displayname=user_profile.get("displayname"), + avatar_url=user_profile.get("avatar_url"), + ) + defer.returnValue({ "group_id": group_id, }) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 5ad408f549..53213cdccf 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -20,7 +20,6 @@ from .room import ( from .room_member import RoomMemberHandler from .message import MessageHandler from .federation import FederationHandler -from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler from .identity import IdentityHandler @@ -52,7 +51,6 @@ class Handlers(object): self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) self.federation_handler = FederationHandler(hs) - self.profile_handler = ProfileHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.identity_handler = IdentityHandler(hs) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 274fed9278..1950c12f1f 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -56,6 +56,8 @@ class GroupsLocalHandler(object): self.notifier = hs.get_notifier() self.attestations = hs.get_groups_attestation_signing() + self.profile_handler = hs.get_profile_handler() + # Ensure attestations get renewed hs.get_groups_attestation_renewer() @@ -123,6 +125,7 @@ class GroupsLocalHandler(object): defer.returnValue(res) + @defer.inlineCallbacks def create_group(self, group_id, user_id, content): """Create a group """ @@ -130,13 +133,16 @@ class GroupsLocalHandler(object): logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - return self.groups_server_handler.create_group( + res = yield self.groups_server_handler.create_group( group_id, user_id, content ) + defer.returnValue(res) - return self.transport_client.create_group( + content["user_profile"] = yield self.profile_handler.get_profile(user_id) + res = yield self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content, - ) # TODO + ) + defer.returnValue(res) @defer.inlineCallbacks def get_users_in_group(self, group_id, requester_user_id): @@ -265,7 +271,9 @@ class GroupsLocalHandler(object): "groups_key", token, users=[user_id], ) - defer.returnValue({"state": "invite"}) + user_profile = yield self.profile_handler.get_profile(user_id) + + defer.returnValue({"state": "invite", "user_profile": user_profile}) @defer.inlineCallbacks def remove_user_from_group(self, group_id, user_id, requester_user_id, content): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index be4f123c54..5b8f20b73c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -47,6 +47,7 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() + self.profile_handler = hs.get_profile_handler() self.pagination_lock = ReadWriteLock() @@ -210,7 +211,7 @@ class MessageHandler(BaseHandler): if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler content = builder.content try: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7abee98dea..c3cee38a43 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -19,14 +19,15 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id from ._base import BaseHandler - logger = logging.getLogger(__name__) class ProfileHandler(BaseHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): super(ProfileHandler, self).__init__(hs) @@ -36,6 +37,40 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) + self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) + + @defer.inlineCallbacks + def get_profile(self, user_id): + target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + + defer.returnValue({ + "displayname": displayname, + "avatar_url": avatar_url, + }) + else: + try: + result = yield self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + defer.returnValue(result) + except CodeMessageException as e: + if e.code != 404: + logger.exception("Failed to get displayname") + + raise + @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -182,3 +217,44 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e.message) ) + + def _update_remote_profile_cache(self): + """Called periodically to check profiles of remote users we haven't + checked in a while. + """ + entries = yield self.store.get_remote_profile_cache_entries_that_expire( + last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS + ) + + for user_id, displayname, avatar_url in entries: + is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + user_id, + ) + if not is_subscribed: + yield self.store.maybe_delete_remote_profile_cache(user_id) + continue + + try: + profile = yield self.federation.make_query( + destination=get_domain_from_id(user_id), + query_type="profile", + args={ + "user_id": user_id, + }, + ignore_backoff=True, + ) + except: + logger.exception("Failed to get avatar_url") + + yield self.store.update_remote_profile_cache( + user_id, displayname, avatar_url + ) + continue + + new_name = profile.get("displayname") + new_avatar = profile.get("avatar_url") + + # We always hit update to update the last_check timestamp + yield self.store.update_remote_profile_cache( + user_id, new_name, new_avatar + ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ee3a2269a8..560fb36254 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() + self.profile_handler = hs.get_profile_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) - profile_handler = self.hs.get_handlers().profile_handler - yield profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, displayname, by_admin=True, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b3f979b246..dadc19d45b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -45,6 +45,8 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) + self.profile_handler = hs.get_profile_handler() + self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() @@ -255,7 +257,7 @@ class RoomMemberHandler(BaseHandler): content["membership"] = Membership.JOIN - profile = self.hs.get_handlers().profile_handler + profile = self.profile_handler if not content_specified: content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1a5045c9ec..d7edc34245 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileDisplaynameRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) @@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): except: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_displayname( + yield self.profile_handler.set_displayname( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileAvatarURLRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) @@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): except: defer.returnValue((400, "Unable to parse name")) - yield self.handlers.profile_handler.set_avatar_url( + yield self.profile_handler.set_avatar_url( user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet): def __init__(self, hs): super(ProfileRestServlet, self).__init__(hs) - self.handlers = hs.get_handlers() + self.profile_handler = hs.get_profile_handler() @defer.inlineCallbacks def on_GET(self, request, user_id): user = UserID.from_string(user_id) - displayname = yield self.handlers.profile_handler.get_displayname( + displayname = yield self.profile_handler.get_displayname( user, ) - avatar_url = yield self.handlers.profile_handler.get_avatar_url( + avatar_url = yield self.profile_handler.get_avatar_url( user, ) diff --git a/synapse/server.py b/synapse/server.py index d0a6272766..5b892cc390 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -51,6 +51,7 @@ from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.profile import ProfileHandler from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory @@ -114,6 +115,7 @@ class HomeServer(object): 'application_service_scheduler', 'application_service_handler', 'device_message_handler', + 'profile_handler', 'notifier', 'distributor', 'client_resource', @@ -258,6 +260,9 @@ class HomeServer(object): def build_initial_sync_handler(self): return InitialSyncHandler(self) + def build_profile_handler(self): + return ProfileHandler(self) + def build_event_sources(self): return EventSources(self) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6f54036d67..5124a833a5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -743,6 +743,33 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) + def _simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, + self._simple_update_txn, + table, keyvalues, updatevalues, + ) + + @staticmethod + def _simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + return txn.rowcount + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for @@ -768,27 +795,13 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - @staticmethod - def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute( - update_sql, - updatevalues.values() + keyvalues.values() - ) + @classmethod + def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - if txn.rowcount == 0: + if rowcount == 0: raise StoreError(404, "No row found") - if txn.rowcount > 1: + if rowcount > 1: raise StoreError(500, "More than one row matched") @staticmethod diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..dca6af8a77 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from ._base import SQLBaseStore @@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore): updatevalues={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) + + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url", "last_check"), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) diff --git a/synapse/storage/schema/delta/43/profile_cache.sql b/synapse/storage/schema/delta/43/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/schema/delta/43/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 2a203129ca..a5f47181d7 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase): self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.handlers = ProfileHandlers(hs) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") @@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) - self.handler = hs.get_handlers().profile_handler + self.handler = hs.get_profile_handler() @defer.inlineCallbacks def test_get_my_name(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index c8cf9a63ec..e990e45220 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver( handlers=None, http_client=None, - expire_access_token=True) + expire_access_token=True, + profile_handler=Mock(), + ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret')) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler - self.hs.get_handlers().profile_handler = Mock() @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 1e95e97538..dddcf51b69 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase): resource_for_client=self.mock_resource, federation=Mock(), replication_layer=Mock(), + profile_handler=self.mock_handler ) def _get_user_by_req(request=None, allow_guest=False): @@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase): hs.get_v1auth().get_user_by_req = _get_user_by_req - hs.get_handlers().profile_handler = self.mock_handler - profile.register_servlets(hs, self.mock_resource) @defer.inlineCallbacks |