diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 17 | ||||
-rw-r--r-- | synapse/storage/_base.py | 34 | ||||
-rw-r--r-- | synapse/storage/appservice.py | 26 | ||||
-rw-r--r-- | synapse/storage/client_ips.py | 7 | ||||
-rw-r--r-- | synapse/storage/devices.py | 45 | ||||
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 35 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 91 | ||||
-rw-r--r-- | synapse/storage/events.py | 44 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 2 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 4 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 4 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 159 | ||||
-rw-r--r-- | synapse/storage/schema/delta/42/current_state_delta.sql | 26 | ||||
-rw-r--r-- | synapse/storage/schema/delta/42/device_list_last_id.sql | 33 | ||||
-rw-r--r-- | synapse/storage/schema/delta/42/event_auth_state_only.sql | 17 | ||||
-rw-r--r-- | synapse/storage/schema/delta/42/user_dir.py | 84 | ||||
-rw-r--r-- | synapse/storage/state.py | 104 | ||||
-rw-r--r-- | synapse/storage/user_directory.py | 506 |
18 files changed, 1097 insertions, 141 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 2970df138b..f119c5a758 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,6 +49,7 @@ from .tags import TagsStore from .account_data import AccountDataStore from .openid import OpenIdStore from .client_ips import ClientIpStore +from .user_directory import UserDirectoryStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .engines import PostgresEngine @@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore, ClientIpStore, DeviceStore, DeviceInboxStore, + UserDirectoryStore, ): def __init__(self, db_conn, hs): @@ -221,11 +223,24 @@ class DataStore(RoomMemberStore, RoomStore, "DeviceListFederationStreamChangeCache", device_list_max, ) + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", database_engine=self.database_engine, - after_callbacks=[] + after_callbacks=[], + final_callbacks=[], ) self._find_stream_orderings_for_times_txn(cur) cur.close() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 58b73af7d2..51730a88bf 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -52,13 +52,17 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name", "database_engine", "after_callbacks"] + __slots__ = [ + "txn", "name", "database_engine", "after_callbacks", "final_callbacks", + ] - def __init__(self, txn, name, database_engine, after_callbacks): + def __init__(self, txn, name, database_engine, after_callbacks, + final_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "final_callbacks", final_callbacks) def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the @@ -67,6 +71,9 @@ class LoggingTransaction(object): """ self.after_callbacks.append((callback, args, kwargs)) + def call_finally(self, callback, *args, **kwargs): + self.final_callbacks.append((callback, args, kwargs)) + def __getattr__(self, name): return getattr(self.txn, name) @@ -217,8 +224,8 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, logging_context, - func, *args, **kwargs): + def _new_transaction(self, conn, desc, after_callbacks, final_callbacks, + logging_context, func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID @@ -237,7 +244,8 @@ class SQLBaseStore(object): try: txn = conn.cursor() txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks + txn, name, self.database_engine, after_callbacks, + final_callbacks, ) r = func(txn, *args, **kwargs) conn.commit() @@ -298,6 +306,7 @@ class SQLBaseStore(object): start_time = time.time() * 1000 after_callbacks = [] + final_callbacks = [] def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: @@ -309,7 +318,7 @@ class SQLBaseStore(object): current_context.copy_to(context) return self._new_transaction( - conn, desc, after_callbacks, current_context, + conn, desc, after_callbacks, final_callbacks, current_context, func, *args, **kwargs ) @@ -318,9 +327,13 @@ class SQLBaseStore(object): result = yield self._db_pool.runWithConnection( inner_func, *args, **kwargs ) - finally: + for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) + finally: + for after_callback, after_args, after_kwargs in final_callbacks: + after_callback(*after_args, **after_kwargs) + defer.returnValue(result) @defer.inlineCallbacks @@ -425,6 +438,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) + def _simple_insert_many(self, table, values, desc): + return self.runInteraction( + desc, self._simple_insert_many_txn, table, values + ) + @staticmethod def _simple_insert_many_txn(txn, table, values): if not values: @@ -936,7 +954,7 @@ class SQLBaseStore(object): # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() - txn.call_after(ctx.__exit__, None, None, None) + txn.call_finally(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) self._simple_insert_txn( diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 514570561f..532df736a5 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re import simplejson as json from twisted.internet import defer @@ -36,16 +37,31 @@ class ApplicationServiceStore(SQLBaseStore): hs.config.app_service_config_files ) + # We precompie a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in self.services_cache + for regex in service.get_exlusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + self.exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + self.exclusive_user_regex = None + def get_app_services(self): return self.services_cache def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service + """Check if the user is one associated with an app service (exclusively) """ - for service in self.services_cache: - if service.is_interested_in_user(user_id): - return True - return False + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 747d2df622..014ab635b7 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -20,6 +20,8 @@ from twisted.internet import defer from ._base import Cache from . import background_updates +import os + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -28,12 +30,15 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) + + class ClientIpStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, - max_entries=5000, + max_entries=50000 * CACHE_SIZE_FACTOR, ) super(ClientIpStore, self).__init__(hs) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d9936c88bb..bb27fd1f70 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes + FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? AND stream_id <= ? """ @@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # First we DELETE all rows such that only the latest row for each - # (destination, user_id is left. We do this by selecting first and - # deleting. + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. sql = """ - SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id - HAVING count(*) > 1 """ txn.execute(sql, (destination, stream_id,)) rows = txn.fetchall() sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND stream_id < ? + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? """ txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows) + sql, ((row[1], destination, row[0],) for row in rows if row[2]) ) - # Mark everything that is left as sent sql = """ - UPDATE device_lists_outbound_pokes SET sent = ? + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + ) + + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (True, destination, stream_id,)) + txn.execute(sql, (destination, stream_id,)) @defer.inlineCallbacks def get_user_whose_devices_changed(self, from_key): @@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore): ) ) + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) return self.runInteraction( diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index e00f31da2b..2cebb203c6 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore): for algorithm, key_id, json_bytes in new_keys ], ) - txn.call_after( - self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) ) yield self.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys @@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore): ) for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) - txn.call_after( - self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) ) return result return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): - yield self._simple_delete( - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_device_keys_by_device" - ) - yield self._simple_delete( - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_one_time_keys_by_device" + def delete_e2e_keys_by_device_txn(txn): + self._simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + return self.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - self.count_e2e_one_time_keys.invalidate((user_id, device_id,)) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 519059c306..e8133de2fa 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + def __init__(self, hs): super(EventFederationStore, self).__init__(hs) + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def get_auth_chain(self, event_ids): - return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given, + ).addCallback(self._get_events) + + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def get_auth_chain_ids(self, event_ids): + Returns: + list of event_ids + """ return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_ids + event_ids, include_given ) - def _get_auth_chain_ids_txn(self, txn, event_ids): - results = set() + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" @@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id,)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + + defer.returnValue(batch_size) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f29d71589d..f60ed889d5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -647,9 +647,10 @@ class EventsStore(SQLBaseStore): list of the event ids which are the forward extremities. """ - self._update_current_state_txn(txn, current_state_for_room) - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_current_state_txn(txn, current_state_for_room, max_stream_order) + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -712,7 +713,7 @@ class EventsStore(SQLBaseStore): backfilled=backfilled, ) - def _update_current_state_txn(self, txn, state_delta_by_room): + def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): for room_id, current_state_tuple in state_delta_by_room.iteritems(): to_delete, to_insert, _ = current_state_tuple txn.executemany( @@ -734,6 +735,29 @@ class EventsStore(SQLBaseStore): ], ) + state_deltas = {key: None for key in to_delete} + state_deltas.update(to_insert) + + self._simple_insert_many_txn( + txn, + table="current_state_delta_stream", + values=[ + { + "stream_id": max_stream_order, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": ev_id, + "prev_event_id": to_delete.get(key, None), + } + for key, ev_id in state_deltas.iteritems() + ] + ) + + self._curr_state_delta_stream_cache.entity_has_changed( + room_id, max_stream_order, + ) + # Invalidate the various caches # Figure out the changes of membership to invalidate the @@ -742,11 +766,7 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - state_key for ev_type, state_key in to_delete.iterkeys() - if ev_type == EventTypes.Member - ) - members_changed.update( - state_key for ev_type, state_key in to_insert.iterkeys() + state_key for ev_type, state_key in state_deltas if ev_type == EventTypes.Member ) @@ -755,6 +775,11 @@ class EventsStore(SQLBaseStore): txn, self.get_rooms_for_user, (member,) ) + for host in set(get_domain_from_id(u) for u in members_changed): + self._invalidate_cache_and_stream( + txn, self.is_host_joined, (room_id, host) + ) + self._invalidate_cache_and_stream( txn, self.get_users_in_room, (room_id,) ) @@ -1119,6 +1144,7 @@ class EventsStore(SQLBaseStore): } for event, _ in events_and_contexts for auth_id, _ in event.auth_events + if event.is_state() ], ) @@ -1418,7 +1444,7 @@ class EventsStore(SQLBaseStore): ] rows = self._new_transaction( - conn, "do_fetch", [], None, self._fetch_event_rows, event_ids + conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids ) row_dict = { diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6e623843d5..eaba699e29 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 41 +SCHEMA_VERSION = 42 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 0a819d32c5..8758b1c0c7 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map): class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index efb90c3c91..f42b8014c7 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore): return # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) if res: if isinstance(res, defer.Deferred) and res.called: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0829ae5bee..457ca288d0 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from collections import namedtuple from ._base import SQLBaseStore +from synapse.util.async import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii @@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore): context=context, ) - def get_joined_users_from_state(self, room_id, state_group, state_ids): + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids, + room_id, state_group, state_entry.state, context=state_entry, ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, @@ -499,42 +501,40 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(users_in_room) - def is_host_joined(self, room_id, host, state_group, state_ids): - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if '%' in host or '_' in host: + raise Exception("Invalid host name") - return self._is_host_joined( - room_id, host, state_group, state_ids - ) + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ - @cachedInlineCallbacks(num_args=3) - def _is_host_joined(self, room_id, host, state_group, current_state_ids): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host - for (etype, state_key), event_id in current_state_ids.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue + rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + if not rows: + defer.returnValue(False) - defer.returnValue(False) + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") - def get_joined_hosts(self, room_id, state_group, state_ids): + defer.returnValue(True) + + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -543,33 +543,20 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_hosts( - room_id, state_group, state_ids + room_id, state_group, state_entry.state, state_entry=state_entry, ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - joined_hosts = set() - for etype, state_key in current_state_ids: - if etype == EventTypes.Member: - try: - host = get_domain_from_id(state_key) - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - if host in joined_hosts: - continue - - event_id = current_state_ids[(etype, state_key)] - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - joined_hosts.add(intern_string(host)) + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) defer.returnValue(joined_hosts) @@ -647,3 +634,75 @@ class RoomMemberStore(SQLBaseStore): yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) defer.returnValue(result) + + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry, + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql new file mode 100644 index 0000000000..b8821ac759 --- /dev/null +++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..ea6a18196d --- /dev/null +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 85acf2ad1e..d1e679719b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from collections import namedtuple import logging @@ -29,6 +30,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -98,6 +109,46 @@ class StateStore(SQLBaseStore): _get_current_state_ids_txn, ) + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return _GetStateGroupDelta(prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + }) + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: @@ -184,6 +235,19 @@ class StateStore(SQLBaseStore): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if context.prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": context.prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (context.prev_group,) + ) + potential_hops = self._count_state_group_hops_txn( txn, context.prev_group ) @@ -563,20 +627,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -590,7 +656,7 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { k: v for k, v in state_dict_ids.iteritems() @@ -607,7 +673,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -624,7 +690,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -653,19 +719,7 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. for group, group_state_dict in group_to_state_dict.iteritems(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] + state_dict = results[group] state_dict.update( ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) @@ -677,17 +731,9 @@ class StateStore(SQLBaseStore): key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.iteritems(): - results[group] = { - key: event_id - for key, event_id in state_dict.iteritems() - if event_id - } - defer.returnValue(results) def get_next_state_group(self): diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py new file mode 100644 index 0000000000..137aca2881 --- /dev/null +++ b/synapse/storage/user_directory.py @@ -0,0 +1,506 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id + +import re + + +class UserDirectoryStore(SQLBaseStore): + + @cachedInlineCallbacks(cache_context=True) + def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): + """Check if the room is either world_readable or publically joinable + """ + current_state_ids = yield self.get_current_state_ids( + room_id, on_invalidate=cache_context.invalidate + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + defer.returnValue(True) + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks + def add_users_to_public_room(self, room_id, user_ids): + """Add user to the list of users in public rooms + + Args: + room_id (str): A room_id that all users are in that is world_readable + or publically joinable + user_ids (list(str)): Users to add + """ + yield self._simple_insert_many( + table="users_in_pubic_room", + values=[ + { + "user_id": user_id, + "room_id": room_id, + } + for user_id in user_ids + ], + desc="add_users_to_public_room" + ) + for user_id in user_ids: + self.get_user_in_public_room.invalidate((user_id,)) + + def add_profiles_to_user_dir(self, room_id, users_with_profile): + """Add profiles to the user directory + + Args: + room_id (str): A room_id that all users are joined to + users_with_profile (dict): Users to add to directory in the form of + mapping of user_id -> ProfileInfo + """ + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + args = ( + ( + user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), + profile.display_name, + ) + for user_id, profile in users_with_profile.iteritems() + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT INTO user_directory_search(user_id, value) + VALUES (?,?) + """ + args = ( + ( + user_id, + "%s %s" % (user_id, p.display_name,) if p.display_name else user_id + ) + for user_id, p in users_with_profile.iteritems() + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + def _add_profiles_to_user_dir_txn(txn): + txn.executemany(sql, args) + self._simple_insert_many_txn( + txn, + table="user_directory", + values=[ + { + "user_id": user_id, + "room_id": room_id, + "display_name": profile.display_name, + "avatar_url": profile.avatar_url, + } + for user_id, profile in users_with_profile.iteritems() + ] + ) + for user_id in users_with_profile: + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + + return self.runInteraction( + "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_user_dir(self, user_id, room_id): + yield self._simple_update_one( + table="user_directory", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_user_dir", + ) + self.get_user_in_directory.invalidate((user_id,)) + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id): + def _update_profile_in_user_dir_txn(txn): + new_entry = self._simple_upsert_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + insertion_values={"room_id": room_id}, + values={"display_name": display_name, "avatar_url": avatar_url}, + lock=False, # We're only inserter + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + if new_entry: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + else: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), get_domain_from_id(user_id), + display_name, user_id, + ) + ) + elif isinstance(self.database_engine, Sqlite3Engine): + value = "%s %s" % (user_id, display_name,) if display_name else user_id + self._simple_upsert_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + values={"value": value}, + lock=False, # We're only inserter + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_public_user_list(self, user_id, room_id): + yield self._simple_update_one( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_public_user_list", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + ) + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + txn.call_after( + self.get_user_in_public_room.invalidate, (user_id,) + ) + return self.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn, + ) + + @defer.inlineCallbacks + def remove_from_user_in_public_room(self, user_id): + yield self._simple_delete( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + desc="remove_from_user_in_public_room", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def get_users_in_public_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="users_in_pubic_room", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_public_due_to_room", + ) + + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="user_directory", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + def get_all_rooms(self): + """Get all room_ids we've ever known about + """ + return self._simple_select_onecol( + table="current_state_events", + keyvalues={}, + retcol="DISTINCT room_id", + desc="get_all_rooms", + ) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_pubic_room") + txn.call_after(self.get_user_in_directory.invalidate_all) + txn.call_after(self.get_user_in_public_room.invalidate_all) + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("room_id", "display_name", "avatar_url",), + allow_none=True, + desc="get_user_in_directory", + ) + + @cached() + def get_user_in_public_room(self, user_id): + return self._simple_select_one( + table="users_in_pubic_room", + keyvalues={"user_id": user_id}, + retcols=("room_id",), + allow_none=True, + desc="get_user_in_public_room", + ) + + def get_user_directory_stream_pos(self): + return self._simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) + + @defer.inlineCallbacks + def search_user_dir(self, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + if isinstance(self.database_engine, PostgresEngine): + full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + + # We order by rank and then if they have profile info + # The ranking algorithm is hand tweaked for "best" results. Broadly + # the idea is we give a higher weight to exact matches. + # The array of numbers are the weights for the various part of the + # search: (domain, _, display name, localpart) + sql = """ + SELECT user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory USING (user_id) + INNER JOIN users_in_pubic_room USING (user_id) + WHERE vector @@ to_tsquery('english', ?) + ORDER BY + 2 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ + args = (full_query, exact_query, prefix_query, limit + 1,) + elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_sqlite(search_term) + + sql = """ + SELECT user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory USING (user_id) + INNER JOIN users_in_pubic_room USING (user_id) + WHERE value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ + args = (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self._execute( + "search_user_dir", self.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + defer.returnValue({ + "limited": limited, + "results": results, + }) + + +def _parse_query_sqlite(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return " & ".join("(%s* | %s)" % (result, result,) for result in results) + + +def _parse_query_postgres(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + both = " & ".join("(%s:* | %s)" % (result, result,) for result in results) + exact = " & ".join("%s" % (result,) for result in results) + prefix = " & ".join("%s:*" % (result,) for result in results) + + return both, exact, prefix |