summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py4
-rw-r--r--synapse/storage/_base.py45
-rw-r--r--synapse/storage/events.py27
-rw-r--r--synapse/storage/schema/delta/14/upgrade_appservice_db.py2
-rw-r--r--synapse/storage/schema/delta/20/pushers.py76
-rw-r--r--synapse/storage/state.py20
6 files changed, 138 insertions, 36 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 75af44d787..c137f47820 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 19
+SCHEMA_VERSION = 20
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
                         module_name, absolute_path, python_file
                     )
                 logger.debug("Running script %s", relative_path)
-                module.run_upgrade(cur)
+                module.run_upgrade(cur, database_engine)
             elif ext == ".sql":
                 # A plain old .sql file, just read and execute it
                 logger.debug("Applying schema %s", relative_path)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..8d33def6c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,7 +127,7 @@ class Cache(object):
         self.cache.clear()
 
 
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
     The function is presumed to take zero or more arguments, which are used in
@@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def wrap(orig):
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+        self.orig = orig
+
+        self.max_entries = max_entries
+        self.num_args = num_args
+        self.lru = lru
+
+    def __get__(self, obj, objtype=None):
         cache = Cache(
-            name=orig.__name__,
-            max_entries=max_entries,
-            keylen=num_args,
-            lru=lru,
+            name=self.orig.__name__,
+            max_entries=self.max_entries,
+            keylen=self.num_args,
+            lru=self.lru,
         )
 
-        @functools.wraps(orig)
+        @functools.wraps(self.orig)
         @defer.inlineCallbacks
-        def wrapped(self, *keyargs):
+        def wrapped(*keyargs):
             try:
-                cached_result = cache.get(*keyargs)
+                cached_result = cache.get(*keyargs[:self.num_args])
                 if DEBUG_CACHES:
-                    actual_result = yield orig(self, *keyargs)
+                    actual_result = yield self.orig(obj, *keyargs)
                     if actual_result != cached_result:
                         logger.error(
                             "Stale cache entry %s%r: cached: %r, actual %r",
-                            orig.__name__, keyargs,
+                            self.orig.__name__, keyargs,
                             cached_result, actual_result,
                         )
                         raise ValueError("Stale cache entry")
@@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
                 # while the SELECT is executing (SYN-369)
                 sequence = cache.sequence
 
-                ret = yield orig(self, *keyargs)
+                ret = yield self.orig(obj, *keyargs)
 
-                cache.update(sequence, *keyargs + (ret,))
+                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
 
                 defer.returnValue(ret)
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.prefill = cache.prefill
+
+        obj.__dict__[self.orig.__name__] = wrapped
+
         return wrapped
 
-    return wrap
+
+def cached(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru
+    )
 
 
 class LoggingTransaction(object):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d2a010bd88..2caf0aae80 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -17,7 +17,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException
 
 from twisted.internet import defer, reactor
 
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, USE_FROZEN_DICTS
 from synapse.events.utils import prune_event
 
 from synapse.util.logcontext import preserve_context_over_deferred
@@ -26,11 +26,11 @@ from synapse.api.constants import EventTypes
 from synapse.crypto.event_signing import compute_event_reference_hash
 
 from syutil.base64util import decode_base64
-from syutil.jsonutil import encode_canonical_json
+from syutil.jsonutil import encode_json
 from contextlib import contextmanager
 
 import logging
-import simplejson as json
+import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -166,8 +166,9 @@ class EventsStore(SQLBaseStore):
             allow_none=True,
         )
 
-        metadata_json = encode_canonical_json(
-            event.internal_metadata.get_dict()
+        metadata_json = encode_json(
+            event.internal_metadata.get_dict(),
+            using_frozen_dicts=USE_FROZEN_DICTS
         ).decode("UTF-8")
 
         # If we have already persisted this event, we don't need to do any
@@ -235,12 +236,14 @@ class EventsStore(SQLBaseStore):
                 "event_id": event.event_id,
                 "room_id": event.room_id,
                 "internal_metadata": metadata_json,
-                "json": encode_canonical_json(event_dict).decode("UTF-8"),
+                "json": encode_json(
+                    event_dict, using_frozen_dicts=USE_FROZEN_DICTS
+                ).decode("UTF-8"),
             },
         )
 
