diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c8c76e58fe..39884c2afe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,10 +15,8 @@
import logging
from synapse.api.errors import StoreError
-from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
@@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple, OrderedDict
+
import functools
-import simplejson as json
import sys
import time
import threading
@@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
-sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
caches_by_name = {}
cache_counter = metrics.register_cache(
@@ -307,6 +304,12 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
+ self._event_fetch_lock = threading.Condition()
+ self._event_fetch_list = []
+ self._event_fetch_ongoing = 0
+
+ self._pending_ds = []
+
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator()
@@ -315,6 +318,7 @@ class SQLBaseStore(object):
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+ self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -345,6 +349,75 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
+ def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
+ start = time.time() * 1000
+ txn_id = self._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+
+ try:
+ i = 0
+ N = 5
+ while True:
+ try:
+ txn = conn.cursor()
+ txn = LoggingTransaction(
+ txn, name, self.database_engine, after_callbacks
+ )
+ r = func(txn, *args, **kwargs)
+ conn.commit()
+ return r
+ except self.database_engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warn(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name, e, i, N
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ raise
+ except self.database_engine.module.DatabaseError as e:
+ if self.database_engine.is_deadlock(e):
+ logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ raise
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
+ raise
+ finally:
+ end = time.time() * 1000
+ duration = end - start
+
+ transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+
+ self._current_txn_total_time += duration
+ self._txn_perf_counters.update(desc, start, end)
+ sql_txn_timer.inc_by(duration, desc)
+
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
@@ -356,82 +429,50 @@ class SQLBaseStore(object):
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
+ sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
- start = time.time() * 1000
- txn_id = self._TXN_ID
+ return self._new_transaction(
+ conn, desc, after_callbacks, func, *args, **kwargs
+ )
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+ result = yield preserve_context_over_fn(
+ self._db_pool.runWithConnection,
+ inner_func, *args, **kwargs
+ )
- name = "%s-%x" % (desc, txn_id, )
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def runWithConnection(self, func, *args, **kwargs):
+ """Wraps the .runInteraction() method on the underlying db_pool."""
+ current_context = LoggingContext.current_context()
+
+ start_time = time.time() * 1000
+ def inner_func(conn, *args, **kwargs):
+ with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
- transaction_logger.debug("[TXN START] {%s}", name)
- try:
- i = 0
- N = 5
- while True:
- try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks
- )
- return func(txn, *args, **kwargs)
- except self.database_engine.module.OperationalError as e:
- # This can happen if the database disappears mid
- # transaction.
- logger.warn(
- "[TXN OPERROR] {%s} %s %d/%d",
- name, e, i, N
- )
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warn(
- "[TXN EROLL] {%s} %s",
- name, e1,
- )
- continue
- except self.database_engine.module.DatabaseError as e:
- if self.database_engine.is_deadlock(e):
- logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warn(
- "[TXN EROLL] {%s} %s",
- name, e1,
- )
- continue
- raise
- except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
- raise
- finally:
- end = time.time() * 1000
- duration = end - start
- transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+ if self.database_engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
- self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.inc_by(duration, desc)
+ current_context.copy_to(context)
+
+ return func(conn, *args, **kwargs)
+
+ result = yield preserve_context_over_fn(
+ self._db_pool.runWithConnection,
+ inner_func, *args, **kwargs
+ )
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
- for after_callback, after_args in after_callbacks:
- after_callback(*after_args)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@@ -871,158 +912,6 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func)
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False):
- return self.runInteraction(
- "_get_events", self._get_events_txn, event_ids,
- check_redacted=check_redacted, get_prev_content=get_prev_content,
- )
-
- def _get_events_txn(self, txn, event_ids, check_redacted=True,
- get_prev_content=False):
- if not event_ids:
- return []
-
- events = [
- self._get_event_txn(
- txn, event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content
- )
- for event_id in event_ids
- ]
-
- return [e for e in events if e]
-
- def _invalidate_get_event_cache(self, event_id):
- for check_redacted in (False, True):
- for get_prev_content in (False, True):
- self._get_event_cache.invalidate(event_id, check_redacted,
- get_prev_content)
-
- def _get_event_txn(self, txn, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
-
- start_time = time.time() * 1000
-
- def update_counter(desc, last_time):
- curr_time = self._get_event_counters.update(desc, last_time)
- sql_getevents_timer.inc_by(curr_time - last_time, desc)
- return curr_time
-
- try:
- ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
-
- if allow_rejected or not ret.rejected_reason:
- return ret
- else:
- return None
- except KeyError:
- pass
- finally:
- start_time = update_counter("event_cache", start_time)
-
- sql = (
- "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
- "FROM event_json as e "
- "LEFT JOIN redactions as r ON e.event_id = r.redacts "
- "LEFT JOIN rejections as rej on rej.event_id = e.event_id "
- "WHERE e.event_id = ? "
- "LIMIT 1 "
- )
-
- txn.execute(sql, (event_id,))
-
- res = txn.fetchone()
-
- if not res:
- return None
-
- internal_metadata, js, redacted, rejected_reason = res
-
- start_time = update_counter("select_event", start_time)
-
- result = self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- rejected_reason=rejected_reason,
- )
- self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
-
- if allow_rejected or not rejected_reason:
- return result
- else:
- return None
-
- def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False,
- rejected_reason=None):
-
- start_time = time.time() * 1000
-
- def update_counter(desc, last_time):
- curr_time = self._get_event_counters.update(desc, last_time)
- sql_getevents_timer.inc_by(curr_time - last_time, desc)
- return curr_time
-
- d = json.loads(js)
- start_time = update_counter("decode_json", start_time)
-
- internal_metadata = json.loads(internal_metadata)
- start_time = update_counter("decode_internal", start_time)
-
- ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
- start_time = update_counter("build_frozen_event", start_time)
-
- if check_redacted and redacted:
- ev = prune_event(ev)
-
- ev.unsigned["redacted_by"] = redacted
- # Get the redaction event.
-
- because = self._get_event_txn(
- txn,
- redacted,
- check_redacted=False
- )
-
- if because:
- ev.unsigned["redacted_because"] = because
- start_time = update_counter("redact_event", start_time)
-
- if get_prev_content and "replaces_state" in ev.unsigned:
- prev = self._get_event_txn(
- txn,
- ev.unsigned["replaces_state"],
- get_prev_content=False,
- )
- if prev:
- ev.unsigned["prev_content"] = prev.get_dict()["content"]
- start_time = update_counter("get_prev_content", start_time)
-
- return ev
-
- def _parse_events(self, rows):
- return self.runInteraction(
- "_parse_events", self._parse_events_txn, rows
- )
-
- def _parse_events_txn(self, txn, rows):
- event_ids = [r["event_id"] for r in rows]
-
- return self._get_events_txn(txn, event_ids)
-
- def _has_been_redacted_txn(self, txn, event):
- sql = "SELECT event_id FROM redactions WHERE redacts = ?"
- txn.execute(sql, (event.event_id,))
- result = txn.fetchone()
- return result[0] if result else None
-
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
|