summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py83
-rw-r--r--synapse/storage/events.py22
-rw-r--r--synapse/storage/receipts.py11
-rw-r--r--synapse/storage/roommember.py24
-rw-r--r--synapse/storage/state.py14
5 files changed, 103 insertions, 51 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 813ad59e56..d4cf0fc59b 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -228,46 +228,69 @@ class BackgroundUpdateStore(SQLBaseStore):
             columns (list[str]): columns/expressions to include in index
         """
 
-        # if this is postgres, we add the indexes concurrently. Otherwise
-        # we fall back to doing it inline
-        if isinstance(self.database_engine, engines.PostgresEngine):
-            conc = True
-        else:
-            conc = False
-            # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy has 3.7
-            where_clause = None
-
-        sql = (
-            "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
-            " %(where_clause)s"
-        ) % {
-            "conc": "CONCURRENTLY" if conc else "",
-            "name": index_name,
-            "table": table,
-            "columns": ", ".join(columns),
-            "where_clause": "WHERE " + where_clause if where_clause else ""
-        }
-
-        def create_index_concurrently(conn):
+        def create_index_psql(conn):
             conn.rollback()
             # postgres insists on autocommit for the index
             conn.set_session(autocommit=True)
-            c = conn.cursor()
-            c.execute(sql)
-            conn.set_session(autocommit=False)
 
-        def create_index(conn):
+            try:
+                c = conn.cursor()
+
+                # If a previous attempt to create the index was interrupted,
+                # we may already have a half-built index. Let's just drop it
+                # before trying to create it again.
+
+                sql = "DROP INDEX IF EXISTS %s" % (index_name,)
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+
+                sql = (
+                    "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s"
+                    " (%(columns)s) %(where_clause)s"
+                ) % {
+                    "name": index_name,
+                    "table": table,
+                    "columns": ", ".join(columns),
+                    "where_clause": "WHERE " + where_clause if where_clause else ""
+                }
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+            finally:
+                conn.set_session(autocommit=False)
+
+        def create_index_sqlite(conn):
+            # Sqlite doesn't support concurrent creation of indexes.
+            #
+            # We don't use partial indices on SQLite as it wasn't introduced
+            # until 3.8, and wheezy has 3.7
+            #
+            # We assume that sqlite doesn't give us invalid indices; however
+            # we may still end up with the index existing but the
+            # background_updates not having been recorded if synapse got shut
+            # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
+            # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
+            sql = (
+                "CREATE INDEX IF NOT EXISTS %(name)s ON %(table)s"
+                " (%(columns)s)"
+            ) % {
+                "name": index_name,
+                "table": table,
+                "columns": ", ".join(columns),
+            }
+
             c = conn.cursor()
+            logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
+        if isinstance(self.database_engine, engines.PostgresEngine):
+            runner = create_index_psql
+        else:
+            runner = create_index_sqlite
+
         @defer.inlineCallbacks
         def updater(progress, batch_size):
             logger.info("Adding index %s to %s", index_name, table)
-            if conc:
-                yield self.runWithConnection(create_index_concurrently)
-            else:
-                yield self.runWithConnection(create_index)
+            yield self.runWithConnection(runner)
             yield self._end_background_update(update_name)
             defer.returnValue(1)
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 64fe937bdc..a3790419dd 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -2159,6 +2159,28 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+    @defer.inlineCallbacks
+    def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = yield self._get_event_ordering(event_id1)
+        to_2, so_2 = yield self._get_event_ordering(event_id2)
+        defer.returnValue((to_1, so_1) > (to_2, so_2))
+
+    @defer.inlineCallbacks
+    def _get_event_ordering(self, event_id):
+        res = yield self._simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+
 
 AllNewEventsResult = namedtuple("AllNewEventsResult", [
     "new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 6b0f8c2787..efb90c3c91 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -47,10 +47,13 @@ class ReceiptsStore(SQLBaseStore):
         # Returns an ObservableDeferred
         res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
 
-        if res and res.called and user_id in res.result:
-            # We'd only be adding to the set, so no point invalidating if the
-            # user is already there
-            return
+        if res:
+            if isinstance(res, defer.Deferred) and res.called:
+                res = res.result
+            if user_id in res:
+                # We'd only be adding to the set, so no point invalidating if the
+                # user is already there
+                return
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 367dbbbcf6..7ad2198d96 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -19,6 +19,7 @@ from collections import namedtuple
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.stringutils import to_ascii
 
 from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
@@ -35,6 +36,13 @@ RoomsForUser = namedtuple(
 )
 
 
+# We store this using a namedtuple so that we save about 3x space over using a
+# dict.
+ProfileInfo = namedtuple(
+    "ProfileInfo", ("avatar_url", "display_name")
+)
+
+
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
 
@@ -422,20 +430,20 @@ class RoomMemberStore(SQLBaseStore):
         )
 
         users_in_room = {
-            row["user_id"]: {
-                "display_name": row["display_name"],
-                "avatar_url": row["avatar_url"],
-            }
+            to_ascii(row["user_id"]): ProfileInfo(
+                avatar_url=to_ascii(row["avatar_url"]),
+                display_name=to_ascii(row["display_name"]),
+            )
             for row in rows
         }
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = {
-                        "display_name": event.content.get("displayname", None),
-                        "avatar_url": event.content.get("avatar_url", None),
-                    }
+                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(event.content.get("displayname", None)),
+                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
+                    )
 
         defer.returnValue(users_in_room)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index acd69944c4..a16afa8df5 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches import intern_string
+from synapse.util.stringutils import to_ascii
 from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
@@ -89,7 +90,7 @@ class StateStore(SQLBaseStore):
             )
 
             return {
-                (r[0], r[1]): r[2] for r in txn
+                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
             }
 
         return self.runInteraction(
@@ -279,12 +280,7 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @cached(num_args=2, max_entries=100000, iterable=True)
-    def _get_state_group_from_group(self, group, types):
-        raise NotImplementedError()
-
-    @cachedList(cached_method_name="_get_state_group_from_group",
-                list_name="groups", num_args=2, inlineCallbacks=True)
+    @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
@@ -512,7 +508,7 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=100000)
+    @cached(num_args=2, max_entries=50000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
@@ -660,7 +656,7 @@ class StateStore(SQLBaseStore):
                     state_dict = results[group]
 
                 state_dict.update(
-                    ((intern_string(k[0]), intern_string(k[1])), v)
+                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
                     for k, v in group_state_dict.iteritems()
                 )