summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-05-10 13:27:41 +0100
committerErik Johnston <erik@matrix.org>2016-05-10 13:27:41 +0100
commite581754f0972b29ee2f64fb0b70881f6f9112108 (patch)
treee6791617720f3d92c417afd97d382eadb2123242 /synapse/storage
parentMerge pull request #773 from matrix-org/erikj/get_domian_from_id (diff)
parentMerge branch 'develop' of github.com:matrix-org/synapse into erikj/ignore_user (diff)
downloadsynapse-e581754f0972b29ee2f64fb0b70881f6f9112108.tar.xz
Merge pull request #763 from matrix-org/erikj/ignore_user
Implement basic ignore user API
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/account_data.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 7a7fbf1e52..ec7e8d40d2 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -16,6 +16,8 @@
 from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
 import ujson as json
 import logging
 
@@ -24,6 +26,7 @@ logger = logging.getLogger(__name__)
 
 class AccountDataStore(SQLBaseStore):
 
+    @cached()
     def get_account_data_for_user(self, user_id):
         """Get all the client account_data for a user.
 
@@ -60,6 +63,47 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2)
+    def get_global_account_data_by_type_for_user(self, data_type, user_id):
+        """
+        Returns:
+            Deferred: A dict
+        """
+        result = yield self._simple_select_one_onecol(
+            table="account_data",
+            keyvalues={
+                "user_id": user_id,
+                "account_data_type": data_type,
+            },
+            retcol="content",
+            desc="get_global_account_data_by_type_for_user",
+            allow_none=True,
+        )
+
+        if result:
+            defer.returnValue(json.loads(result))
+        else:
+            defer.returnValue(None)
+
+    @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
+                num_args=2, list_name="user_ids", inlineCallbacks=True)
+    def get_global_account_data_by_type_for_users(self, data_type, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="account_data",
+            column="user_id",
+            iterable=user_ids,
+            keyvalues={
+                "account_data_type": data_type,
+            },
+            retcols=("user_id", "content",),
+            desc="get_global_account_data_by_type_for_users",
+        )
+
+        defer.returnValue({
+            row["user_id"]: json.loads(row["content"]) if row["content"] else None
+            for row in rows
+        })
+
     def get_account_data_for_room(self, user_id, room_id):
         """Get all the client account_data for a user for a room.
 
@@ -193,6 +237,7 @@ class AccountDataStore(SQLBaseStore):
                 self._account_data_stream_cache.entity_has_changed,
                 user_id, next_id,
             )
+            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
             self._update_max_stream_id(txn, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
@@ -232,6 +277,11 @@ class AccountDataStore(SQLBaseStore):
                 self._account_data_stream_cache.entity_has_changed,
                 user_id, next_id,
             )
+            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
+            txn.call_after(
+                self.get_global_account_data_by_type_for_user.invalidate,
+                (account_data_type, user_id,)
+            )
             self._update_max_stream_id(txn, next_id)
 
         with self._account_data_id_gen.get_next() as next_id: