diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 56a0bde549..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
-import ujson as json
+import abc
+import simplejson as json
import logging
logger = logging.getLogger(__name__)
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_account_data_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ account_max = self.get_max_account_data_stream_id()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+ @abc.abstractmethod
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream ID for account data stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
for row in rows
})
+ @cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
@@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
+ @cached(num_args=3, max_entries=5000)
+ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ """Get the client account_data of given type for a user for a room.
+
+ Args:
+ user_id(str): The user to get the account_data for.
+ room_id(str): The room to get the account_data for.
+ account_data_type (str): The account data type to get.
+ Returns:
+ A deferred of the room account_data for that type, or None if
+ there isn't any set.
+ """
+ def get_account_data_for_room_and_type_txn(txn):
+ content_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={
+ "user_id": user_id,
+ "room_id": room_id,
+ "account_data_type": account_data_type,
+ },
+ retcol="content",
+ allow_none=True
+ )
+
+ return json.loads(content_json) if content_json else None
+
+ return self.runInteraction(
+ "get_account_data_for_room_and_type",
+ get_account_data_for_room_and_type_txn,
+ )
+
def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit):
"""Get all the client account_data that has changed on the server
@@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
+ @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+ ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", ignorer_user_id,
+ on_invalidate=cache_context.invalidate,
+ )
+ if not ignored_account_data:
+ defer.returnValue(False)
+
+ defer.returnValue(
+ ignored_user_id in ignored_account_data.get("ignored_users", {})
+ )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ def __init__(self, db_conn, hs):
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ super(AccountDataStore, self).__init__(db_conn, hs)
+
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream id for the private user data stream
+
+ Returns:
+ A deferred int.
+ """
+ return self._account_data_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@@ -251,6 +343,10 @@ class AccountDataStore(SQLBaseStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id,))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type,), content,
+ )
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@@ -321,16 +417,3 @@ class AccountDataStore(SQLBaseStore):
"update_account_data_max_stream_id",
_update,
)
-
- @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
- def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
- ignored_account_data = yield self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", ignorer_user_id,
- on_invalidate=cache_context.invalidate,
- )
- if not ignored_account_data:
- defer.returnValue(False)
-
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
|