diff options
Diffstat (limited to 'synapse')
44 files changed, 2104 insertions, 284 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 7346206bb1..b989007314 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -241,6 +241,16 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) + def get_exlusive_user_regexes(self): + """Get the list of regexes used to determine if a user is exclusively + registered by the AS + """ + return [ + regex_obj["regex"] + for regex_obj in self.namespaces[ApplicationService.NS_USERS] + if regex_obj["exclusive"] + ] + def is_rate_limited(self): return self.rate_limited diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index a15198e05d..003eaba893 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -187,6 +187,7 @@ class TransactionQueue(object): prev_id for prev_id, _ in event.prev_events ], ) + destinations = set(destinations) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 52d97dfbf3..a333acc4aa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -43,7 +43,6 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination -from synapse.push.action_generator import ActionGenerator from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -75,6 +74,7 @@ class FederationHandler(BaseHandler): self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() + self.action_generator = hs.get_action_generator() self.replication_layer.set_handler(self) @@ -832,7 +832,11 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_event_auth(self, event_id): - auth = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + auth = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) for event in auth: event.signatures.update( @@ -1047,9 +1051,7 @@ class FederationHandler(BaseHandler): yield user_joined_room(self.distributor, user, event.room_id) state_ids = context.prev_state_ids.values() - auth_chain = yield self.store.get_auth_chain(set( - [event.event_id] + state_ids - )) + auth_chain = yield self.store.get_auth_chain(state_ids) state = yield self.store.get_events(context.prev_state_ids.values()) @@ -1389,8 +1391,7 @@ class FederationHandler(BaseHandler): ) if not event.internal_metadata.is_outlier(): - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( + yield self.action_generator.handle_push_actions_for_event( event, context ) @@ -1599,7 +1600,11 @@ class FederationHandler(BaseHandler): pass # Now get the current auth_chain for the event. - local_auth_chain = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + local_auth_chain = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. @@ -1792,7 +1797,9 @@ class FederationHandler(BaseHandler): auth_ids = yield self.auth.compute_auth_events( event, context.prev_state_ids ) - local_auth_chain = yield self.store.get_auth_chain(auth_ids) + local_auth_chain = yield self.store.get_auth_chain( + auth_ids, include_given=True + ) try: # 2. Get remote difference. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 196925edad..a04f634c5c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -20,7 +20,6 @@ from synapse.api.errors import AuthError, Codes, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.push.action_generator import ActionGenerator from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) @@ -54,6 +53,8 @@ class MessageHandler(BaseHandler): # This is to stop us from diverging history *too* much. self.limiter = Limiter(max_count=5) + self.action_generator = hs.get_action_generator() + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -590,8 +591,7 @@ class MessageHandler(BaseHandler): "Changing the room create event is forbidden", ) - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( + yield self.action_generator.handle_push_actions_for_event( event, context ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c0205da1a9..91c6c6be3c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -117,6 +117,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. "device_lists", # List of user_ids whose devices have chanegd + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device ])): __slots__ = [] @@ -550,6 +552,14 @@ class SyncHandler(object): sync_result_builder ) + device_id = sync_config.device_id + one_time_key_counts = {} + if device_id: + user_id = sync_config.user.to_string() + one_time_key_counts = yield self.store.count_e2e_one_time_keys( + user_id, device_id + ) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -558,6 +568,7 @@ class SyncHandler(object): archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, + device_one_time_keys_count=one_time_key_counts, next_batch=sync_result_builder.now_token, )) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3b7818af5c..82dedbbc99 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -89,7 +89,7 @@ class TypingHandler(object): until = self._member_typing_until.get(member, None) if not until or until <= now: logger.info("Timing out typing for: %s", member.user_id) - preserve_fn(self._stopped_typing)(member) + self._stopped_typing(member) continue # Check if we need to resend a keep alive over federation for this @@ -147,7 +147,7 @@ class TypingHandler(object): # No point sending another notification defer.returnValue(None) - yield self._push_update( + self._push_update( member=member, typing=True, ) @@ -171,7 +171,7 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=target_user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -180,7 +180,6 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) - @defer.inlineCallbacks def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): # No point @@ -189,16 +188,15 @@ class TypingHandler(object): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - yield self._push_update( + self._push_update( member=member, typing=False, ) - @defer.inlineCallbacks def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - yield self._push_remote(member, typing) + preserve_fn(self._push_remote)(member, typing) self._push_update_local( member=member, diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py new file mode 100644 index 0000000000..43eb1c78e9 --- /dev/null +++ b/synapse/handlers/user_directory.py @@ -0,0 +1,439 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.storage.roommember import ProfileInfo +from synapse.util.metrics import Measure + + +logger = logging.getLogger(__name__) + + +class UserDirectoyHandler(object): + """Handles querying of and keeping updated the user_directory. + + N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY + + The user directory is filled with users who this server can see are joined to a + world_readable or publically joinable room. We keep a database table up to date + by streaming changes of the current state and recalculating whether users should + be in the directory or not when necessary. + + For each user in the directory we also store a room_id which is public and that the + user is joined to. This allows us to ignore history_visibility and join_rules changes + for that user in all other public rooms, as we know they'll still be in at least + one public room. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + + self.notifier.add_replication_callback(self.notify_new_event) + + # When start up for the first time we need to populate the user_directory. + # This is a set of user_id's we've inserted already + self.initially_handled_users = set() + self.initially_handled_users_in_public = set() + + # The current position in the current_state_delta stream + self.pos = None + + # Guard to ensure we only process deltas one at a time + self._is_processing = False + + # We kick this off so that we don't have to wait for a change before + # we start populating the user directory + self.clock.call_later(0, self.notify_new_event) + + def search_users(self, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + return self.store.search_user_dir(search_term, limit) + + @defer.inlineCallbacks + def notify_new_event(self): + """Called when there may be more deltas to process + """ + if self._is_processing: + return + + self._is_processing = True + try: + yield self._unsafe_process() + finally: + self._is_processing = False + + @defer.inlineCallbacks + def _unsafe_process(self): + # If self.pos is None then means we haven't fetched it from DB + if self.pos is None: + self.pos = yield self.store.get_user_directory_stream_pos() + + # If still None then we need to do the initial fill of directory + if self.pos is None: + yield self._do_initial_spam() + self.pos = yield self.store.get_user_directory_stream_pos() + + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "user_dir_delta"): + deltas = yield self.store.get_current_state_deltas(self.pos) + if not deltas: + return + + yield self._handle_deltas(deltas) + + self.pos = deltas[-1]["stream_id"] + yield self.store.update_user_directory_stream_pos(self.pos) + + @defer.inlineCallbacks + def _do_initial_spam(self): + """Populates the user_directory from the current state of the DB, used + when synapse first starts with user_directory support + """ + new_pos = yield self.store.get_max_stream_id_in_current_state_deltas() + + # Delete any existing entries just in case there are any + yield self.store.delete_all_from_user_dir() + + # We process by going through each existing room at a time. + room_ids = yield self.store.get_all_rooms() + + logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) + num_processed_rooms = 1 + + for room_id in room_ids: + logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) + yield self._handle_intial_room(room_id) + num_processed_rooms += 1 + + logger.info("Processed all rooms.") + + self.initially_handled_users = None + + yield self.store.update_user_directory_stream_pos(new_pos) + + @defer.inlineCallbacks + def _handle_intial_room(self, room_id): + """Called when we initially fill out user_directory one room at a time + """ + is_in_room = yield self.state.get_is_host_in_room(room_id, self.server_name) + if not is_in_room: + return + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + unhandled_users = set(users_with_profile) - self.initially_handled_users + + yield self.store.add_profiles_to_user_dir( + room_id, { + user_id: users_with_profile[user_id] for user_id in unhandled_users + } + ) + + self.initially_handled_users |= unhandled_users + + if is_public: + yield self.store.add_users_to_public_room( + room_id, + user_ids=unhandled_users - self.initially_handled_users_in_public + ) + self.initially_handled_users_in_public != unhandled_users + + @defer.inlineCallbacks + def _handle_deltas(self, deltas): + """Called with the state deltas to process + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + # For join rule and visibility changes we need to check if the room + # may have become public or not and add/remove the users in said room + if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): + yield self._handle_room_publicity_change( + room_id, prev_event_id, event_id, typ, + ) + elif typ == EventTypes.Member: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="membership", + public_value=Membership.JOIN, + ) + + if change is None: + # Handle any profile changes + yield self._handle_profile_change(state_key, prev_event_id, event_id) + continue + + if not change: + # Need to check if the server left the room entirely, if so + # we might need to remove all the users in that room + is_in_room = yield self.state.get_is_host_in_room( + room_id, self.server_name, + ) + if not is_in_room: + logger.debug("Server left room: %r", room_id) + # Fetch all the users that we marked as being in user + # directory due to being in the room and then check if + # need to remove those users or not + user_ids = yield self.store.get_users_in_dir_due_to_room(room_id) + for user_id in user_ids: + yield self._handle_remove_user(room_id, user_id) + return + else: + logger.debug("Server is still in room: %r", room_id) + + if change: # The user joined + event = yield self.store.get_event(event_id) + profile = ProfileInfo( + avatar_url=event.content.get("avatar_url"), + display_name=event.content.get("displayname"), + ) + + yield self._handle_new_user(room_id, state_key, profile) + else: # The user left + yield self._handle_remove_user(room_id, state_key) + else: + logger.debug("Ignoring irrelevant type: %r", typ) + + @defer.inlineCallbacks + def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + """Handle a room having potentially changed from/to world_readable/publically + joinable. + + Args: + room_id (str) + prev_event_id (str|None): The previous event before the state change + event_id (str|None): The new event after the state change + typ (str): Type of the event + """ + logger.debug("Handling change for %s", typ) + + if typ == EventTypes.RoomHistoryVisibility: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="history_visibility", + public_value="world_readable", + ) + elif typ == EventTypes.JoinRules: + change = yield self._get_key_change( + prev_event_id, event_id, + key_name="join_rule", + public_value=JoinRules.PUBLIC, + ) + else: + raise Exception("Invalid event type") + # If change is None, no change. True => become world_readable/public, + # False => was world_readable/public + if change is None: + logger.debug("No change") + return + + # There's been a change to or from being world readable. + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + logger.debug("Change: %r, is_public: %r", change, is_public) + + if change and not is_public: + # If we became world readable but room isn't currently public then + # we ignore the change + return + elif not change and is_public: + # If we stopped being world readable but are still public, + # ignore the change + return + + if change: + users_with_profile = yield self.state.get_current_user_in_room(room_id) + for user_id, profile in users_with_profile.iteritems(): + yield self._handle_new_user(room_id, user_id, profile) + else: + users = yield self.store.get_users_in_public_due_to_room(room_id) + for user_id in users: + yield self._handle_remove_user(room_id, user_id) + + @defer.inlineCallbacks + def _handle_new_user(self, room_id, user_id, profile): + """Called when we might need to add user to directory + + Args: + room_id (str): room_id that user joined or started being public that + user_id (str) + """ + logger.debug("Adding user to dir, %r", user_id) + + row = yield self.store.get_user_in_directory(user_id) + if not row: + yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile}) + + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + return + + row = yield self.store.get_user_in_public_room(user_id) + if not row: + yield self.store.add_users_to_public_room(room_id, [user_id]) + + @defer.inlineCallbacks + def _handle_remove_user(self, room_id, user_id): + """Called when we might need to remove user to directory + + Args: + room_id (str): room_id that user left or stopped being public that + user_id (str) + """ + logger.debug("Maybe removing user %r", user_id) + + row = yield self.store.get_user_in_directory(user_id) + update_user_dir = row and row["room_id"] == room_id + + row = yield self.store.get_user_in_public_room(user_id) + update_user_in_public = row and row["room_id"] == room_id + + if not update_user_in_public and not update_user_dir: + return + + # XXX: Make this faster? + rooms = yield self.store.get_rooms_for_user(user_id) + for j_room_id in rooms: + if not update_user_in_public and not update_user_dir: + break + + is_in_room = yield self.state.get_is_host_in_room( + j_room_id, self.server_name, + ) + + if not is_in_room: + continue + + if update_user_dir: + update_user_dir = False + yield self.store.update_user_in_user_dir(user_id, j_room_id) + + if update_user_in_public: + is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + j_room_id + ) + + if is_public: + yield self.store.update_user_in_public_user_list(user_id, j_room_id) + update_user_in_public = False + + if update_user_dir: + yield self.store.remove_from_user_dir(user_id) + elif update_user_in_public: + yield self.store.remove_from_user_in_public_room(user_id) + + @defer.inlineCallbacks + def _handle_profile_change(self, user_id, prev_event_id, event_id): + """Check member event changes for any profile changes and update the + database if there are. + """ + if not prev_event_id or not event_id: + return + + prev_event = yield self.store.get_event(prev_event_id) + event = yield self.store.get_event(event_id) + + if event.membership != Membership.JOIN: + return + + prev_name = prev_event.content.get("displayname") + new_name = event.content.get("displayname") + + prev_avatar = prev_event.content.get("avatar_url") + new_avatar = event.content.get("avatar_url") + + if prev_name != new_name or prev_avatar != new_avatar: + yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + + @defer.inlineCallbacks + def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + """Given two events check if the `key_name` field in content changed + from not matching `public_value` to doing so. + + For example, check if `history_visibility` (`key_name`) changed from + `shared` to `world_readable` (`public_value`). + + Returns: + None if the field in the events either both match `public_value` + or if neither do, i.e. there has been no change. + True if it didnt match `public_value` but now does + False if it did match `public_value` but now doesn't + """ + prev_event = None + event = None + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + + if not event and not prev_event: + logger.debug("Neither event exists: %r %r", prev_event_id, event_id) + defer.returnValue(None) + + prev_value = None + value = None + + if prev_event: + prev_value = prev_event.content.get(key_name) + + if event: + value = event.content.get(key_name) + + logger.debug("prev_value: %r -> value: %r", prev_value, value) + + if value == public_value and prev_value != public_value: + defer.returnValue(True) + elif value != public_value and prev_value == public_value: + defer.returnValue(False) + else: + defer.returnValue(None) diff --git a/synapse/notifier.py b/synapse/notifier.py index 48566187ab..385208b574 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -251,7 +251,8 @@ class Notifier(object): """Notify any user streams that are interested in this room event""" # poke any interested application service. preserve_fn(self.appservice_handler.notify_interested_services)( - room_stream_id) + room_stream_id + ) if self.federation_sender: preserve_fn(self.federation_sender.notify_new_events)( diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 3f75d3f921..fe09d50d55 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from .bulk_push_rule_evaluator import evaluator_for_event +from .bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure @@ -24,11 +24,12 @@ import logging logger = logging.getLogger(__name__) -class ActionGenerator: +class ActionGenerator(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastore() + self.bulk_evaluator = BulkPushRuleEvaluator(hs) # really we want to get all user ids and all profile tags too, # since we want the actions for each profile tag for every user and # also actions for a client with no profile tag for each user. @@ -38,16 +39,11 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): - with Measure(self.clock, "evaluator_for_event"): - bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context - ) - with Measure(self.clock, "action_for_event_by_user"): - actions_by_user = yield bulk_evaluator.action_for_event_by_user( + actions_by_user = yield self.bulk_evaluator.action_for_event_by_user( event, context ) context.push_actions = [ - (uid, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.iteritems() ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f943ff640f..9a96e6fe8f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,60 +19,83 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes from synapse.visibility import filter_events_for_clients_context +from synapse.api.constants import EventTypes, Membership +from synapse.util.caches.descriptors import cached +from synapse.util.async import Linearizer +from collections import namedtuple -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def evaluator_for_event(event, hs, store, context): - rules_by_user = yield store.bulk_get_push_rules_for_room( - event, context - ) - - # if this event is an invite event, we may need to run rules for the user - # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': - invited_user = event.state_key - if invited_user and hs.is_mine_id(invited_user): - has_pusher = yield store.user_has_pusher(invited_user) - if has_pusher: - rules_by_user = dict(rules_by_user) - rules_by_user[invited_user] = yield store.get_push_rules_for_user( - invited_user - ) - defer.returnValue(BulkPushRuleEvaluator( - event.room_id, rules_by_user, store - )) +rules_by_room = {} -class BulkPushRuleEvaluator: +class BulkPushRuleEvaluator(object): + """Calculates the outcome of push rules for an event for all users in the + room at once. """ - Runs push rules for all users in a room. - This is faster than running PushRuleEvaluator for each user because it - fetches all the rules for all the users in one (batched) db query - rather than doing multiple queries per-user. It currently uses - the same logic to run the actual rules, but could be optimised further - (see https://matrix.org/jira/browse/SYN-562) - """ - def __init__(self, room_id, rules_by_user, store): - self.room_id = room_id - self.rules_by_user = rules_by_user - self.store = store + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def _get_rules_for_event(self, event, context): + """This gets the rules for all users in the room at the time of the event, + as well as the push rules for the invitee if the event is an invite. + + Returns: + dict of user_id -> push_rules + """ + room_id = event.room_id + rules_for_room = self._get_rules_for_room(room_id) + + rules_by_user = yield rules_for_room.get_rules(event, context) + + # if this event is an invite event, we may need to run rules for the user + # who's been invited, otherwise they won't get told they've been invited + if event.type == 'm.room.member' and event.content['membership'] == 'invite': + invited = event.state_key + if invited and self.hs.is_mine_id(invited): + has_pusher = yield self.store.user_has_pusher(invited) + if has_pusher: + rules_by_user = dict(rules_by_user) + rules_by_user[invited] = yield self.store.get_push_rules_for_user( + invited + ) + + defer.returnValue(rules_by_user) + + @cached() + def _get_rules_for_room(self, room_id): + """Get the current RulesForRoom object for the given room id + + Returns: + RulesForRoom + """ + # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + # before any lookup methods get called on it as otherwise there may be + # a race if invalidate_all gets called (which assumes its in the cache) + return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache) @defer.inlineCallbacks def action_for_event_by_user(self, event, context): + """Given an event and context, evaluate the push rules and return + the results + + Returns: + dict of user_id -> action + """ + rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} # None of these users can be peeking since this list of users comes # from the set of users in the room, so we know for sure they're all # actually in the room. - user_tuples = [ - (u, False) for u in self.rules_by_user.keys() - ] + user_tuples = [(u, False) for u in rules_by_user] filtered_by_user = yield filter_events_for_clients_context( self.store, user_tuples, [event], {event.event_id: context} @@ -86,7 +109,7 @@ class BulkPushRuleEvaluator: condition_cache = {} - for uid, rules in self.rules_by_user.items(): + for uid, rules in rules_by_user.iteritems(): display_name = None profile_info = room_members.get(uid) if profile_info: @@ -138,3 +161,240 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return False return True + + +class RulesForRoom(object): + """Caches push rules for users in a room. + + This efficiently handles users joining/leaving the room by not invalidating + the entire cache for the room. + """ + + def __init__(self, hs, room_id, rules_for_room_cache): + """ + Args: + hs (HomeServer) + room_id (str) + rules_for_room_cache(Cache): The cache object that caches these + RoomsForUser objects. + """ + self.room_id = room_id + self.is_mine_id = hs.is_mine_id + self.store = hs.get_datastore() + + self.linearizer = Linearizer(name="rules_for_room") + + self.member_map = {} # event_id -> (user_id, state) + self.rules_by_user = {} # user_id -> rules + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + self.state_group = object() + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + self.sequence = 0 + + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + self.uninteresting_user_set = set() + + # We need to be clever on the invalidating caches callbacks, as + # otherwise the invalidation callback holds a reference to the object, + # potentially causing it to leak. + # To get around this we pass a function that on invalidations looks ups + # the RoomsForUser entry in the cache, rather than keeping a reference + # to self around in the callback. + self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) + + @defer.inlineCallbacks + def get_rules(self, event, context): + """Given an event context return the rules for all users who are + currently in the room. + """ + state_group = context.state_group + + with (yield self.linearizer.queue(())): + if state_group and self.state_group == state_group: + logger.debug("Using cached rules for %r", self.room_id) + defer.returnValue(self.rules_by_user) + + ret_rules_by_user = {} + missing_member_event_ids = {} + if state_group and self.state_group == context.prev_group: + # If we have a simple delta then we can reuse most of the previous + # results. + ret_rules_by_user = self.rules_by_user + current_state_ids = context.delta_ids + else: + current_state_ids = context.current_state_ids + + logger.debug( + "Looking for member changes in %r %r", state_group, current_state_ids + ) + + # Loop through to see which member events we've seen and have rules + # for and which we need to fetch + for key in current_state_ids: + typ, user_id = key + if typ != EventTypes.Member: + continue + + if user_id in self.uninteresting_user_set: + continue + + if not self.is_mine_id(user_id): + self.uninteresting_user_set.add(user_id) + continue + + if self.store.get_if_app_services_interested_in_user(user_id): + self.uninteresting_user_set.add(user_id) + continue + + event_id = current_state_ids[key] + + res = self.member_map.get(event_id, None) + if res: + user_id, state = res + if state == Membership.JOIN: + rules = self.rules_by_user.get(user_id, None) + if rules: + ret_rules_by_user[user_id] = rules + continue + + # If a user has left a room we remove their push rule. If they + # joined then we readd it later in _update_rules_with_member_event_ids + ret_rules_by_user.pop(user_id, None) + missing_member_event_ids[user_id] = event_id + + if missing_member_event_ids: + # If we have some memebr events we haven't seen, look them up + # and fetch push rules for them if appropriate. + logger.debug("Found new member events %r", missing_member_event_ids) + yield self._update_rules_with_member_event_ids( + ret_rules_by_user, missing_member_event_ids, state_group, event + ) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Returning push rules for %r %r", + self.room_id, ret_rules_by_user.keys(), + ) + defer.returnValue(ret_rules_by_user) + + @defer.inlineCallbacks + def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, + state_group, event): + """Update the partially filled rules_by_user dict by fetching rules for + any newly joined users in the `member_event_ids` list. + + Args: + ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets + updated with any new rules. + member_event_ids (list): List of event ids for membership events that + have happened since the last time we filled rules_by_user + state_group: The state group we are currently computing push rules + for. Used when updating the cache. + """ + sequence = self.sequence + + rows = yield self.store._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids.values(), + retcols=('user_id', 'membership', 'event_id'), + keyvalues={}, + batch_size=500, + desc="_get_rules_for_member_event_ids", + ) + + members = { + row["event_id"]: (row["user_id"], row["membership"]) + for row in rows + } + + # If the event is a join event then it will be in current state evnts + # map but not in the DB, so we have to explicitly insert it. + if event.type == EventTypes.Member: + for event_id in member_event_ids.itervalues(): + if event_id == event.event_id: + members[event_id] = (event.state_key, event.membership) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Found members %r: %r", self.room_id, members.values()) + + interested_in_user_ids = set( + user_id for user_id, membership in members.itervalues() + if membership == Membership.JOIN + ) + + logger.debug("Joined: %r", interested_in_user_ids) + + if_users_with_pushers = yield self.store.get_if_users_have_pushers( + interested_in_user_ids, + on_invalidate=self.invalidate_all_cb, + ) + + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher + ) + + logger.debug("With pushers: %r", user_ids) + + users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + self.room_id, on_invalidate=self.invalidate_all_cb, + ) + + logger.debug("With receipts: %r", users_with_receipts) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in interested_in_user_ids: + user_ids.add(uid) + + rules_by_user = yield self.store.bulk_get_push_rules( + user_ids, on_invalidate=self.invalidate_all_cb, + ) + + ret_rules_by_user.update( + item for item in rules_by_user.iteritems() if item[0] is not None + ) + + self.update_cache(sequence, members, ret_rules_by_user, state_group) + + def invalidate_all(self): + # Note: Don't hand this function directly to an invalidation callback + # as it keeps a reference to self and will stop this instance from being + # GC'd if it gets dropped from the rules_to_user cache. Instead use + # `self.invalidate_all_cb` + logger.debug("Invalidating RulesForRoom for %r", self.room_id) + self.sequence += 1 + self.state_group = object() + self.member_map = {} + self.rules_by_user = {} + + def update_cache(self, sequence, members, rules_by_user, state_group): + if sequence == self.sequence: + self.member_map.update(members) + self.rules_by_user = rules_by_user + self.state_group = state_group + + +class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): + # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, + # which namedtuple does for us (i.e. two _CacheContext are the same if + # their caches and keys match). This is important in particular to + # dedupe when we add callbacks to lru cache nodes, otherwise the number + # of callbacks would grow. + def __call__(self): + rules = self.cache.get(self.room_id, None, update_metrics=False) + if rules: + rules.invalidate_all() diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index c7afd11111..a69dda7b09 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -21,7 +21,6 @@ import logging from synapse.util.metrics import Measure from synapse.util.logcontext import LoggingContext -from mailer import Mailer logger = logging.getLogger(__name__) @@ -56,8 +55,10 @@ class EmailPusher(object): This shares quite a bit of code with httpusher: it would be good to factor out the common parts """ - def __init__(self, hs, pusherdict): + def __init__(self, hs, pusherdict, mailer): self.hs = hs + self.mailer = mailer + self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.pusher_id = pusherdict['id'] @@ -73,16 +74,6 @@ class EmailPusher(object): self.processing = False - if self.hs.config.email_enable_notifs: - if 'data' in pusherdict and 'brand' in pusherdict['data']: - app_name = pusherdict['data']['brand'] - else: - app_name = self.hs.config.email_app_name - - self.mailer = Mailer(self.hs, app_name) - else: - self.mailer = None - @defer.inlineCallbacks def on_started(self): if self.mailer is not None: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index f83aa7625c..b5cd9b426a 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -78,23 +78,17 @@ ALLOWED_ATTRS = { class Mailer(object): - def __init__(self, hs, app_name): + def __init__(self, hs, app_name, notif_template_html, notif_template_text): self.hs = hs + self.notif_template_html = notif_template_html + self.notif_template_text = notif_template_text + self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name + logger.info("Created Mailer for app_name %s" % app_name) - env = jinja2.Environment(loader=loader) - env.filters["format_ts"] = format_ts_filter - env.filters["mxc_to_http"] = self.mxc_to_http_filter - self.notif_template_html = env.get_template( - self.hs.config.email_notif_template_html - ) - self.notif_template_text = env.get_template( - self.hs.config.email_notif_template_text - ) @defer.inlineCallbacks def send_notification_mail(self, app_id, user_id, email_address, @@ -481,28 +475,6 @@ class Mailer(object): urllib.urlencode(params), ) - def mxc_to_http_filter(self, value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if '#' in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) - fragment = "#" + fragment - - params = { - "width": width, - "height": height, - "method": resize_method, - } - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - self.hs.config.public_baseurl, - serverAndMediaId, - urllib.urlencode(params), - fragment or "", - ) - def safe_markup(raw_html): return jinja2.Markup(bleach.linkify(bleach.clean( @@ -543,3 +515,52 @@ def string_ordinal_total(s): def format_ts_filter(value, format): return time.strftime(format, time.localtime(value / 1000)) + + +def load_jinja2_templates(config): + """Load the jinja2 email templates from disk + + Returns: + (notif_template_html, notif_template_text) + """ + logger.info("loading jinja2") + + loader = jinja2.FileSystemLoader(config.email_template_dir) + env = jinja2.Environment(loader=loader) + env.filters["format_ts"] = format_ts_filter + env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) + + notif_template_html = env.get_template( + config.email_notif_template_html + ) + notif_template_text = env.get_template( + config.email_notif_template_text + ) + + return notif_template_html, notif_template_text + + +def _create_mxc_to_http_filter(config): + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + serverAndMediaId = value[6:] + fragment = None + if '#' in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + fragment = "#" + fragment + + params = { + "width": width, + "height": height, + "method": resize_method, + } + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + config.public_baseurl, + serverAndMediaId, + urllib.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index de9c33b936..491f27bded 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -26,22 +26,54 @@ logger = logging.getLogger(__name__) # process works fine) try: from synapse.push.emailpusher import EmailPusher + from synapse.push.mailer import Mailer, load_jinja2_templates except: pass -def create_pusher(hs, pusherdict): - logger.info("trying to create_pusher for %r", pusherdict) +class PusherFactory(object): + def __init__(self, hs): + self.hs = hs - PUSHER_TYPES = { - "http": HttpPusher, - } + self.pusher_types = { + "http": HttpPusher, + } - logger.info("email enable notifs: %r", hs.config.email_enable_notifs) - if hs.config.email_enable_notifs: - PUSHER_TYPES["email"] = EmailPusher - logger.info("defined email pusher type") + logger.info("email enable notifs: %r", hs.config.email_enable_notifs) + if hs.config.email_enable_notifs: + self.mailers = {} # app_name -> Mailer - if pusherdict['kind'] in PUSHER_TYPES: - logger.info("found pusher") - return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict) + templates = load_jinja2_templates(hs.config) + self.notif_template_html, self.notif_template_text = templates + + self.pusher_types["email"] = self._create_email_pusher + + logger.info("defined email pusher type") + + def create_pusher(self, pusherdict): + logger.info("trying to create_pusher for %r", pusherdict) + + if pusherdict['kind'] in self.pusher_types: + logger.info("found pusher") + return self.pusher_types[pusherdict['kind']](self.hs, pusherdict) + + def _create_email_pusher(self, _hs, pusherdict): + app_name = self._app_name_from_pusherdict(pusherdict) + mailer = self.mailers.get(app_name) + if not mailer: + mailer = Mailer( + hs=self.hs, + app_name=app_name, + notif_template_html=self.notif_template_html, + notif_template_text=self.notif_template_text, + ) + self.mailers[app_name] = mailer + return EmailPusher(self.hs, pusherdict, mailer) + + def _app_name_from_pusherdict(self, pusherdict): + if 'data' in pusherdict and 'brand' in pusherdict['data']: + app_name = pusherdict['data']['brand'] + else: + app_name = self.hs.config.email_app_name + + return app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3837be523d..43cb6e9c01 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -16,7 +16,7 @@ from twisted.internet import defer -import pusher +from .pusher import PusherFactory from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.async import run_on_reactor @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class PusherPool: def __init__(self, _hs): self.hs = _hs + self.pusher_factory = PusherFactory(_hs) self.start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() @@ -48,7 +49,7 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - pusher.create_pusher(self.hs, { + self.pusher_factory.create_pusher({ "id": None, "user_name": user_id, "kind": kind, @@ -186,7 +187,7 @@ class PusherPool: logger.info("Starting %d pushers", len(pushers)) for pusherdict in pushers: try: - p = pusher.create_pusher(self.hs, pusherdict) + p = self.pusher_factory.create_pusher(pusherdict) except: logger.exception("Couldn't start a pusher: caught Exception") continue diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 4d4a435471..7687867aee 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,6 +16,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore +from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -45,6 +46,7 @@ class SlavedDeviceStore(BaseSlavedStore): _mark_as_sent_devices_by_remote_txn = ( DataStore._mark_as_sent_devices_by_remote_txn.__func__ ) + count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index fcaf58b93b..6cd3a843df 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore): get_current_state_ids = ( StateStore.__dict__["get_current_state_ids"] ) + get_state_group_delta = DataStore.get_state_group_delta.__func__ + _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] has_room_changed_since = DataStore.has_room_changed_since.__func__ get_unread_push_actions_for_user_in_range_for_http = ( diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index aa8d874f96..3d809d181b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.client.v2_alpha import ( devices, thirdparty, sendtodevice, + user_directory, ) from synapse.http.server import JsonResource @@ -100,3 +101,4 @@ class ClientRestResource(JsonResource): devices.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) + user_directory.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 771e127ab9..83e209d18f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -192,6 +192,7 @@ class SyncRestServlet(RestServlet): "invite": invited, "leave": archived, }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), } diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py new file mode 100644 index 0000000000..17d3dffc8f --- /dev/null +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class UserDirectorySearchRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/user_directory/search$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(UserDirectorySearchRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.user_directory_handler = hs.get_user_directory_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + yield self.auth.get_user_by_req(request, allow_guest=False) + body = parse_json_object_from_request(request) + + limit = body.get("limit", 10) + limit = min(limit, 50) + + try: + search_term = body["search_term"] + except: + raise SynapseError(400, "`search_term` is required field") + + results = yield self.user_directory_handler.search_users(search_term, limit) + + defer.returnValue((200, results)) + + +def register_servlets(hs, http_server): + UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 99760d622f..c680fddab5 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -434,6 +434,8 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og['og:description'] = summarize_paragraphs(text_nodes) + else: + og['og:description'] = summarize_paragraphs([og['og:description']]) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG diff --git a/synapse/server.py b/synapse/server.py index 12754c89ae..a38e5179e0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -49,9 +49,11 @@ from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.read_marker import ReadMarkerHandler +from synapse.handlers.user_directory import UserDirectoyHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier +from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.rest.media.v1.media_repository import MediaRepository from synapse.state import StateHandler @@ -135,6 +137,8 @@ class HomeServer(object): 'macaroon_generator', 'tcp_replication', 'read_marker_handler', + 'action_generator', + 'user_directory_handler', ] def __init__(self, hostname, **kwargs): @@ -299,6 +303,12 @@ class HomeServer(object): def build_tcp_replication(self): raise NotImplementedError() + def build_action_generator(self): + return ActionGenerator(self) + + def build_user_directory_handler(self): + return UserDirectoyHandler(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state.py b/synapse/state.py index 02fee47f39..a98145598f 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -170,9 +170,7 @@ class StateHandler(object): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state( - room_id, entry.state_id, entry.state - ) + joined_users = yield self.store.get_joined_users_from_state(room_id, entry) defer.returnValue(joined_users) @defer.inlineCallbacks @@ -181,12 +179,21 @@ class StateHandler(object): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_hosts_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_hosts = yield self.store.get_joined_hosts( - room_id, entry.state_id, entry.state - ) + joined_hosts = yield self.store.get_joined_hosts(room_id, entry) defer.returnValue(joined_hosts) @defer.inlineCallbacks + def get_is_host_in_room(self, room_id, host, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + logger.debug("calling resolve_state_groups from get_is_host_in_room") + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + is_host_joined = yield self.store.is_host_joined( + room_id, host, entry.state_id, entry.state + ) + defer.returnValue(is_host_joined) + + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """Build an EventContext structure for the event. @@ -195,12 +202,12 @@ class StateHandler(object): Returns: synapse.events.snapshot.EventContext: """ - context = EventContext() if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + context = EventContext() if old_state: context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state @@ -219,6 +226,7 @@ class StateHandler(object): defer.returnValue(context) if old_state: + context = EventContext() context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } @@ -239,19 +247,13 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") - if event.is_state(): - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - event_type=event.type, - state_key=event.state_key, - ) - else: - entry = yield self.resolve_state_groups( - event.room_id, [e for e, _ in event.prev_events], - ) + entry = yield self.resolve_state_groups( + event.room_id, [e for e, _ in event.prev_events], + ) curr_state = entry.state + context = EventContext() context.prev_state_ids = curr_state if event.is_state(): context.state_group = self.store.get_next_state_group() @@ -264,11 +266,14 @@ class StateHandler(object): context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids - if context.delta_ids is not None: - context.delta_ids = dict(context.delta_ids) - context.delta_ids[key] = event.event_id + if entry.state_group: + context.prev_group = entry.state_group + context.delta_ids = { + key: event.event_id + } + elif entry.prev_group: + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids else: if entry.state_group is None: entry.state_group = self.store.get_next_state_group() @@ -284,7 +289,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): + def resolve_state_groups(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -309,11 +314,13 @@ class StateHandler(object): if len(group_names) == 1: name, state_list = state_groups_ids.items().pop() + prev_group, delta_ids = yield self.store.get_state_group_delta(name) + defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, - prev_group=name, - delta_ids={}, + prev_group=prev_group, + delta_ids=delta_ids, )) with (yield self.resolve_linearizer.queue(group_names)): @@ -366,11 +373,11 @@ class StateHandler(object): prev_group = None delta_ids = None - for old_group, old_ids in state_groups_ids.items(): - if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + for old_group, old_ids in state_groups_ids.iteritems(): + if not set(new_state) - set(old_ids): n_delta_ids = { k: v - for k, v in new_state.items() + for k, v in new_state.iteritems() if old_ids.get(k) != v } if not delta_ids or len(n_delta_ids) < len(delta_ids): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 349f96e24b..5e72985cda 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,6 +49,7 @@ from .tags import TagsStore from .account_data import AccountDataStore from .openid import OpenIdStore from .client_ips import ClientIpStore +from .user_directory import UserDirectoryStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .engines import PostgresEngine @@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore, ClientIpStore, DeviceStore, DeviceInboxStore, + UserDirectoryStore, ): def __init__(self, db_conn, hs): @@ -221,6 +223,18 @@ class DataStore(RoomMemberStore, RoomStore, "DeviceListFederationStreamChangeCache", device_list_max, ) + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f214b9d4c4..51730a88bf 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -438,6 +438,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) + def _simple_insert_many(self, table, values, desc): + return self.runInteraction( + desc, self._simple_insert_many_txn, table, values + ) + @staticmethod def _simple_insert_many_txn(txn, table, values): if not values: diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 514570561f..532df736a5 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re import simplejson as json from twisted.internet import defer @@ -36,16 +37,31 @@ class ApplicationServiceStore(SQLBaseStore): hs.config.app_service_config_files ) + # We precompie a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in self.services_cache + for regex in service.get_exlusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + self.exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + self.exclusive_user_regex = None + def get_app_services(self): return self.services_cache def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service + """Check if the user is one associated with an app service (exclusively) """ - for service in self.services_cache: - if service.is_interested_in_user(user_id): - return True - return False + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 747d2df622..014ab635b7 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -20,6 +20,8 @@ from twisted.internet import defer from ._base import Cache from . import background_updates +import os + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -28,12 +30,15 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class ClientIpStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, - max_entries=5000, + max_entries=50000 * CACHE_SIZE_FACTOR, ) super(ClientIpStore, self).__init__(hs) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d9936c88bb..bb27fd1f70 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes + FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? AND stream_id <= ? """ @@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # First we DELETE all rows such that only the latest row for each - # (destination, user_id is left. We do this by selecting first and - # deleting. + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. sql = """ - SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id - HAVING count(*) > 1 """ txn.execute(sql, (destination, stream_id,)) rows = txn.fetchall() sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND stream_id < ? + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? """ txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows) + sql, ((row[1], destination, row[0],) for row in rows if row[2]) ) - # Mark everything that is left as sent sql = """ - UPDATE device_lists_outbound_pokes SET sent = ? + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + ) + + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (True, destination, stream_id,)) + txn.execute(sql, (destination, stream_id,)) @defer.inlineCallbacks def get_user_whose_devices_changed(self, from_key): @@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore): ) ) + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) return self.runInteraction( diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index e00f31da2b..2cebb203c6 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore): for algorithm, key_id, json_bytes in new_keys ], ) - txn.call_after( - self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) ) yield self.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys @@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore): ) for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) - txn.call_after( - self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) ) return result return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): - yield self._simple_delete( - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_device_keys_by_device" - ) - yield self._simple_delete( - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_one_time_keys_by_device" + def delete_e2e_keys_by_device_txn(txn): + self._simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + return self.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - self.count_e2e_one_time_keys.invalidate((user_id, device_id,)) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 519059c306..e8133de2fa 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + def __init__(self, hs): super(EventFederationStore, self).__init__(hs) + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def get_auth_chain(self, event_ids): - return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given, + ).addCallback(self._get_events) + + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def get_auth_chain_ids(self, event_ids): + Returns: + list of event_ids + """ return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_ids + event_ids, include_given ) - def _get_auth_chain_ids_txn(self, txn, event_ids): - results = set() + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" @@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id,)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + + defer.returnValue(batch_size) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 73283eb4c7..c80d181fc7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -648,9 +648,10 @@ class EventsStore(SQLBaseStore): list of the event ids which are the forward extremities. """ - self._update_current_state_txn(txn, current_state_for_room) - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_current_state_txn(txn, current_state_for_room, max_stream_order) + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -713,7 +714,7 @@ class EventsStore(SQLBaseStore): backfilled=backfilled, ) - def _update_current_state_txn(self, txn, state_delta_by_room): + def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): for room_id, current_state_tuple in state_delta_by_room.iteritems(): to_delete, to_insert, _ = current_state_tuple txn.executemany( @@ -735,6 +736,29 @@ class EventsStore(SQLBaseStore): ], ) + state_deltas = {key: None for key in to_delete} + state_deltas.update(to_insert) + + self._simple_insert_many_txn( + txn, + table="current_state_delta_stream", + values=[ + { + "stream_id": max_stream_order, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": ev_id, + "prev_event_id": to_delete.get(key, None), + } + for key, ev_id in state_deltas.iteritems() + ] + ) + + self._curr_state_delta_stream_cache.entity_has_changed( + room_id, max_stream_order, + ) + # Invalidate the various caches # Figure out the changes of membership to invalidate the @@ -743,11 +767,7 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - state_key for ev_type, state_key in to_delete.iterkeys() - if ev_type == EventTypes.Member - ) - members_changed.update( - state_key for ev_type, state_key in to_insert.iterkeys() + state_key for ev_type, state_key in state_deltas if ev_type == EventTypes.Member ) @@ -1120,6 +1140,7 @@ class EventsStore(SQLBaseStore): } for event, _ in events_and_contexts for auth_id, _ in event.auth_events + if event.is_state() ], ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6e623843d5..eaba699e29 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 41 +SCHEMA_VERSION = 42 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 0a819d32c5..8758b1c0c7 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map): class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index efb90c3c91..f42b8014c7 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore): return # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) if res: if isinstance(res, defer.Deferred) and res.called: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0829ae5bee..8656455f6e 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from collections import namedtuple from ._base import SQLBaseStore +from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii @@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore): context=context, ) - def get_joined_users_from_state(self, room_id, state_group, state_ids): + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids, + room_id, state_group, state_entry.state, context=state_entry, ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, @@ -534,7 +536,8 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(False) - def get_joined_hosts(self, room_id, state_group, state_ids): + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -543,33 +546,20 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_hosts( - room_id, state_group, state_ids + room_id, state_group, state_entry.state, state_entry=state_entry, ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - joined_hosts = set() - for etype, state_key in current_state_ids: - if etype == EventTypes.Member: - try: - host = get_domain_from_id(state_key) - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - if host in joined_hosts: - continue - - event_id = current_state_ids[(etype, state_key)] - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - joined_hosts.add(intern_string(host)) + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) defer.returnValue(joined_hosts) @@ -647,3 +637,75 @@ class RoomMemberStore(SQLBaseStore): yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) defer.returnValue(result) + + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry, + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql new file mode 100644 index 0000000000..b8821ac759 --- /dev/null +++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..ea6a18196d --- /dev/null +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 85acf2ad1e..c3eecbe824 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -98,6 +98,45 @@ class StateStore(SQLBaseStore): _get_current_state_ids_txn, ) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return None, None + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + } + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: @@ -563,20 +602,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -590,7 +631,7 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { k: v for k, v in state_dict_ids.iteritems() @@ -607,7 +648,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -624,7 +665,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -653,19 +694,7 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. for group, group_state_dict in group_to_state_dict.iteritems(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] + state_dict = results[group] state_dict.update( ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) @@ -677,17 +706,9 @@ class StateStore(SQLBaseStore): key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.iteritems(): - results[group] = { - key: event_id - for key, event_id in state_dict.iteritems() - if event_id - } - defer.returnValue(results) def get_next_state_group(self): diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py new file mode 100644 index 0000000000..6a4bf63f0d --- /dev/null +++ b/synapse/storage/user_directory.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id + +import re + + +class UserDirectoryStore(SQLBaseStore): + + @cachedInlineCallbacks(cache_context=True) + def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): + """Check if the room is either world_readable or publically joinable + """ + current_state_ids = yield self.get_current_state_ids( + room_id, on_invalidate=cache_context.invalidate + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + defer.returnValue(True) + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks + def add_users_to_public_room(self, room_id, user_ids): + """Add user to the list of users in public rooms + + Args: + room_id (str): A room_id that all users are in that is world_readable + or publically joinable + user_ids (list(str)): Users to add + """ + yield self._simple_insert_many( + table="users_in_pubic_room", + values=[ + { + "user_id": user_id, + "room_id": room_id, + } + for user_id in user_ids + ], + desc="add_users_to_public_room" + ) + for user_id in user_ids: + self.get_user_in_public_room.invalidate((user_id,)) + + def add_profiles_to_user_dir(self, room_id, users_with_profile): + """Add profiles to the user directory + + Args: + room_id (str): A room_id that all users are joined to + users_with_profile (dict): Users to add to directory in the form of + mapping of user_id -> ProfileInfo + """ + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + args = ( + ( + user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), + profile.display_name, + ) + for user_id, profile in users_with_profile.iteritems() + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT INTO user_directory_search(user_id, value) + VALUES (?,?) + """ + args = ( + ( + user_id, + "%s %s" % (user_id, p.display_name,) if p.display_name else user_id + ) + for user_id, p in users_with_profile.iteritems() + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + def _add_profiles_to_user_dir_txn(txn): + txn.executemany(sql, args) + self._simple_insert_many_txn( + txn, + table="user_directory", + values=[ + { + "user_id": user_id, + "room_id": room_id, + "display_name": profile.display_name, + "avatar_url": profile.avatar_url, + } + for user_id, profile in users_with_profile.iteritems() + ] + ) + for user_id in users_with_profile: + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + + return self.runInteraction( + "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_user_dir(self, user_id, room_id): + yield self._simple_update_one( + table="user_directory", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_user_dir", + ) + self.get_user_in_directory.invalidate((user_id,)) + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + def _update_profile_in_user_dir_txn(txn): + self._simple_update_one_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + updatevalues={"display_name": display_name, "avatar_url": avatar_url}, + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + args = ( + get_localpart_from_id(user_id), get_domain_from_id(user_id), + display_name, + user_id, + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + UPDATE user_directory_search + set value = ? + WHERE user_id = ? + """ + args = ( + "%s %s" % (user_id, display_name,) if display_name else user_id, + user_id, + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.execute(sql, args) + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_public_user_list(self, user_id, room_id): + yield self._simple_update_one( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_public_user_list", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + ) + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + txn.call_after( + self.get_user_in_public_room.invalidate, (user_id,) + ) + return self.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn, + ) + + @defer.inlineCallbacks + def remove_from_user_in_public_room(self, user_id): + yield self._simple_delete( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + desc="remove_from_user_in_public_room", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def get_users_in_public_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="users_in_pubic_room", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_public_due_to_room", + ) + + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="user_directory", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + def get_all_rooms(self): + """Get all room_ids we've ever known about + """ + return self._simple_select_onecol( + table="current_state_events", + keyvalues={}, + retcol="DISTINCT room_id", + desc="get_all_rooms", + ) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_pubic_room") + txn.call_after(self.get_user_in_directory.invalidate_all) + txn.call_after(self.get_user_in_public_room.invalidate_all) + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("room_id", "display_name", "avatar_url",), + allow_none=True, + desc="get_user_in_directory", + ) + + @cached() + def get_user_in_public_room(self, user_id): + return self._simple_select_one( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + retcols=("room_id",), + allow_none=True, + desc="get_user_in_public_room", + ) + + def get_user_directory_stream_pos(self): + return self._simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) + + @defer.inlineCallbacks + def search_user_dir(self, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + + search_query = _parse_query(self.database_engine, search_term) + + if isinstance(self.database_engine, PostgresEngine): + # We order by rank and then if they have profile info + sql = """ + SELECT user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory USING (user_id) + INNER JOIN users_in_pubic_room USING (user_id) + WHERE vector @@ to_tsquery('english', ?) + ORDER BY + ts_rank_cd(vector, to_tsquery('english', ?), 1) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ + args = (search_query, search_query, limit + 1,) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + SELECT user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory USING (user_id) + INNER JOIN users_in_pubic_room USING (user_id) + WHERE value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ + args = (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self._execute( + "search_user_dir", self.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + defer.returnValue({ + "limited": limited, + "results": results, + }) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join("(%s:* | %s)" % (result, result,) for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join("(%s* | %s)" % (result, result,) for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") diff --git a/synapse/types.py b/synapse/types.py index 445bdcb4d7..111948540d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -62,6 +62,13 @@ def get_domain_from_id(string): return string[idx + 1:] +def get_localpart_from_id(string): + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + class DomainSpecificString( namedtuple("DomainSpecificString", ("localpart", "domain")) ): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 48dcbafeef..cbdff86596 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.invalidate_all = cache.invalidate_all wrapped.cache = cache + wrapped.num_args = self.num_args obj.__dict__[self.orig.__name__] = wrapped @@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase): ) def __get__(self, obj, objtype=None): - - cache = getattr(obj, self.cached_method_name).cache + cached_method = getattr(obj, self.cached_method_name) + cache = cached_method.cache + num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args, **kwargs): @@ -469,12 +471,23 @@ class CacheListDescriptor(_CacheDescriptorBase): results = {} cached_defers = {} missing = [] - for arg in list_args: + + # If the cache takes a single arg then that is used as the key, + # otherwise a tuple is used. + if num_args == 1: + def cache_get(arg): + return cache.get(arg, callback=invalidate_callback) + else: key = list(keyargs) - key[self.list_pos] = arg + def cache_get(arg): + key[self.list_pos] = arg + return cache.get(tuple(key), callback=invalidate_callback) + + for arg in list_args: try: - res = cache.get(tuple(key), callback=invalidate_callback) + res = cache_get(arg) + if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -505,17 +518,28 @@ class CacheListDescriptor(_CacheDescriptorBase): observer = ObservableDeferred(observer) - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) + if num_args == 1: + cache.set( + arg, observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, arg) + else: + key = list(keyargs) + key[self.list_pos] = arg + cache.set( + tuple(key), observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index cb6933c61c..d4105822b3 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,17 @@ import logging logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): + """Returned when getting an entry from the cache + + Attributes: + full (bool): Whether the cache has the full or dict or just some keys. + If not full then not all requested keys will necessarily be present + in `value` + known_absent (set): Keys that were looked up in the dict and were not + there. + value (dict): The full or partial dict value + """ def __len__(self): return len(self.value) @@ -58,21 +68,31 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): + """Fetch an entry out of the cache + + Args: + key + dict_key(list): If given a set of keys then return only those keys + that exist in the cache. + + Returns: + DictionaryEntry + """ entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) + return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) else: - return DictionaryEntry(entry.full, { + return DictionaryEntry(entry.full, entry.known_absent, { k: entry.value[k] for k in dict_keys if k in entry.value }) self.metrics.inc_misses() - return DictionaryEntry(False, {}) + return DictionaryEntry(False, set(), {}) def invalidate(self, key): self.check_thread() @@ -87,19 +107,34 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False): + def update(self, sequence, key, value, full=False, known_absent=None): + """Updates the entry in the cache + + Args: + sequence + key + value (dict): The value to update the cache with. + full (bool): Whether the given value is the full dict, or just a + partial subset there of. If not full then any existing entries + for the key will be updated. + known_absent (set): Set of keys that we know don't exist in the full + dict. + """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) + if known_absent is None: + known_absent = set() if full: - self._insert(key, value) + self._insert(key, value, known_absent) else: - self._update_or_insert(key, value) + self._update_or_insert(key, value, known_absent) - def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + def _update_or_insert(self, key, value, known_absent): + entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) entry.value.update(value) + entry.known_absent.update(known_absent) - def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) + def _insert(self, key, value, known_absent): + self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 70fe00ce0b..c498aee46c 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -89,6 +89,21 @@ class StreamChangeCache(object): return result + def has_any_entity_changed(self, stream_pos): + """Returns if any entity has changed + """ + assert type(stream_pos) is int + + if stream_pos >= self._earliest_known_stream_pos: + self.metrics.inc_hits() + if stream_pos >= max(self._cache): + return False + else: + return True + else: + self.metrics.inc_misses() + return True + def get_all_entities_changed(self, stream_pos): """Returns all entites that have had new things since the given position. If the position is too old it will return None. |