summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py21
-rw-r--r--synapse/storage/account_data.py76
-rw-r--r--synapse/storage/push_rule.py72
-rw-r--r--synapse/storage/pusher.py11
-rw-r--r--synapse/storage/tags.py16
5 files changed, 131 insertions, 65 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e1c4fe086e..0f136f8a06 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -104,9 +105,6 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "events", "stream_ordering", step=-1,
             extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
         )
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -159,11 +157,6 @@ class DataStore(RoomMemberStore, RoomStore,
             "MembershipStreamChangeCache", events_max,
         )
 
-        account_max = self._account_data_id_gen.get_current_token()
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
-        )
-
         self._presence_on_startup = self._get_active_presence(db_conn)
 
         presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -177,18 +170,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=presence_cache_prefill
         )
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn, "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache", push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
             db_conn, "device_inbox",
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 56a0bde549..466194e96f 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
+import abc
 import ujson as json
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_account_data_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        account_max = self.get_max_account_data_stream_id()
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache", account_max,
+        )
+
+        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+    @abc.abstractmethod
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream ID for account data stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
 
     @cached()
     def get_account_data_for_user(self, user_id):
@@ -209,6 +238,36 @@ class AccountDataStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", ignorer_user_id,
+            on_invalidate=cache_context.invalidate,
+        )
+        if not ignored_account_data:
+            defer.returnValue(False)
+
+        defer.returnValue(
+            ignored_user_id in ignored_account_data.get("ignored_users", {})
+        )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    def __init__(self, db_conn, hs):
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        super(AccountDataStore, self).__init__(db_conn, hs)
+
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._account_data_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
         """Add some account_data to a room for a user.
@@ -321,16 +380,3 @@ class AccountDataStore(SQLBaseStore):
             "update_account_data_max_stream_id",
             _update,
         )
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
-    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
-        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list", ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            defer.returnValue(False)
-
-        defer.returnValue(
-            ignored_user_id in ignored_account_data.get("ignored_users", {})
-        )
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..583efb7bdf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,12 @@
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.push.baserules import list_with_base_rules
 from synapse.api.constants import EventTypes
 from twisted.internet import defer
 
+import abc
 import logging
 import simplejson as json
 
@@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
@@ -89,6 +124,24 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
+
+class PushRuleStore(PushRulesWorkerStore):
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -526,21 +579,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3d8b4d5d5b..f4af3e4caa 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
 logger = logging.getLogger(__name__)
 
 
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
     def _decode_pushers_rows(self, rows):
         for r in rows:
             dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
-    def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_current_token()
-
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed(([], []))
@@ -177,6 +175,11 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
+
+class PusherStore(PusherWorkerStore):
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..fc46bf7bb3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
 
@@ -23,15 +25,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class TagsStore(SQLBaseStore):
-    def get_max_account_data_stream_id(self):
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            A deferred int.
-        """
-        return self._account_data_id_gen.get_current_token()
-
+class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     def get_tags_for_user(self, user_id):
         """Get all the tags for a user.
@@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore):
             row["tag"]: json.loads(row["content"]) for row in rows
         })
 
+
+class TagsStore(TagsWorkerStore):
     @defer.inlineCallbacks
     def add_tag_to_room(self, user_id, room_id, tag, content):
         """Add a tag to a room for a user.