summary refs log tree commit diff
path: root/synapse/storage/account_data.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
committerErik Johnston <erik@matrix.org>2019-04-17 19:44:40 +0100
commitca90336a6935b36b5761244005b0f68b496d5d79 (patch)
tree6bbce5eafc0db3b24ccc3b59b051da850382ae09 /synapse/storage/account_data.py
parentAdd management endpoints for account validity (diff)
parentMerge pull request #5047 from matrix-org/babolivier/account_expiration (diff)
downloadsynapse-ca90336a6935b36b5761244005b0f68b496d5d79.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/account_expiration
Diffstat (limited to 'synapse/storage/account_data.py')
-rw-r--r--synapse/storage/account_data.py77
1 files changed, 35 insertions, 42 deletions
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index bbc3355c73..8394389073 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
+            "AccountDataAndTagsChangeCache", account_max
         )
 
         super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
 
         def get_account_data_for_user_txn(txn):
             rows = self._simple_select_list_txn(
-                txn, "account_data", {"user_id": user_id},
-                ["account_data_type", "content"]
+                txn,
+                "account_data",
+                {"user_id": user_id},
+                ["account_data_type", "content"],
             )
 
             global_account_data = {
@@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
             }
 
             rows = self._simple_select_list_txn(
-                txn, "room_account_data", {"user_id": user_id},
-                ["room_id", "account_data_type", "content"]
+                txn,
+                "room_account_data",
+                {"user_id": user_id},
+                ["room_id", "account_data_type", "content"],
             )
 
             by_room = {}
@@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
         result = yield self._simple_select_one_onecol(
             table="account_data",
-            keyvalues={
-                "user_id": user_id,
-                "account_data_type": data_type,
-            },
+            keyvalues={"user_id": user_id, "account_data_type": data_type},
             retcol="content",
             desc="get_global_account_data_by_type_for_user",
             allow_none=True,
@@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
         Returns:
             A deferred dict of the room account_data
         """
+
         def get_account_data_for_room_txn(txn):
             rows = self._simple_select_list_txn(
-                txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
-                ["account_data_type", "content"]
+                txn,
+                "room_account_data",
+                {"user_id": user_id, "room_id": room_id},
+                ["account_data_type", "content"],
             )
 
             return {
@@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             A deferred of the room account_data for that type, or None if
             there isn't any set.
         """
+
         def get_account_data_for_room_and_type_txn(txn):
             content_json = self._simple_select_one_onecol_txn(
                 txn,
@@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
                     "account_data_type": account_data_type,
                 },
                 retcol="content",
-                allow_none=True
+                allow_none=True,
             )
 
             return json.loads(content_json) if content_json else None
 
         return self.runInteraction(
-            "get_account_data_for_room_and_type",
-            get_account_data_for_room_and_type_txn,
+            "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
-    def get_all_updated_account_data(self, last_global_id, last_room_id,
-                                     current_id, limit):
+    def get_all_updated_account_data(
+        self, last_global_id, last_room_id, current_id, limit
+    ):
         """Get all the client account_data that has changed on the server
         Args:
             last_global_id(int): The position to fetch from for top level data
@@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_room_id, current_id, limit))
             room_results = txn.fetchall()
             return (global_results, room_results)
+
         return self.runInteraction(
             "get_all_updated_account_data_txn", get_updated_account_data_txn
         )
@@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {
-                row[0]: json.loads(row[1]) for row in txn
-            }
+            global_account_data = {row[0]: json.loads(row[1]) for row in txn}
 
             sql = (
                 "SELECT room_id, account_data_type, content FROM room_account_data"
@@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
     @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
     def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
         ignored_account_data = yield self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list", ignorer_user_id,
+            "m.ignored_user_list",
+            ignorer_user_id,
             on_invalidate=cache_context.invalidate,
         )
         if not ignored_account_data:
@@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
                     "room_id": room_id,
                     "account_data_type": account_data_type,
                 },
-                values={
-                    "stream_id": next_id,
-                    "content": content_json,
-                },
+                values={"stream_id": next_id, "content": content_json},
                 lock=False,
             )
 
@@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
-            self.get_account_data_for_room.invalidate((user_id, room_id,))
+            self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
-                (user_id, room_id, account_data_type,), content,
+                (user_id, room_id, account_data_type), content
             )
 
         result = self._account_data_id_gen.get_current_token()
@@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
             yield self._simple_upsert(
                 desc="add_user_account_data",
                 table="account_data",
-                keyvalues={
-                    "user_id": user_id,
-                    "account_data_type": account_data_type,
-                },
-                values={
-                    "stream_id": next_id,
-                    "content": content_json,
-                },
+                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+                values={"stream_id": next_id, "content": content_json},
                 lock=False,
             )
 
@@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
             # transaction.
             yield self._update_max_stream_id(next_id)
 
-            self._account_data_stream_cache.entity_has_changed(
-                user_id, next_id,
-            )
+            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
-                (account_data_type, user_id,)
+                (account_data_type, user_id)
             )
 
         result = self._account_data_id_gen.get_current_token()
@@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
         Args:
             next_id(int): The the revision to advance to.
         """
+
         def _update(txn):
             update_max_id_sql = (
                 "UPDATE account_data_max_stream_id"
@@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
                 " WHERE stream_id < ?"
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
-        return self.runInteraction(
-            "update_account_data_max_stream_id",
-            _update,
-        )
+
+        return self.runInteraction("update_account_data_max_stream_id", _update)