diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 2 | ||||
-rw-r--r-- | synapse/storage/_base.py | 331 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 2 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite3.py | 2 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 176 | ||||
-rw-r--r-- | synapse/storage/events.py | 487 | ||||
-rw-r--r-- | synapse/storage/presence.py | 35 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 32 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 39 | ||||
-rw-r--r-- | synapse/storage/schema/delta/19/event_index.sql | 19 | ||||
-rw-r--r-- | synapse/storage/state.py | 28 | ||||
-rw-r--r-- | synapse/storage/stream.py | 181 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 12 |
13 files changed, 873 insertions, 473 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7cb91a0be9..75af44d787 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 18 +SCHEMA_VERSION = 19 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c8c76e58fe..39884c2afe 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,10 +15,8 @@ import logging from synapse.api.errors import StoreError -from synapse.events import FrozenEvent -from synapse.events.utils import prune_event from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext, LoggingContext +from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics @@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer from collections import namedtuple, OrderedDict + import functools -import simplejson as json import sys import time import threading @@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"]) caches_by_name = {} cache_counter = metrics.register_cache( @@ -307,6 +304,12 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + self._pending_ds = [] + self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator() @@ -315,6 +318,7 @@ class SQLBaseStore(object): self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) + self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -345,6 +349,75 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) + def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): + start = time.time() * 1000 + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + txn = LoggingTransaction( + txn, name, self.database_engine, after_callbacks + ) + r = func(txn, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warn( + "[TXN OPERROR] {%s} %s %d/%d", + name, e, i, N + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = time.time() * 1000 + duration = end - start + + transaction_logger.debug("[TXN END] {%s} %f", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, start, end) + sql_txn_timer.inc_by(duration, desc) + @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" @@ -356,82 +429,50 @@ class SQLBaseStore(object): def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: + sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) - start = time.time() * 1000 - txn_id = self._TXN_ID + return self._new_transaction( + conn, desc, after_callbacks, func, *args, **kwargs + ) - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - name = "%s-%x" % (desc, txn_id, ) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) + defer.returnValue(result) + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runInteraction() method on the underlying db_pool.""" + current_context = LoggingContext.current_context() + + start_time = time.time() * 1000 + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - transaction_logger.debug("[TXN START] {%s}", name) - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks - ) - return func(txn, *args, **kwargs) - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warn( - "[TXN OPERROR] {%s} %s %d/%d", - name, e, i, N - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = time.time() * 1000 - duration = end - start - transaction_logger.debug("[TXN END] {%s} %f", name, duration) + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, start, end) - sql_txn_timer.inc_by(duration, desc) + current_context.copy_to(context) + + return func(conn, *args, **kwargs) + + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) def cursor_to_dict(self, cursor): @@ -871,158 +912,6 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False): - return self.runInteraction( - "_get_events", self._get_events_txn, event_ids, - check_redacted=check_redacted, get_prev_content=get_prev_content, - ) - - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False): - if not event_ids: - return [] - - events = [ - self._get_event_txn( - txn, event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content - ) - for event_id in event_ids - ] - - return [e for e in events if e] - - def _invalidate_get_event_cache(self, event_id): - for check_redacted in (False, True): - for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) - - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) - - if allow_rejected or not ret.rejected_reason: - return ret - else: - return None - except KeyError: - pass - finally: - start_time = update_counter("event_cache", start_time) - - sql = ( - "SELECT e.internal_metadata, e.json, r.event_id, rej.reason " - "FROM event_json as e " - "LEFT JOIN redactions as r ON e.event_id = r.redacts " - "LEFT JOIN rejections as rej on rej.event_id = e.event_id " - "WHERE e.event_id = ? " - "LIMIT 1 " - ) - - txn.execute(sql, (event_id,)) - - res = txn.fetchone() - - if not res: - return None - - internal_metadata, js, redacted, rejected_reason = res - - start_time = update_counter("select_event", start_time) - - result = self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=rejected_reason, - ) - self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result) - - if allow_rejected or not rejected_reason: - return result - else: - return None - - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - d = json.loads(js) - start_time = update_counter("decode_json", start_time) - - internal_metadata = json.loads(internal_metadata) - start_time = update_counter("decode_internal", start_time) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - start_time = update_counter("build_frozen_event", start_time) - - if check_redacted and redacted: - ev = prune_event(ev) - - ev.unsigned["redacted_by"] = redacted - # Get the redaction event. - - because = self._get_event_txn( - txn, - redacted, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - start_time = update_counter("redact_event", start_time) - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] - start_time = update_counter("get_prev_content", start_time) - - return ev - - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) - - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a323028546..4a855ffd56 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): + single_threaded = False + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index ff13d8006a..d18e2808d1 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database class Sqlite3Engine(object): + single_threaded = True + def __init__(self, database_module): self.module = database_module diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 74b4e23590..1ba073884b 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from ._base import SQLBaseStore, cached from syutil.base64util import encode_base64 import logging +from Queue import PriorityQueue, Empty logger = logging.getLogger(__name__) @@ -33,16 +36,7 @@ class EventFederationStore(SQLBaseStore): """ def get_auth_chain(self, event_ids): - return self.runInteraction( - "get_auth_chain", - self._get_auth_chain_txn, - event_ids - ) - - def _get_auth_chain_txn(self, txn, event_ids): - results = self._get_auth_chain_ids_txn(txn, event_ids) - - return self._get_events_txn(txn, results) + return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) def get_auth_chain_ids(self, event_ids): return self.runInteraction( @@ -79,6 +73,28 @@ class EventFederationStore(SQLBaseStore): room_id, ) + def get_oldest_events_with_depth_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_with_depth_in_room", + self.get_oldest_events_with_depth_in_room_txn, + room_id, + ) + + def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): + sql = ( + "SELECT b.event_id, MAX(e.depth) FROM events as e" + " INNER JOIN event_edges as g" + " ON g.event_id = e.event_id AND g.room_id = e.room_id" + " INNER JOIN event_backward_extremities as b" + " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id" + " WHERE b.room_id = ? AND g.is_state is ?" + " GROUP BY b.event_id" + ) + + txn.execute(sql, (room_id, False,)) + + return dict(txn.fetchall()) + def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( txn, @@ -247,11 +263,13 @@ class EventFederationStore(SQLBaseStore): do_insert = depth < min_depth if min_depth else True if do_insert: - self._simple_insert_txn( + self._simple_upsert_txn( txn, table="room_depth", - values={ + keyvalues={ "room_id": room_id, + }, + values={ "min_depth": depth, }, ) @@ -306,31 +324,28 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (event_id, room_id)) - # Insert all the prev_events as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - self._simple_insert_many_txn( - txn, - table="event_backward_extremities", - values=[ - { - "event_id": e_id, - "room_id": room_id, - } - for e_id, _ in prev_events - ], + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" ) - # Also delete from the backwards extremities table all ones that - # reference events that we have already seen + txn.executemany(query, [ + (e_id, room_id, e_id, room_id, e_id, room_id, False) + for e_id, _ in prev_events + ]) + query = ( - "DELETE FROM event_backward_extremities WHERE EXISTS (" - "SELECT 1 FROM events " - "WHERE " - "event_backward_extremities.event_id = events.event_id " - "AND not events.outlier " - ")" + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" ) - txn.execute(query) + txn.execute(query, (event_id, room_id)) txn.call_after( self.get_latest_event_ids_in_room.invalidate, room_id @@ -349,6 +364,10 @@ class EventFederationStore(SQLBaseStore): return self.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, event_list, limit + ).addCallback( + self._get_events + ).addCallback( + lambda l: sorted(l, key=lambda e: -e.depth) ) def _get_backfill_events(self, txn, room_id, event_list, limit): @@ -357,54 +376,75 @@ class EventFederationStore(SQLBaseStore): room_id, repr(event_list), limit ) - event_results = event_list + event_results = set() - front = event_list + # We want to make sure that we do a breadth-first, "depth" ordered + # search. query = ( - "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? " - "LIMIT ?" + "SELECT depth, prev_event_id FROM event_edges" + " INNER JOIN events" + " ON prev_event_id = events.event_id" + " AND event_edges.room_id = events.room_id" + " WHERE event_edges.room_id = ? AND event_edges.event_id = ?" + " AND event_edges.is_state = ?" + " LIMIT ?" ) - # We iterate through all event_ids in `front` to select their previous - # events. These are dumped in `new_front`. - # We continue until we reach the limit *or* new_front is empty (i.e., - # we've run out of things to select - while front and len(event_results) < limit: + queue = PriorityQueue() - new_front = [] - for event_id in front: - logger.debug( - "_backfill_interaction: id=%s", - event_id - ) + for event_id in event_list: + depth = self._simple_select_one_onecol_txn( + txn, + table="events", + keyvalues={ + "event_id": event_id, + }, + retcol="depth" + ) - txn.execute( - query, - (room_id, event_id, limit - len(event_results)) - ) + queue.put((-depth, event_id)) - for row in txn.fetchall(): - logger.debug( - "_backfill_interaction: got id=%s", - *row - ) - new_front.append(row[0]) + while not queue.empty() and len(event_results) < limit: + try: + _, event_id = queue.get_nowait() + except Empty: + break - front = new_front - event_results += new_front + if event_id in event_results: + continue + + event_results.add(event_id) + + txn.execute( + query, + (room_id, event_id, False, limit - len(event_results)) + ) + + for row in txn.fetchall(): + if row[1] not in event_results: + queue.put((-row[0], row[1])) - return self._get_events_txn(txn, event_results) + return event_results + @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit, min_depth): - return self.runInteraction( + ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, room_id, earliest_events, latest_events, limit, min_depth ) + events = yield self._get_events(ids) + + events = sorted( + [ev for ev in events if ev.depth >= min_depth], + key=lambda e: e.depth, + ) + + defer.returnValue(events[:limit]) + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit, min_depth): @@ -436,14 +476,7 @@ class EventFederationStore(SQLBaseStore): front = new_front event_results |= new_front - events = self._get_events_txn(txn, event_results) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - return events[:limit] + return event_results def clean_room_for_join(self, room_id): return self.runInteraction( @@ -456,3 +489,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1304219e86..d2a010bd88 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -15,20 +15,36 @@ from _base import SQLBaseStore, _RollbackButIsFineException -from twisted.internet import defer +from twisted.internet import defer, reactor +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from synapse.crypto.event_signing import compute_event_reference_hash from syutil.base64util import decode_base64 from syutil.jsonutil import encode_canonical_json +from contextlib import contextmanager import logging +import simplejson as json logger = logging.getLogger(__name__) +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function @@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore): self.min_token -= 1 stream_ordering = self.min_token + if stream_ordering is None: + stream_ordering_manager = yield self._stream_id_gen.get_next(self) + else: + @contextmanager + def stream_ordering_manager(): + yield stream_ordering + stream_ordering_manager = stream_ordering_manager() + try: - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - backfilled=backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) + with stream_ordering_manager as stream_ordering: + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + backfilled=backfilled, + stream_ordering=stream_ordering, + is_new_state=is_new_state, + current_state=current_state, + ) except _RollbackButIsFineException: pass + max_persisted_id = yield self._stream_id_gen.get_max_token(self) + defer.returnValue((stream_ordering, max_persisted_id)) + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - event = yield self.runInteraction( - "get_event", self._get_event_txn, - event_id, + events = yield self._get_events( + [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) - if not event and not allow_none: + if not events and not allow_none: raise RuntimeError("Could not find event %s" % (event_id,)) - defer.returnValue(event) + defer.returnValue(events[0] if events else None) @log_function def _persist_event_txn(self, txn, event, context, backfilled, @@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore): # Remove the any existing cache entries for the event_id txn.call_after(self._invalidate_get_event_cache, event.event_id) - if stream_ordering is None: - with self._stream_id_gen.get_next_txn(txn) as stream_ordering: - return self._persist_event_txn( - txn, event, context, backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) - # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore): outlier = event.internal_metadata.is_outlier() if not outlier: - self._store_state_groups_txn(txn, event, context) - self._update_min_depth_for_room_txn( txn, event.room_id, event.depth ) - have_persisted = self._simple_select_one_onecol_txn( + have_persisted = self._simple_select_one_txn( txn, - table="event_json", + table="events", keyvalues={"event_id": event.event_id}, - retcol="event_id", + retcols=["event_id", "outlier"], allow_none=True, ) @@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore): # if we are persisting an event that we had persisted as an outlier, # but is no longer one. if have_persisted: - if not outlier: + if not outlier and have_persisted["outlier"]: + self._store_state_groups_txn(txn, event, context) + sql = ( "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" @@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore): ) return + if not outlier: + self._store_state_groups_txn(txn, event, context) + self._handle_prev_events( txn, outlier=outlier, @@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore): return self.runInteraction( "have_events", f, ) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + def _get_events_txn(self, txn, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + return [] + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + missing_events = self._fetch_events_txn( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + def _invalidate_get_event_cache(self, event_id): + for check_redacted in (False, True): + for get_prev_content in (False, True): + self._get_event_cache.invalidate(event_id, check_redacted, + get_prev_content) + + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False): + + events = self._get_events_txn( + txn, [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return events[0] if events else None + + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + except KeyError: + pass + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + d.callback([ + res[i] + for i in ids + if i in res + ]) + except: + logger.exception("Failed to callback") + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + d.errback(e) + + if event_list: + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + self.runWithConnection( + self._do_fetch + ) + + rows = yield preserve_context_over_deferred(events_d) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield defer.gatherResults( + [ + self._get_event_from_row( + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + ) + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + return {} + + rows = self._fetch_event_rows( + txn, events, + ) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = [ + self._get_event_from_row_txn( + txn, + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ] + + return { + r.event_id: r + for r in res + } + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + desc="_get_event_from_row", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = yield self.get_event( + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + defer.returnValue(ev) + + def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = self._simple_select_one_onecol_txn( + txn, + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = self._simple_select_one_onecol_txn( + txn, + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = self._get_event_txn( + txn, + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = self._get_event_txn( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + return ev + + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) + + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) + + def _has_been_redacted_txn(self, txn, event): + sql = "SELECT event_id FROM redactions WHERE redacts = ?" + txn.execute(sql, (event.event_id,)) + result = txn.fetchone() + return result[0] if result else None diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 22ec94bc16..fefcf6bce0 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from ._base import SQLBaseStore, cached + +from twisted.internet import defer class PresenceStore(SQLBaseStore): @@ -87,31 +89,48 @@ class PresenceStore(SQLBaseStore): desc="add_presence_list_pending", ) + @defer.inlineCallbacks def set_presence_list_accepted(self, observer_localpart, observed_userid): - return self._simple_update_one( + result = yield self._simple_update_one( table="presence_list", keyvalues={"user_id": observer_localpart, "observed_user_id": observed_userid}, updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) + self.get_presence_list_accepted.invalidate(observer_localpart) + defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): - keyvalues = {"user_id": observer_localpart} - if accepted is not None: - keyvalues["accepted"] = accepted + if accepted: + return self.get_presence_list_accepted(observer_localpart) + else: + keyvalues = {"user_id": observer_localpart} + if accepted is not None: + keyvalues["accepted"] = accepted + return self._simple_select_list( + table="presence_list", + keyvalues=keyvalues, + retcols=["observed_user_id", "accepted"], + desc="get_presence_list", + ) + + @cached() + def get_presence_list_accepted(self, observer_localpart): return self._simple_select_list( table="presence_list", - keyvalues=keyvalues, + keyvalues={"user_id": observer_localpart, "accepted": True}, retcols=["observed_user_id", "accepted"], - desc="get_presence_list", + desc="get_presence_list_accepted", ) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): - return self._simple_delete_one( + yield self._simple_delete_one( table="presence_list", keyvalues={"user_id": observer_localpart, "observed_user_id": observed_userid}, desc="del_presence_list", ) + self.get_presence_list_accepted.invalidate(observer_localpart) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 88ee21b089..4cac118d17 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): + @cached() @defer.inlineCallbacks def get_push_rules_for_user(self, user_name): rows = yield self._simple_select_list( @@ -31,6 +32,7 @@ class PushRuleStore(SQLBaseStore): "user_name": user_name, }, retcols=PushRuleTable.fields, + desc="get_push_rules_enabled_for_user", ) rows.sort( @@ -151,6 +153,10 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( + self.get_push_rules_for_user.invalidate, user_name + ) + + txn.call_after( self.get_push_rules_enabled_for_user.invalidate, user_name ) @@ -183,6 +189,9 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( + self.get_push_rules_for_user.invalidate, user_name + ) + txn.call_after( self.get_push_rules_enabled_for_user.invalidate, user_name ) @@ -208,17 +217,34 @@ class PushRuleStore(SQLBaseStore): {'user_name': user_name, 'rule_id': rule_id}, desc="delete_push_rule", ) + + self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate(user_name) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): - yield self._simple_upsert( + ret = yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + user_name, rule_id, enabled + ) + defer.returnValue(ret) + + def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled): + new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + self._simple_upsert_txn( + txn, PushRuleEnableTable.table_name, {'user_name': user_name, 'rule_id': rule_id}, {'enabled': 1 if enabled else 0}, - desc="set_push_rule_enabled", + {'id': new_id}, + ) + txn.call_after( + self.get_push_rules_for_user.invalidate, user_name + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, user_name ) - self.get_push_rules_enabled_for_user.invalidate(user_name) class RuleNotFoundException(Exception): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3691eade05..d36a6c18a8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -77,16 +77,16 @@ class RoomMemberStore(SQLBaseStore): Returns: Deferred: Results in a MembershipEvent or None. """ - def f(txn): - events = self._get_members_events_txn( - txn, - room_id, - user_id=user_id, - ) - - return events[0] if events else None - - return self.runInteraction("get_room_member", f) + return self.runInteraction( + "get_room_member", + self._get_members_events_txn, + room_id, + user_id=user_id, + ).addCallback( + self._get_events + ).addCallback( + lambda events: events[0] if events else None + ) @cached() def get_users_in_room(self, room_id): @@ -112,15 +112,12 @@ class RoomMemberStore(SQLBaseStore): Returns: list of namedtuples representing the members in this room. """ - - def f(txn): - return self._get_members_events_txn( - txn, - room_id, - membership=membership, - ) - - return self.runInteraction("get_room_members", f) + return self.runInteraction( + "get_room_members", + self._get_members_events_txn, + room_id, + membership=membership, + ).addCallback(self._get_events) def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -192,14 +189,14 @@ class RoomMemberStore(SQLBaseStore): return self.runInteraction( "get_members_query", self._get_members_events_txn, where_clause, where_values - ) + ).addCallbacks(self._get_events) def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, room_id, membership, user_id, ) - return self._get_events_txn(txn, [r["event_id"] for r in rows]) + return [r["event_id"] for r in rows] def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/schema/delta/19/event_index.sql new file mode 100644 index 0000000000..3881fc9897 --- /dev/null +++ b/synapse/storage/schema/delta/19/event_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE INDEX events_order_topo_stream_room ON events( + topological_ordering, stream_ordering, room_id +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6df7350552..b24de34f23 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -43,6 +43,7 @@ class StateStore(SQLBaseStore): * `state_groups_state`: Maps state group to state events. """ + @defer.inlineCallbacks def get_state_groups(self, event_ids): """ Get the state groups for the given list of event_ids @@ -71,17 +72,29 @@ class StateStore(SQLBaseStore): retcol="event_id", ) - state = self._get_events_txn(txn, state_ids) - - res[group] = state + res[group] = state_ids return res - return self.runInteraction( + states = yield self.runInteraction( "get_state_groups", f, ) + @defer.inlineCallbacks + def c(vals): + vals[:] = yield self._get_events(vals, get_prev_content=False) + + yield defer.gatherResults( + [ + c(vals) + for vals in states.values() + ], + consumeErrors=True, + ) + + defer.returnValue(states) + def _store_state_groups_txn(self, txn, event, context): if context.current_state is None: return @@ -152,11 +165,12 @@ class StateStore(SQLBaseStore): args = (room_id, ) txn.execute(sql, args) - results = self.cursor_to_dict(txn) + results = txn.fetchall() - return self._parse_events_txn(txn, results) + return [r[0] for r in results] - events = yield self.runInteraction("get_current_state", f) + event_ids = yield self.runInteraction("get_current_state", f) + events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) @cached(num_args=3) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 280d4ad605..af45fc5619 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -37,11 +37,9 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.types import RoomStreamToken from synapse.util.logutils import log_function -from collections import namedtuple - import logging @@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" -class _StreamToken(namedtuple("_StreamToken", "topological stream")): - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] V [1] V [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream events are ordered by when they - arrived at the homeserver. - - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". - - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, follewed by "-", - followed by the "stream_ordering" id of the event it comes after. - """ - __slots__ = [] - - @classmethod - def parse(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - if string[0] == 't': - parts = string[1:].split('-', 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - def __str__(self): - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - else: - return "s%d" % (self.stream,) +def lower_bound(token): + if token.topological is None: + return "(%d < %s)" % (token.stream, "stream_ordering") + else: + return "(%d < %s OR (%d = %s AND %d < %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) - def lower_bound(self): - if self.topological is None: - return "(%d < %s)" % (self.stream, "stream_ordering") - else: - return "(%d < %s OR (%d = %s AND %d < %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) - def upper_bound(self): - if self.topological is None: - return "(%d >= %s)" % (self.stream, "stream_ordering") - else: - return "(%d > %s OR (%d = %s AND %d >= %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) +def upper_bound(token): + if token.topological is None: + return "(%d >= %s)" % (token.stream, "stream_ordering") + else: + return "(%d > %s OR (%d = %s AND %d >= %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) class StreamStore(SQLBaseStore): @@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: defer.returnValue(([], to_key)) @@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: return defer.succeed(([], to_key)) @@ -276,7 +224,7 @@ class StreamStore(SQLBaseStore): return self.runInteraction("get_room_events_stream", f) - @log_function + @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1, with_feedback=False): @@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = _StreamToken.parse(from_key).upper_bound() + bounds = upper_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).lower_bound() + bounds, lower_bound(RoomStreamToken.parse(to_key)) ) else: order = "ASC" - bounds = _StreamToken.parse(from_key).lower_bound() + bounds = lower_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).upper_bound() + bounds, upper_bound(RoomStreamToken.parse(to_key)) ) if int(limit) > 0: @@ -333,28 +281,30 @@ class StreamStore(SQLBaseStore): # when we are going backwards so we subtract one from the # stream part. toke -= 1 - next_token = str(_StreamToken(topo, toke)) + next_token = str(RoomStreamToken(topo, toke)) else: # TODO (erikj): We should work out what to do here instead. next_token = to_key if to_key else from_key - events = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) + return rows, next_token, - self._set_before_and_after(events, rows) + rows, token = yield self.runInteraction("paginate_room_events", f) + + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - return events, next_token, + self._set_before_and_after(events, rows) - return self.runInteraction("paginate_room_events", f) + defer.returnValue((events, token)) + @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, with_feedback=False, from_token=None): # TODO (erikj): Handle compressed feedback - end_token = _StreamToken.parse_stream_token(end_token) + end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: sql = ( @@ -365,7 +315,7 @@ class StreamStore(SQLBaseStore): " LIMIT ?" ) else: - from_token = _StreamToken.parse_stream_token(from_token) + from_token = RoomStreamToken.parse_stream_token(from_token) sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -395,30 +345,49 @@ class StreamStore(SQLBaseStore): # stream part. topo = rows[0]["topological_ordering"] toke = rows[0]["stream_ordering"] - 1 - start_token = str(_StreamToken(topo, toke)) + start_token = str(RoomStreamToken(topo, toke)) token = (start_token, str(end_token)) else: token = (str(end_token), str(end_token)) - events = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) + return rows, token - return events, token - - return self.runInteraction( + rows, token = yield self.runInteraction( "get_recent_events_for_room", get_recent_events_for_room_txn ) + logger.debug("stream before") + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + logger.debug("stream after") + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) + @defer.inlineCallbacks - def get_room_events_max_id(self): + def get_room_events_max_id(self, direction='f'): token = yield self._stream_id_gen.get_max_token(self) - defer.returnValue("s%d" % (token,)) + if direction != 'b': + defer.returnValue("s%d" % (token,)) + else: + topo = yield self.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn + ) + defer.returnValue("t%d-%d" % (topo, token)) + + def _get_max_topological_txn(self, txn): + txn.execute( + "SELECT MAX(topological_ordering) FROM events" + " WHERE outlier = ?", + (False,) + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 @defer.inlineCallbacks def _get_min_token(self): @@ -439,5 +408,5 @@ class StreamStore(SQLBaseStore): stream = row["stream_ordering"] topo = event.depth internal = event.internal_metadata - internal.before = str(_StreamToken(topo, stream - 1)) - internal.after = str(_StreamToken(topo, stream)) + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e40eb8a8c4..89d1643f10 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -78,14 +78,18 @@ class StreamIdGenerator(object): self._current_max = None self._unfinished_ids = deque() - def get_next_txn(self, txn): + @defer.inlineCallbacks + def get_next(self, store): """ Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with yield stream_id_gen.get_next as stream_id: # ... persist event ... """ if not self._current_max: - self._get_or_compute_current_max(txn) + yield store.runInteraction( + "_compute_current_max", + self._get_or_compute_current_max, + ) with self._lock: self._current_max += 1 @@ -101,7 +105,7 @@ class StreamIdGenerator(object): with self._lock: self._unfinished_ids.remove(next_id) - return manager() + defer.returnValue(manager()) @defer.inlineCallbacks def get_max_token(self, store): |