summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py331
1 files changed, 110 insertions, 221 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c8c76e58fe..39884c2afe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,10 +15,8 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
@@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
+
 import functools
-import simplejson as json
 import sys
 import time
 import threading
@@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
 
 sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
 sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
-sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
 
 caches_by_name = {}
 cache_counter = metrics.register_cache(
@@ -307,6 +304,12 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
+        self._event_fetch_lock = threading.Condition()
+        self._event_fetch_list = []
+        self._event_fetch_ongoing = 0
+
+        self._pending_ds = []
+
         self.database_engine = hs.database_engine
 
         self._stream_id_gen = StreamIdGenerator()
@@ -315,6 +318,7 @@ class SQLBaseStore(object):
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
         self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
@@ -345,6 +349,75 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
+    def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
+        start = time.time() * 1000
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+        name = "%s-%x" % (desc, txn_id, )
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                try:
+                    txn = conn.cursor()
+                    txn = LoggingTransaction(
+                        txn, name, self.database_engine, after_callbacks
+                    )
+                    r = func(txn, *args, **kwargs)
+                    conn.commit()
+                    return r
+                except self.database_engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warn(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name, e, i, N
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.database_engine.module.Error as e1:
+                            logger.warn(
+                                "[TXN EROLL] {%s} %s",
+                                name, e1,
+                            )
+                        continue
+                    raise
+                except self.database_engine.module.DatabaseError as e:
+                    if self.database_engine.is_deadlock(e):
+                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.database_engine.module.Error as e1:
+                                logger.warn(
+                                    "[TXN EROLL] {%s} %s",
+                                    name, e1,
+                                )
+                            continue
+                    raise
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = time.time() * 1000
+            duration = end - start
+
+            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, start, end)
+            sql_txn_timer.inc_by(duration, desc)
+
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
@@ -356,82 +429,50 @@ class SQLBaseStore(object):
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
+                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
                 current_context.copy_to(context)
-                start = time.time() * 1000
-                txn_id = self._TXN_ID
+                return self._new_transaction(
+                    conn, desc, after_callbacks, func, *args, **kwargs
+                )
 
-                # We don't really need these to be unique, so lets stop it from
-                # growing really large.
-                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-                name = "%s-%x" % (desc, txn_id, )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runInteraction() method on the underlying db_pool."""
+        current_context = LoggingContext.current_context()
+
+        start_time = time.time() * 1000
 
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection") as context:
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-                transaction_logger.debug("[TXN START] {%s}", name)
-                try:
-                    i = 0
-                    N = 5
-                    while True:
-                        try:
-                            txn = conn.cursor()
-                            txn = LoggingTransaction(
-                                txn, name, self.database_engine, after_callbacks
-                            )
-                            return func(txn, *args, **kwargs)
-                        except self.database_engine.module.OperationalError as e:
-                            # This can happen if the database disappears mid
-                            # transaction.
-                            logger.warn(
-                                "[TXN OPERROR] {%s} %s %d/%d",
-                                name, e, i, N
-                            )
-                            if i < N:
-                                i += 1
-                                try:
-                                    conn.rollback()
-                                except self.database_engine.module.Error as e1:
-                                    logger.warn(
-                                        "[TXN EROLL] {%s} %s",
-                                        name, e1,
-                                    )
-                                continue
-                        except self.database_engine.module.DatabaseError as e:
-                            if self.database_engine.is_deadlock(e):
-                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                                if i < N:
-                                    i += 1
-                                    try:
-                                        conn.rollback()
-                                    except self.database_engine.module.Error as e1:
-                                        logger.warn(
-                                            "[TXN EROLL] {%s} %s",
-                                            name, e1,
-                                        )
-                                    continue
-                            raise
-                except Exception as e:
-                    logger.debug("[TXN FAIL] {%s} %s", name, e)
-                    raise
-                finally:
-                    end = time.time() * 1000
-                    duration = end - start
 
-                    transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+                if self.database_engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
 
-                    self._current_txn_total_time += duration
-                    self._txn_perf_counters.update(desc, start, end)
-                    sql_txn_timer.inc_by(duration, desc)
+                current_context.copy_to(context)
+
+                return func(conn, *args, **kwargs)
+
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
-        for after_callback, after_args in after_callbacks:
-            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
@@ -871,158 +912,6 @@ class SQLBaseStore(object):
 
         return self.runInteraction("_simple_max_id", func)
 
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False):
-        return self.runInteraction(
-            "_get_events", self._get_events_txn, event_ids,
-            check_redacted=check_redacted, get_prev_content=get_prev_content,
-        )
-
-    def _get_events_txn(self, txn, event_ids, check_redacted=True,
-                        get_prev_content=False):
-        if not event_ids:
-            return []
-
-        events = [
-            self._get_event_txn(
-                txn, event_id,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content
-            )
-            for event_id in event_ids
-        ]
-
-        return [e for e in events if e]
-
-    def _invalidate_get_event_cache(self, event_id):
-        for check_redacted in (False, True):
-            for get_prev_content in (False, True):
-                self._get_event_cache.invalidate(event_id, check_redacted,
-                                                 get_prev_content)
-
-    def _get_event_txn(self, txn, event_id, check_redacted=True,
-                       get_prev_content=False, allow_rejected=False):
-
-        start_time = time.time() * 1000
-
-        def update_counter(desc, last_time):
-            curr_time = self._get_event_counters.update(desc, last_time)
-            sql_getevents_timer.inc_by(curr_time - last_time, desc)
-            return curr_time
-
-        try:
-            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
-
-            if allow_rejected or not ret.rejected_reason:
-                return ret
-            else:
-                return None
-        except KeyError:
-            pass
-        finally:
-            start_time = update_counter("event_cache", start_time)
-
-        sql = (
-            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
-            "FROM event_json as e "
-            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
-            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
-            "WHERE e.event_id = ? "
-            "LIMIT 1 "
-        )
-
-        txn.execute(sql, (event_id,))
-
-        res = txn.fetchone()
-
-        if not res:
-            return None
-
-        internal_metadata, js, redacted, rejected_reason = res
-
-        start_time = update_counter("select_event", start_time)
-
-        result = self._get_event_from_row_txn(
-            txn, internal_metadata, js, redacted,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            rejected_reason=rejected_reason,
-        )
-        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
-
-        if allow_rejected or not rejected_reason:
-            return result
-        else:
-            return None
-
-    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
-                                check_redacted=True, get_prev_content=False,
-                                rejected_reason=None):
-
-        start_time = time.time() * 1000
-
-        def update_counter(desc, last_time):
-            curr_time = self._get_event_counters.update(desc, last_time)
-            sql_getevents_timer.inc_by(curr_time - last_time, desc)
-            return curr_time
-
-        d = json.loads(js)
-        start_time = update_counter("decode_json", start_time)
-
-        internal_metadata = json.loads(internal_metadata)
-        start_time = update_counter("decode_internal", start_time)
-
-        ev = FrozenEvent(
-            d,
-            internal_metadata_dict=internal_metadata,
-            rejected_reason=rejected_reason,
-        )
-        start_time = update_counter("build_frozen_event", start_time)
-
-        if check_redacted and redacted:
-            ev = prune_event(ev)
-
-            ev.unsigned["redacted_by"] = redacted
-            # Get the redaction event.
-
-            because = self._get_event_txn(
-                txn,
-                redacted,
-                check_redacted=False
-            )
-
-            if because:
-                ev.unsigned["redacted_because"] = because
-            start_time = update_counter("redact_event", start_time)
-
-        if get_prev_content and "replaces_state" in ev.unsigned:
-            prev = self._get_event_txn(
-                txn,
-                ev.unsigned["replaces_state"],
-                get_prev_content=False,
-            )
-            if prev:
-                ev.unsigned["prev_content"] = prev.get_dict()["content"]
-            start_time = update_counter("get_prev_content", start_time)
-
-        return ev
-
-    def _parse_events(self, rows):
-        return self.runInteraction(
-            "_parse_events", self._parse_events_txn, rows
-        )
-
-    def _parse_events_txn(self, txn, rows):
-        event_ids = [r["event_id"] for r in rows]
-
-        return self._get_events_txn(txn, event_ids)
-
-    def _has_been_redacted_txn(self, txn, event):
-        sql = "SELECT event_id FROM redactions WHERE redacts = ?"
-        txn.execute(sql, (event.event_id,))
-        result = txn.fetchone()
-        return result[0] if result else None
-
     def get_next_stream_id(self):
         with self._next_stream_id_lock:
             i = self._next_stream_id