summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4917.misc1
-rw-r--r--synapse/handlers/state_deltas.py70
-rw-r--r--synapse/handlers/user_directory.py51
-rw-r--r--synapse/storage/state_deltas.py74
-rw-r--r--synapse/storage/user_directory.py66
5 files changed, 152 insertions, 110 deletions
diff --git a/changelog.d/4917.misc b/changelog.d/4917.misc
new file mode 100644
index 0000000000..338d8a9a0c
--- /dev/null
+++ b/changelog.d/4917.misc
@@ -0,0 +1 @@
+Refactor out the state deltas portion of the user directory store and handler.
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
new file mode 100644
index 0000000000..b268bbcb2c
--- /dev/null
+++ b/synapse/handlers/state_deltas.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeltasHandler(object):
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+        """Given two events check if the `key_name` field in content changed
+        from not matching `public_value` to doing so.
+
+        For example, check if `history_visibility` (`key_name`) changed from
+        `shared` to `world_readable` (`public_value`).
+
+        Returns:
+            None if the field in the events either both match `public_value`
+            or if neither do, i.e. there has been no change.
+            True if it didnt match `public_value` but now does
+            False if it did match `public_value` but now doesn't
+        """
+        prev_event = None
+        event = None
+        if prev_event_id:
+            prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+
+        if event_id:
+            event = yield self.store.get_event(event_id, allow_none=True)
+
+        if not event and not prev_event:
+            logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
+            defer.returnValue(None)
+
+        prev_value = None
+        value = None
+
+        if prev_event:
+            prev_value = prev_event.content.get(key_name)
+
+        if event:
+            value = event.content.get(key_name)
+
+        logger.debug("prev_value: %r -> value: %r", prev_value, value)
+
+        if value == public_value and prev_value != public_value:
+            defer.returnValue(True)
+        elif value != public_value and prev_value == public_value:
+            defer.returnValue(False)
+        else:
+            defer.returnValue(None)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 7dc0e236e7..b689979b4b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 import synapse.metrics
 from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.handlers.state_deltas import StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import get_localpart_from_id
@@ -29,7 +30,7 @@ from synapse.util.metrics import Measure
 logger = logging.getLogger(__name__)
 
 
-class UserDirectoryHandler(object):
+class UserDirectoryHandler(StateDeltasHandler):
     """Handles querying of and keeping updated the user_directory.
 
     N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@@ -41,6 +42,8 @@ class UserDirectoryHandler(object):
     """
 
     def __init__(self, hs):
+        super(UserDirectoryHandler, self).__init__(hs)
+
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -360,7 +363,7 @@ class UserDirectoryHandler(object):
 
     @defer.inlineCallbacks
     def _handle_remove_user(self, room_id, user_id):
-        """Called when we might need to remove user to directory
+        """Called when we might need to remove user from directory
 
         Args:
             room_id (str): room_id that user left or stopped being public that
@@ -402,47 +405,3 @@ class UserDirectoryHandler(object):
 
         if prev_name != new_name or prev_avatar != new_avatar:
             yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