-        content = encode_canonical_json(
-            event.content
+        content = encode_json(
+            event.content, using_frozen_dicts=USE_FROZEN_DICTS
         ).decode("UTF-8")
 
         vals = {
@@ -266,8 +269,8 @@ class EventsStore(SQLBaseStore):
             ]
         }
 
-        vals["unrecognized_keys"] = encode_canonical_json(
-            unrec
+        vals["unrecognized_keys"] = encode_json(
+            unrec, using_frozen_dicts=USE_FROZEN_DICTS
         ).decode("UTF-8")
 
         sql = (
@@ -733,7 +736,8 @@ class EventsStore(SQLBaseStore):
 
             because = yield self.get_event(
                 redaction_id,
-                check_redacted=False
+                check_redacted=False,
+                allow_none=True,
             )
 
             if because:
@@ -743,6 +747,7 @@ class EventsStore(SQLBaseStore):
             prev = yield self.get_event(
                 ev.unsigned["replaces_state"],
                 get_prev_content=False,
+                allow_none=True,
             )
             if prev:
                 ev.unsigned["prev_content"] = prev.get_dict()["content"]
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 9f3a4dd4c5..61232f9757 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -18,7 +18,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-def run_upgrade(cur):
+def run_upgrade(cur, *args, **kwargs):
     cur.execute("SELECT id, regex FROM application_services_regex")
     for row in cur.fetchall():
         try:
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
new file mode 100644
index 0000000000..543e57bbe2
--- /dev/null
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -0,0 +1,76 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Main purpose of this upgrade is to change the unique key on the
+pushers table again (it was missed when the v16 full schema was
+made) but this also changes the pushkey and data columns to text.
+When selecting a bytea column into a text column, postgres inserts
+the hex encoded data, and there's no portable way of getting the
+UTF-8 bytes, so we have to do it in Python.
+"""
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    logger.info("Porting pushers table...")
+    cur.execute("""
+        CREATE TABLE IF NOT EXISTS pushers2 (
+          id BIGINT PRIMARY KEY,
+          user_name TEXT NOT NULL,
+          access_token BIGINT DEFAULT NULL,
+          profile_tag VARCHAR(32) NOT NULL,
+          kind VARCHAR(8) NOT NULL,
+          app_id VARCHAR(64) NOT NULL,
+          app_display_name VARCHAR(64) NOT NULL,
+          device_display_name VARCHAR(128) NOT NULL,
+          pushkey TEXT NOT NULL,
+          ts BIGINT NOT NULL,
+          lang VARCHAR(8),
+          data TEXT,
+          last_token TEXT,
+          last_success BIGINT,
+          failing_since BIGINT,
+          UNIQUE (app_id, pushkey, user_name)
+        )
+    """)
+    cur.execute("""SELECT
+        id, user_name, access_token, profile_tag, kind,
+        app_id, app_display_name, device_display_name,
+        pushkey, ts, lang, data, last_token, last_success,
+        failing_since
+        FROM pushers
+    """)
+    count = 0
+    for row in cur.fetchall():
+        row = list(row)
+        row[8] = bytes(row[8]).decode("utf-8")
+        row[11] = bytes(row[11]).decode("utf-8")
+        cur.execute(database_engine.convert_param_style("""
+            INSERT into pushers2 (
+            id, user_name, access_token, profile_tag, kind,
+            app_id, app_display_name, device_display_name,
+            pushkey, ts, lang, data, last_token, last_success,
+            failing_since
+            ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
+            row
+        )
+        count += 1
+    cur.execute("DROP TABLE pushers")
+    cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
+    logger.info("Moved %d pushers to new table", count)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b24de34f23..f2b17f29ea 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
             f,
         )
 
-        @defer.inlineCallbacks
-        def c(vals):
-            vals[:] = yield self._get_events(vals, get_prev_content=False)
-
-        yield defer.gatherResults(
+        state_list = yield defer.gatherResults(
             [
-                c(vals)
-                for vals in states.values()
+                self._fetch_events_for_group(group, vals)
+                for group, vals in states.items()
             ],
             consumeErrors=True,
         )
 
-        defer.returnValue(states)
+        defer.returnValue(dict(state_list))
+
+    @cached(num_args=1)
+    def _fetch_events_for_group(self, state_group, events):
+        return self._get_events(
+            events, get_prev_content=False
+        ).addCallback(
+            lambda evs: (state_group, evs)
+        )
 
     def _store_state_groups_txn(self, txn, event, context):
         if context.current_state is None: