summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-05-07 19:07:00 +0100
committerErik Johnston <erik@matrix.org>2015-05-07 19:07:00 +0100
commit89c0cd4accbf6d809cc9d3fdce4df4d8e4f39d35 (patch)
tree019dd15780bbd432e099c748fecd2a16b645b470 /synapse/storage/_base.py
parentMerge pull request #124 from matrix-org/hotfixes-v0.8.1-r4 (diff)
parentSlight rewording (diff)
downloadsynapse-89c0cd4accbf6d809cc9d3fdce4df4d8e4f39d35.tar.xz
Merge branch 'release-v0.9.0' of github.com:matrix-org/synapse v0.9.0
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py543
1 files changed, 392 insertions, 151 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9125bb1198..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
+from util.id_generators import IdGenerator, StreamIdGenerator
+
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
@@ -29,12 +31,15 @@ import functools
 import simplejson as json
 import sys
 import time
+import threading
 
+DEBUG_CACHES = False
 
 logger = logging.getLogger(__name__)
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
 
 
 metrics = synapse.metrics.get_metrics_for("synapse.storage")
@@ -53,14 +58,78 @@ cache_counter = metrics.register_cache(
 )
 
 
-# TODO(paul):
-#  * more generic key management
-#  * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+class Cache(object):
+
+    def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+        if lru:
+            self.cache = LruCache(max_size=max_entries)
+            self.max_entries = None
+        else:
+            self.cache = OrderedDict()
+            self.max_entries = max_entries
+
+        self.name = name
+        self.keylen = keylen
+        self.sequence = 0
+        self.thread = None
+        caches_by_name[name] = self.cache
+
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
+    def get(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if keyargs in self.cache:
+            cache_counter.inc_hits(self.name)
+            return self.cache[keyargs]
+
+        cache_counter.inc_misses(self.name)
+        raise KeyError()
+
+    def update(self, sequence, *args):
+        self.check_thread()
+        if self.sequence == sequence:
+            # Only update the cache if the caches sequence number matches the
+            # number that the cache had before the SELECT was started (SYN-369)
+            self.prefill(*args)
+
+    def prefill(self, *args):  # because I can't  *keyargs, value
+        keyargs = args[:-1]
+        value = args[-1]
+
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if self.max_entries is not None:
+            while len(self.cache) >= self.max_entries:
+                self.cache.popitem(last=False)
+
+        self.cache[keyargs] = value
+
+    def invalidate(self, *keyargs):
+        self.check_thread()
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+        # Increment the sequence number so that any SELECT statements that
+        # raced with the INSERT don't update the cache (SYN-369)
+        self.sequence += 1
+        self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
     """ A method decorator that applies a memoizing cache around the function.
 
-    The function is presumed to take one additional argument, which is used as
-    the key for the cache. Cache hits are served directly from the cache;
+    The function is presumed to take zero or more arguments, which are used in
+    a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
 
     The wrapped function has an additional member, a callable called
@@ -71,34 +140,42 @@ def cached(max_entries=1000):
     calling the calculation function.
     """
     def wrap(orig):
-        cache = OrderedDict()
-        name = orig.__name__
-
-        caches_by_name[name] = cache
-
-        def prefill(key, value):
-            while len(cache) > max_entries:
-                cache.popitem(last=False)
-
-            cache[key] = value
+        cache = Cache(
+            name=orig.__name__,
+            max_entries=max_entries,
+            keylen=num_args,
+            lru=lru,
+        )
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
-        def wrapped(self, key):
-            if key in cache:
-                cache_counter.inc_hits(name)
-                defer.returnValue(cache[key])
-
-            cache_counter.inc_misses(name)
-            ret = yield orig(self, key)
-            prefill(key, ret)
-            defer.returnValue(ret)
-
-        def invalidate(key):
-            cache.pop(key, None)
-
-        wrapped.invalidate = invalidate
-        wrapped.prefill = prefill
+        def wrapped(self, *keyargs):
+            try:
+                cached_result = cache.get(*keyargs)
+                if DEBUG_CACHES:
+                    actual_result = yield orig(self, *keyargs)
+                    if actual_result != cached_result:
+                        logger.error(
+                            "Stale cache entry %s%r: cached: %r, actual %r",
+                            orig.__name__, keyargs,
+                            cached_result, actual_result,
+                        )
+                        raise ValueError("Stale cache entry")
+                defer.returnValue(cached_result)
+            except KeyError:
+                # Get the sequence number of the cache before reading from the
+                # database so that we can tell if the cache is invalidated
+                # while the SELECT is executing (SYN-369)
+                sequence = cache.sequence
+
+                ret = yield orig(self, *keyargs)
+
+                cache.update(sequence, *keyargs + (ret,))
+
+                defer.returnValue(ret)
+
+        wrapped.invalidate = cache.invalidate
+        wrapped.prefill = cache.prefill
         return wrapped
 
     return wrap
@@ -108,11 +185,20 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name"]
+    __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
 
-    def __init__(self, txn, name):
+    def __init__(self, txn, name, database_engine, after_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
+        object.__setattr__(self, "database_engine", database_engine)
+        object.__setattr__(self, "after_callbacks", after_callbacks)
+
+    def call_after(self, callback, *args):
+        """Call the given callback on the main twisted thread after the
+        transaction has finished. Used to invalidate the caches on the
+        correct thread.
+        """
+        self.after_callbacks.append((callback, args))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -120,30 +206,37 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
-    def execute(self, sql, *args, **kwargs):
+    def execute(self, sql, *args):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql, *args):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _do_execute(self, func, sql, *args):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
-        try:
-            if args and args[0]:
-                values = args[0]
+        sql = self.database_engine.convert_param_style(sql)
+
+        if args:
+            try:
                 sql_logger.debug(
-                    "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
-                    self.name,
-                    *values
+                    "[SQL values] {%s} %r",
+                    self.name, args[0]
                 )
-        except:
-            # Don't let logging failures stop SQL from working
-            pass
+            except:
+                # Don't let logging failures stop SQL from working
+                pass
 
         start = time.time() * 1000
+
         try:
-            return self.txn.execute(
-                sql, *args, **kwargs
+            return func(
+                sql, *args
             )
-        except:
-                logger.exception("[SQL FAIL] {%s}", self.name)
-                raise
+        except Exception as e:
+            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            raise
         finally:
             msecs = (time.time() * 1000) - start
             sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
@@ -205,10 +298,16 @@ class SQLBaseStore(object):
         self._txn_perf_counters = PerformanceCounters()
         self._get_event_counters = PerformanceCounters()
 
-        self._get_event_cache = LruCache(hs.config.event_cache_size)
+        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+                                      max_entries=hs.config.event_cache_size)
+
+        self.database_engine = hs.database_engine
 
-        # Pretend the getEventCache is just another named cache
-        caches_by_name["*getEvent*"] = self._get_event_cache
+        self._stream_id_gen = StreamIdGenerator()
+        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+        self._pushers_id_gen = IdGenerator("pushers", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
@@ -232,7 +331,7 @@ class SQLBaseStore(object):
                 time_now - time_then, limit=3
             )
 
-            logger.info(
+            perf_logger.info(
                 "Total database time: %.3f%% {%s} {%s}",
                 ratio * 100, top_three_counters, top_3_event_counters
             )
@@ -246,8 +345,14 @@ class SQLBaseStore(object):
 
         start_time = time.time() * 1000
 
-        def inner_func(txn, *args, **kwargs):
+        after_callbacks = []
+
+        def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
+                if self.database_engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
                 current_context.copy_to(context)
                 start = time.time() * 1000
                 txn_id = self._TXN_ID
@@ -261,9 +366,48 @@ class SQLBaseStore(object):
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
                 transaction_logger.debug("[TXN START] {%s}", name)
                 try:
-                    return func(LoggingTransaction(txn, name), *args, **kwargs)
-                except:
-                    logger.exception("[TXN FAIL] {%s}", name)
+                    i = 0
+                    N = 5
+                    while True:
+                        try:
+                            txn = conn.cursor()
+                            txn = LoggingTransaction(
+                                txn, name, self.database_engine, after_callbacks
+                            )
+                            return func(txn, *args, **kwargs)
+                        except self.database_engine.module.OperationalError as e:
+                            # This can happen if the database disappears mid
+                            # transaction.
+                            logger.warn(
+                                "[TXN OPERROR] {%s} %s %d/%d",
+                                name, e, i, N
+                            )
+                            if i < N:
+                                i += 1
+                                try:
+                                    conn.rollback()
+                                except self.database_engine.module.Error as e1:
+                                    logger.warn(
+                                        "[TXN EROLL] {%s} %s",
+                                        name, e1,
+                                    )
+                                continue
+                        except self.database_engine.module.DatabaseError as e:
+                            if self.database_engine.is_deadlock(e):
+                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                                if i < N:
+                                    i += 1
+                                    try:
+                                        conn.rollback()
+                                    except self.database_engine.module.Error as e1:
+                                        logger.warn(
+                                            "[TXN EROLL] {%s} %s",
+                                            name, e1,
+                                        )
+                                    continue
+                            raise
+                except Exception as e:
+                    logger.debug("[TXN FAIL] {%s} %s", name, e)
                     raise
                 finally:
                     end = time.time() * 1000
@@ -276,9 +420,11 @@ class SQLBaseStore(object):
                     sql_txn_timer.inc_by(duration, desc)
 
         with PreserveLoggingContext():
-            result = yield self._db_pool.runInteraction(
+            result = yield self._db_pool.runWithConnection(
                 inner_func, *args, **kwargs
             )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
@@ -307,11 +453,11 @@ class SQLBaseStore(object):
             The result of decoder(results)
         """
         def interaction(txn):
-            cursor = txn.execute(query, args)
+            txn.execute(query, args)
             if decoder:
-                return decoder(cursor)
+                return decoder(txn)
             else:
-                return cursor.fetchall()
+                return txn.fetchall()
 
         return self.runInteraction(desc, interaction)
 
@@ -321,53 +467,94 @@ class SQLBaseStore(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
+    @defer.inlineCallbacks
+    def _simple_insert(self, table, values, or_ignore=False,
+                       desc="_simple_insert"):
         """Executes an INSERT query on the named table.
 
         Args:
             table : string giving the table name
             values : dict of new column names and values for them
-            or_replace : bool; if True performs an INSERT OR REPLACE
         """
-        return self.runInteraction(
-            "_simple_insert",
-            self._simple_insert_txn, table, values, or_replace=or_replace,
-            or_ignore=or_ignore,
-        )
+        try:
+            yield self.runInteraction(
+                desc,
+                self._simple_insert_txn, table, values,
+            )
+        except self.database_engine.module.IntegrityError:
+            # We have to do or_ignore flag at this layer, since we can't reuse
+            # a cursor after we receive an error from the db.
+            if not or_ignore:
+                raise
 
     @log_function
-    def _simple_insert_txn(self, txn, table, values, or_replace=False,
-                           or_ignore=False):
-        sql = "%s INTO %s (%s) VALUES(%s)" % (
-            ("INSERT OR REPLACE" if or_replace else
-             "INSERT OR IGNORE" if or_ignore else "INSERT"),
+    def _simple_insert_txn(self, txn, table, values):
+        keys, vals = zip(*values.items())
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
-            ", ".join(k for k in values),
-            ", ".join("?" for k in values)
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys)
         )
 
-        logger.debug(
-            "[SQL] %s Args=%s",
-            sql, values.values(),
+        txn.execute(sql, vals)
+
+    def _simple_insert_many_txn(self, txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(*[
+            zip(
+                *(sorted(i.items(), key=lambda kv: kv[0]))
+            )
+            for i in values
+            if i
+        ])
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError(
+                    "All items must have the same keys"
+                )
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0])
         )
 
-        txn.execute(sql, values.values())
-        return txn.lastrowid
+        txn.executemany(sql, vals)
 
-    def _simple_upsert(self, table, keyvalues, values):
+    def _simple_upsert(self, table, keyvalues, values,
+                       insertion_values={}, desc="_simple_upsert", lock=True):
         """
         Args:
             table (str): The table to upsert into
             keyvalues (dict): The unique key tables and their new values
             values (dict): The nonunique columns and their new values
+            insertion_values (dict): key/values to use when inserting
         Returns: A deferred
         """
         return self.runInteraction(
-            "_simple_upsert",
-            self._simple_upsert_txn, table, keyvalues, values
+            desc,
+            self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+            lock
         )
 
-    def _simple_upsert_txn(self, txn, table, keyvalues, values):
+    def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
+                           lock=True):
+        # We need to lock the table :(, unless we're *really* careful
+        if lock:
+            self.database_engine.lock_table(txn, table)
+
         # Try to update
         sql = "UPDATE %s SET %s WHERE %s" % (
             table,
@@ -386,6 +573,7 @@ class SQLBaseStore(object):
             allvalues = {}
             allvalues.update(keyvalues)
             allvalues.update(values)
+            allvalues.update(insertion_values)
 
             sql = "INSERT INTO %s (%s) VALUES (%s)" % (
                 table,
@@ -399,7 +587,7 @@ class SQLBaseStore(object):
             txn.execute(sql, allvalues.values())
 
     def _simple_select_one(self, table, keyvalues, retcols,
-                           allow_none=False):
+                           allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
@@ -411,12 +599,15 @@ class SQLBaseStore(object):
             allow_none : If true, return None instead of failing if the SELECT
               statement returns no rows
         """
-        return self._simple_selectupdate_one(
-            table, keyvalues, retcols=retcols, allow_none=allow_none
+        return self.runInteraction(
+            desc,
+            self._simple_select_one_txn,
+            table, keyvalues, retcols, allow_none,
         )
 
     def _simple_select_one_onecol(self, table, keyvalues, retcol,
-                                  allow_none=False):
+                                  allow_none=False,
+                                  desc="_simple_select_one_onecol"):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it."
 
@@ -426,7 +617,7 @@ class SQLBaseStore(object):
             retcol : string giving the name of the column to return
         """
         return self.runInteraction(
-            "_simple_select_one_onecol",
+            desc,
             self._simple_select_one_onecol_txn,
             table, keyvalues, retcol, allow_none=allow_none,
         )
@@ -450,8 +641,7 @@ class SQLBaseStore(object):
 
     def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
         sql = (
-            "SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
-            "ORDER BY rowid asc"
+            "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
         ) % {
             "retcol": retcol,
             "table": table,
@@ -462,7 +652,8 @@ class SQLBaseStore(object):
 
         return [r[0] for r in txn.fetchall()]
 
-    def _simple_select_onecol(self, table, keyvalues, retcol):
+    def _simple_select_onecol(self, table, keyvalues, retcol,
+                              desc="_simple_select_onecol"):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
@@ -475,12 +666,13 @@ class SQLBaseStore(object):
             Deferred: Results in a list
         """
         return self.runInteraction(
-            "_simple_select_onecol",
+            desc,
             self._simple_select_onecol_txn,
             table, keyvalues, retcol
         )
 
-    def _simple_select_list(self, table, keyvalues, retcols):
+    def _simple_select_list(self, table, keyvalues, retcols,
+                            desc="_simple_select_list"):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -491,7 +683,7 @@ class SQLBaseStore(object):
             retcols : list of strings giving the names of the columns to return
         """
         return self.runInteraction(
-            "_simple_select_list",
+            desc,
             self._simple_select_list_txn,
             table, keyvalues, retcols
         )
@@ -507,14 +699,14 @@ class SQLBaseStore(object):
             retcols : list of strings giving the names of the columns to return
         """
         if keyvalues:
-            sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+            sql = "SELECT %s FROM %s WHERE %s" % (
                 ", ".join(retcols),
                 table,
                 " AND ".join("%s = ?" % (k, ) for k in keyvalues)
             )
             txn.execute(sql, keyvalues.values())
         else:
-            sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+            sql = "SELECT %s FROM %s" % (
                 ", ".join(retcols),
                 table
             )
@@ -523,7 +715,7 @@ class SQLBaseStore(object):
         return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
-                           retcols=None):
+                           desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
@@ -541,56 +733,81 @@ class SQLBaseStore(object):
         get-and-set.  This can be used to implement compare-and-set by putting
         the update column in the 'keyvalues' dict as well.
         """
-        return self._simple_selectupdate_one(table, keyvalues, updatevalues,
-                                             retcols=retcols)
+        return self.runInteraction(
+            desc,
+            self._simple_update_one_txn,
+            table, keyvalues, updatevalues,
+        )
 
-    def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
-                                 retcols=None, allow_none=False):
-        """ Combined SELECT then UPDATE."""
-        if retcols:
-            select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k) for k in keyvalues)
-            )
+    def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+        update_sql = "UPDATE %s SET %s WHERE %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+        )
 
-        if updatevalues:
-            update_sql = "UPDATE %s SET %s WHERE %s" % (
-                table,
-                ", ".join("%s = ?" % (k,) for k in updatevalues),
-                " AND ".join("%s = ?" % (k,) for k in keyvalues)
-            )
+        txn.execute(
+            update_sql,
+            updatevalues.values() + keyvalues.values()
+        )
+
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched")
+
+    def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+                               allow_none=False):
+        select_sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+        )
 
+        txn.execute(select_sql, keyvalues.values())
+
+        row = txn.fetchone()
+        if not row:
+            if allow_none:
+                return None
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched")
+
+        return dict(zip(retcols, row))
+
+    def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+                                 retcols=None, allow_none=False,
+                                 desc="_simple_selectupdate_one"):
+        """ Combined SELECT then UPDATE."""
         def func(txn):
             ret = None
             if retcols:
-                txn.execute(select_sql, keyvalues.values())
-
-                row = txn.fetchone()
-                if not row:
-                    if allow_none:
-                        return None
-                    raise StoreError(404, "No row found")
-                if txn.rowcount > 1:
-                    raise StoreError(500, "More than one row matched")
-
-                ret = dict(zip(retcols, row))
+                ret = self._simple_select_one_txn(
+                    txn,
+                    table=table,
+                    keyvalues=keyvalues,
+                    retcols=retcols,
+                    allow_none=allow_none,
+                )
 
             if updatevalues:
-                txn.execute(
-                    update_sql,
-                    updatevalues.values() + keyvalues.values()
+                self._simple_update_one_txn(
+                    txn,
+                    table=table,
+                    keyvalues=keyvalues,
+                    updatevalues=updatevalues,
                 )
 
-                if txn.rowcount == 0:
-                    raise StoreError(404, "No row found")
+                # if txn.rowcount == 0:
+                #     raise StoreError(404, "No row found")
                 if txn.rowcount > 1:
                     raise StoreError(500, "More than one row matched")
 
             return ret
-        return self.runInteraction("_simple_selectupdate_one", func)
+        return self.runInteraction(desc, func)
 
-    def _simple_delete_one(self, table, keyvalues):
+    def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
@@ -609,9 +826,9 @@ class SQLBaseStore(object):
                 raise StoreError(404, "No row found")
             if txn.rowcount > 1:
                 raise StoreError(500, "more than one row matched")
-        return self.runInteraction("_simple_delete_one", func)
+        return self.runInteraction(desc, func)
 
-    def _simple_delete(self, table, keyvalues):
+    def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
         """Executes a DELETE query on the named table.
 
         Args:
@@ -619,7 +836,7 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the row with
         """
 
-        return self.runInteraction("_simple_delete", self._simple_delete_txn)
+        return self.runInteraction(desc, self._simple_delete_txn)
 
     def _simple_delete_txn(self, txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (
@@ -670,6 +887,12 @@ class SQLBaseStore(object):
 
         return [e for e in events if e]
 
+    def _invalidate_get_event_cache(self, event_id):
+        for check_redacted in (False, True):
+            for get_prev_content in (False, True):
+                self._get_event_cache.invalidate(event_id, check_redacted,
+                                                 get_prev_content)
+
     def _get_event_txn(self, txn, event_id, check_redacted=True,
                        get_prev_content=False, allow_rejected=False):
 
@@ -680,16 +903,14 @@ class SQLBaseStore(object):
             sql_getevents_timer.inc_by(curr_time - last_time, desc)
             return curr_time
 
-        cache = self._get_event_cache.setdefault(event_id, {})
-
         try:
-            # Separate cache entries for each way to invoke _get_event_txn
-            ret = cache[(check_redacted, get_prev_content, allow_rejected)]
+            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
 
-            cache_counter.inc_hits("*getEvent*")
-            return ret
+            if allow_rejected or not ret.rejected_reason:
+                return ret
+            else:
+                return None
         except KeyError:
-            cache_counter.inc_misses("*getEvent*")
             pass
         finally:
             start_time = update_counter("event_cache", start_time)
@@ -714,19 +935,22 @@ class SQLBaseStore(object):
 
         start_time = update_counter("select_event", start_time)
 
+        result = self._get_event_from_row_txn(
+            txn, internal_metadata, js, redacted,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            rejected_reason=rejected_reason,
+        )
+        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
         if allow_rejected or not rejected_reason:
-            result = self._get_event_from_row_txn(
-                txn, internal_metadata, js, redacted,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content,
-            )
-            cache[(check_redacted, get_prev_content, allow_rejected)] = result
             return result
         else:
             return None
 
     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
-                                check_redacted=True, get_prev_content=False):
+                                check_redacted=True, get_prev_content=False,
+                                rejected_reason=None):
 
         start_time = time.time() * 1000
 
@@ -741,7 +965,11 @@ class SQLBaseStore(object):
         internal_metadata = json.loads(internal_metadata)
         start_time = update_counter("decode_internal", start_time)
 
-        ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
         start_time = update_counter("build_frozen_event", start_time)
 
         if check_redacted and redacted:
@@ -788,6 +1016,19 @@ class SQLBaseStore(object):
         result = txn.fetchone()
         return result[0] if result else None
 
+    def get_next_stream_id(self):
+        with self._next_stream_id_lock:
+            i = self._next_stream_id
+            self._next_stream_id += 1
+            return i
+
+
+class _RollbackButIsFineException(Exception):
+    """ This exception is used to rollback a transaction without implying
+    something went wrong.
+    """
+    pass
+
 
 class Table(object):
     """ A base class used to store information about a particular table.
@@ -804,7 +1045,7 @@ class Table(object):
 
     _select_where_clause = "SELECT %s FROM %s WHERE %s"
     _select_clause = "SELECT %s FROM %s"
-    _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+    _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
 
     @classmethod
     def select_statement(cls, where_clause=None):