-
-    @defer.inlineCallbacks
-    def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
-        """Given two events check if the `key_name` field in content changed
-        from not matching `public_value` to doing so.
-
-        For example, check if `history_visibility` (`key_name`) changed from
-        `shared` to `world_readable` (`public_value`).
-
-        Returns:
-            None if the field in the events either both match `public_value`
-            or if neither do, i.e. there has been no change.
-            True if it didnt match `public_value` but now does
-            False if it did match `public_value` but now doesn't
-        """
-        prev_event = None
-        event = None
-        if prev_event_id:
-            prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
-
-        if event_id:
-            event = yield self.store.get_event(event_id, allow_none=True)
-
-        if not event and not prev_event:
-            logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
-            defer.returnValue(None)
-
-        prev_value = None
-        value = None
-
-        if prev_event:
-            prev_value = prev_event.content.get(key_name)
-
-        if event:
-            value = event.content.get(key_name)
-
-        logger.debug("prev_value: %r -> value: %r", prev_value, value)
-
-        if value == public_value and prev_value != public_value:
-            defer.returnValue(True)
-        elif value != public_value and prev_value == public_value:
-            defer.returnValue(False)
-        else:
-            defer.returnValue(None)
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
new file mode 100644
index 0000000000..57bc45cdb9
--- /dev/null
+++ b/synapse/storage/state_deltas.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeltasStore(SQLBaseStore):
+
+    def get_current_state_deltas(self, prev_stream_id):
+        prev_stream_id = int(prev_stream_id)
+        if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+            return []
+
+        def get_current_state_deltas_txn(txn):
+            # First we calculate the max stream id that will give us less than
+            # N results.
+            # We arbitarily limit to 100 stream_id entries to ensure we don't
+            # select toooo many.
+            sql = """
+                SELECT stream_id, count(*)
+                FROM current_state_delta_stream
+                WHERE stream_id > ?
+                GROUP BY stream_id
+                ORDER BY stream_id ASC
+                LIMIT 100
+            """
+            txn.execute(sql, (prev_stream_id,))
+
+            total = 0
+            max_stream_id = prev_stream_id
+            for max_stream_id, count in txn:
+                total += count
+                if total > 100:
+                    # We arbitarily limit to 100 entries to ensure we don't
+                    # select toooo many.
+                    break
+
+            # Now actually get the deltas
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            txn.execute(sql, (prev_stream_id, max_stream_id,))
+            return self.cursor_to_dict(txn)
+
+        return self.runInteraction(
+            "get_current_state_deltas", get_current_state_deltas_txn
+        )
+
+    def get_max_stream_id_in_current_state_deltas(self):
+        return self._simple_select_one_onecol(
+            table="current_state_delta_stream",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), -1)",
+            desc="get_max_stream_id_in_current_state_deltas",
+        )
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index d360e857d1..65bdb1b4a5 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.state import StateFilter
+from synapse.storage.state_deltas import StateDeltasStore
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
@@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
 TEMP_TABLE = "_temp_populate_user_directory"
 
 
-class UserDirectoryStore(BackgroundUpdateStore):
+class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
 
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -488,16 +489,6 @@ class UserDirectoryStore(BackgroundUpdateStore):
 
         defer.returnValue(user_ids)
 
-    @defer.inlineCallbacks
-    def get_all_local_users(self):
-        """Get all local users
-        """
-        sql = """
-            SELECT name FROM users
-        """
-        rows = yield self._execute("get_all_local_users", None, sql)
-        defer.returnValue([name for name, in rows])
-
     def add_users_who_share_private_room(self, room_id, user_id_tuples):
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
@@ -675,59 +666,6 @@ class UserDirectoryStore(BackgroundUpdateStore):
             desc="update_user_directory_stream_pos",
         )
 
-    def get_current_state_deltas(self, prev_stream_id):
-        prev_stream_id = int(prev_stream_id)
-        if not self._curr_state_delta_stream_cache.has_any_entity_changed(
-            prev_stream_id
-        ):
-            return []
-
-        def get_current_state_deltas_txn(txn):
-            # First we calculate the max stream id that will give us less than
-            # N results.
-            # We arbitarily limit to 100 stream_id entries to ensure we don't
-            # select toooo many.
-            sql = """
-                SELECT stream_id, count(*)
-                FROM current_state_delta_stream
-                WHERE stream_id > ?
-                GROUP BY stream_id
-                ORDER BY stream_id ASC
-                LIMIT 100
-            """
-            txn.execute(sql, (prev_stream_id,))
-
-            total = 0
-            max_stream_id = prev_stream_id
-            for max_stream_id, count in txn:
-                total += count
-                if total > 100:
-                    # We arbitarily limit to 100 entries to ensure we don't
-                    # select toooo many.
-                    break
-
-            # Now actually get the deltas
-            sql = """
-                SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
-                FROM current_state_delta_stream
-                WHERE ? < stream_id AND stream_id <= ?
-                ORDER BY stream_id ASC
-            """
-            txn.execute(sql, (prev_stream_id, max_stream_id))
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction(
-            "get_current_state_deltas", get_current_state_deltas_txn
-        )
-
-    def get_max_stream_id_in_current_state_deltas(self):
-        return self._simple_select_one_onecol(
-            table="current_state_delta_stream",
-            keyvalues={},
-            retcol="COALESCE(MAX(stream_id), -1)",
-            desc="get_max_stream_id_in_current_state_deltas",
-        )
-
     @defer.inlineCallbacks
     def search_user_dir(self, user_id, search_term, limit):
         """Searches for users in directory