diff options
Diffstat (limited to 'synapse/storage/account_data.py')
-rw-r--r-- | synapse/storage/account_data.py | 212 |
1 files changed, 151 insertions, 61 deletions
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index aa84ffc2b0..bbc3355c73 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer +import abc +import logging -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from canonicaljson import json -import ujson as json -import logging +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: @@ -85,25 +114,7 @@ class AccountDataStore(SQLBaseStore): else: defer.returnValue(None) - @cachedList(cached_method_name="get_global_account_data_by_type_for_user", - num_args=2, list_name="user_ids", inlineCallbacks=True) - def get_global_account_data_by_type_for_users(self, data_type, user_ids): - rows = yield self._simple_select_many_batch( - table="account_data", - column="user_id", - iterable=user_ids, - keyvalues={ - "account_data_type": data_type, - }, - retcols=("user_id", "content",), - desc="get_global_account_data_by_type_for_users", - ) - - defer.returnValue({ - row["user_id"]: json.loads(row["content"]) if row["content"] else None - for row in rows - }) - + @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -127,6 +138,38 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", + get_account_data_for_room_and_type_txn, + ) + def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server @@ -209,6 +252,36 @@ class AccountDataStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -222,9 +295,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,18 +310,23 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type,), content, ) result = self._account_data_id_gen.get_current_token() @@ -263,9 +344,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,37 +358,43 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) |