diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 331 |
1 files changed, 110 insertions, 221 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c8c76e58fe..39884c2afe 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,10 +15,8 @@ import logging from synapse.api.errors import StoreError -from synapse.events import FrozenEvent -from synapse.events.utils import prune_event from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext, LoggingContext +from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics @@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer from collections import namedtuple, OrderedDict + import functools -import simplejson as json import sys import time import threading @@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"]) caches_by_name = {} cache_counter = metrics.register_cache( @@ -307,6 +304,12 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + self._pending_ds = [] + self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator() @@ -315,6 +318,7 @@ class SQLBaseStore(object): self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) + self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -345,6 +349,75 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) + def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): + start = time.time() * 1000 + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + txn = LoggingTransaction( + txn, name, self.database_engine, after_callbacks + ) + r = func(txn, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warn( + "[TXN OPERROR] {%s} %s %d/%d", + name, e, i, N + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = time.time() * 1000 + duration = end - start + + transaction_logger.debug("[TXN END] {%s} %f", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, start, end) + sql_txn_timer.inc_by(duration, desc) + @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" @@ -356,82 +429,50 @@ class SQLBaseStore(object): def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: + sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) - start = time.time() * 1000 - txn_id = self._TXN_ID + return self._new_transaction( + conn, desc, after_callbacks, func, *args, **kwargs + ) - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - name = "%s-%x" % (desc, txn_id, ) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) + defer.returnValue(result) + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runInteraction() method on the underlying db_pool.""" + current_context = LoggingContext.current_context() + + start_time = time.time() * 1000 + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - transaction_logger.debug("[TXN START] {%s}", name) - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks - ) - return func(txn, *args, **kwargs) - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warn( - "[TXN OPERROR] {%s} %s %d/%d", - name, e, i, N - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = time.time() * 1000 - duration = end - start - transaction_logger.debug("[TXN END] {%s} %f", name, duration) + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, start, end) - sql_txn_timer.inc_by(duration, desc) + current_context.copy_to(context) + + return func(conn, *args, **kwargs) + + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) def cursor_to_dict(self, cursor): @@ -871,158 +912,6 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False): - return self.runInteraction( - "_get_events", self._get_events_txn, event_ids, - check_redacted=check_redacted, get_prev_content=get_prev_content, - ) - - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False): - if not event_ids: - return [] - - events = [ - self._get_event_txn( - txn, event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content - ) - for event_id in event_ids - ] - - return [e for e in events if e] - - def _invalidate_get_event_cache(self, event_id): - for check_redacted in (False, True): - for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) - - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) - - if allow_rejected or not ret.rejected_reason: - return ret - else: - return None - except KeyError: - pass - finally: - start_time = update_counter("event_cache", start_time) - - sql = ( - "SELECT e.internal_metadata, e.json, r.event_id, rej.reason " - "FROM event_json as e " - "LEFT JOIN redactions as r ON e.event_id = r.redacts " - "LEFT JOIN rejections as rej on rej.event_id = e.event_id " - "WHERE e.event_id = ? " - "LIMIT 1 " - ) - - txn.execute(sql, (event_id,)) - - res = txn.fetchone() - - if not res: - return None - - internal_metadata, js, redacted, rejected_reason = res - - start_time = update_counter("select_event", start_time) - - result = self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=rejected_reason, - ) - self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result) - - if allow_rejected or not rejected_reason: - return result - else: - return None - - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - d = json.loads(js) - start_time = update_counter("decode_json", start_time) - - internal_metadata = json.loads(internal_metadata) - start_time = update_counter("decode_internal", start_time) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - start_time = update_counter("build_frozen_event", start_time) - - if check_redacted and redacted: - ev = prune_event(ev) - - ev.unsigned["redacted_by"] = redacted - # Get the redaction event. - - because = self._get_event_txn( - txn, - redacted, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - start_time = update_counter("redact_event", start_time) - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] - start_time = update_counter("get_prev_content", start_time) - - return ev - - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) - - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id |