summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py12
-rw-r--r--synapse/storage/event_federation.py66
-rw-r--r--synapse/storage/events.py53
-rw-r--r--synapse/storage/push_rule.py112
-rw-r--r--synapse/storage/stream.py138
-rw-r--r--synapse/storage/util/id_generators.py12
6 files changed, 191 insertions, 202 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py

index ee5587c721..c9fe5a3555 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ from synapse.api.errors import StoreError from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext, LoggingContext +from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics @@ -308,6 +308,7 @@ class SQLBaseStore(object): self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) + self._push_rule_id_gen = IdGenerator("push_rules", "id", self) def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -419,10 +420,11 @@ class SQLBaseStore(object): self._txn_perf_counters.update(desc, start, end) sql_txn_timer.inc_by(duration, desc) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) + for after_callback, after_args in after_callbacks: after_callback(*after_args) defer.returnValue(result) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 74b4e23590..a1982dfbb5 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py
@@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore): room_id, ) + def get_oldest_events_with_depth_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_with_depth_in_room", + self.get_oldest_events_with_depth_in_room_txn, + room_id, + ) + + def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): + sql = ( + "SELECT b.event_id, MAX(e.depth) FROM events as e" + " INNER JOIN event_edges as g" + " ON g.event_id = e.event_id AND g.room_id = e.room_id" + " INNER JOIN event_backward_extremities as b" + " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id" + " WHERE b.room_id = ? AND g.is_state is ?" + " GROUP BY b.event_id" + ) + + txn.execute(sql, (room_id, False,)) + + return dict(txn.fetchall()) + def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( txn, @@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore): do_insert = depth < min_depth if min_depth else True if do_insert: - self._simple_insert_txn( + self._simple_upsert_txn( txn, table="room_depth", - values={ + keyvalues={ "room_id": room_id, + }, + values={ "min_depth": depth, }, ) @@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (event_id, room_id)) - # Insert all the prev_events as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - self._simple_insert_many_txn( - txn, - table="event_backward_extremities", - values=[ - { - "event_id": e_id, - "room_id": room_id, - } - for e_id, _ in prev_events - ], + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?" + " )" ) - # Also delete from the backwards extremities table all ones that - # reference events that we have already seen + txn.executemany(query, [ + (e_id, room_id, e_id, room_id, e_id, room_id, ) + for e_id, _ in prev_events + ]) + query = ( - "DELETE FROM event_backward_extremities WHERE EXISTS (" - "SELECT 1 FROM events " - "WHERE " - "event_backward_extremities.event_id = events.event_id " - "AND not events.outlier " - ")" + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" ) - txn.execute(query) + txn.execute(query, (event_id, room_id)) txn.call_after( self.get_latest_event_ids_in_room.invalidate, room_id diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..a5a6869079 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash from syutil.base64util import decode_base64 from syutil.jsonutil import encode_canonical_json +from contextlib import contextmanager import logging @@ -41,17 +42,25 @@ class EventsStore(SQLBaseStore): self.min_token -= 1 stream_ordering = self.min_token + if stream_ordering is None: + stream_ordering_manager = yield self._stream_id_gen.get_next(self) + else: + @contextmanager + def stream_ordering_manager(): + yield stream_ordering + try: - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - backfilled=backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) + with stream_ordering_manager as stream_ordering: + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + backfilled=backfilled, + stream_ordering=stream_ordering, + is_new_state=is_new_state, + current_state=current_state, + ) except _RollbackButIsFineException: pass @@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore): # Remove the any existing cache entries for the event_id txn.call_after(self._invalidate_get_event_cache, event.event_id) - if stream_ordering is None: - with self._stream_id_gen.get_next_txn(txn) as stream_ordering: - return self._persist_event_txn( - txn, event, context, backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) - # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore): outlier = event.internal_metadata.is_outlier() if not outlier: - self._store_state_groups_txn(txn, event, context) - self._update_min_depth_for_room_txn( txn, event.room_id, event.depth ) - have_persisted = self._simple_select_one_onecol_txn( + have_persisted = self._simple_select_one_txn( txn, - table="event_json", + table="events", keyvalues={"event_id": event.event_id}, - retcol="event_id", + retcols=["event_id", "outlier"], allow_none=True, ) @@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore): # if we are persisting an event that we had persisted as an outlier, # but is no longer one. if have_persisted: - if not outlier: + if not outlier and have_persisted["outlier"]: + self._store_state_groups_txn(txn, event, context) + sql = ( "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" @@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore): ) return + if not outlier: + self._store_state_groups_txn(txn, event, context) + self._handle_prev_events( txn, outlier=outlier, diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ee7718d5ed..34805e276e 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py
@@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table from twisted.internet import defer import logging -import copy import simplejson as json logger = logging.getLogger(__name__) @@ -28,46 +27,45 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): @defer.inlineCallbacks def get_push_rules_for_user(self, user_name): - sql = ( - "SELECT "+",".join(PushRuleTable.fields)+" " - "FROM "+PushRuleTable.table_name+" " - "WHERE user_name = ? " - "ORDER BY priority_class DESC, priority DESC" + rows = yield self._simple_select_list( + table=PushRuleTable.table_name, + keyvalues={ + "user_name": user_name, + }, + retcols=PushRuleTable.fields, ) - rows = yield self._execute("get_push_rules_for_user", None, sql, user_name) - dicts = [] - for r in rows: - d = {} - for i, f in enumerate(PushRuleTable.fields): - d[f] = r[i] - dicts.append(d) + rows.sort( + key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) + ) - defer.returnValue(dicts) + defer.returnValue(rows) @defer.inlineCallbacks def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( - PushRuleEnableTable.table_name, - {'user_name': user_name}, - PushRuleEnableTable.fields, + table=PushRuleEnableTable.table_name, + keyvalues={ + 'user_name': user_name + }, + retcols=PushRuleEnableTable.fields, desc="get_push_rules_enabled_for_user", ) - defer.returnValue( - {r['rule_id']: False if r['enabled'] == 0 else True for r in results} - ) + defer.returnValue({ + r['rule_id']: False if r['enabled'] == 0 else True for r in results + }) @defer.inlineCallbacks def add_push_rule(self, before, after, **kwargs): - vals = copy.copy(kwargs) + vals = kwargs if 'conditions' in vals: vals['conditions'] = json.dumps(vals['conditions']) if 'actions' in vals: vals['actions'] = json.dumps(vals['actions']) + # we could check the rest of the keys are valid column names # but sqlite will do that anyway so I think it's just pointless. - if 'id' in vals: - del vals['id'] + vals.pop("id", None) if before or after: ret = yield self.runInteraction( @@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): - after = None - relative_to_rule = None - if 'after' in kwargs and kwargs['after']: - after = kwargs['after'] - relative_to_rule = after - if 'before' in kwargs and kwargs['before']: - relative_to_rule = kwargs['before'] - - # get the priority of the rule we're inserting after/before - sql = ( - "SELECT priority_class, priority FROM ? " - "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) + after = kwargs.pop("after", None) + relative_to_rule = kwargs.pop("before", after) + + res = self._simple_select_one_txn( + txn, + table=PushRuleTable.table_name, + keyvalues={ + "user_name": user_name, + "rule_id": relative_to_rule, + }, + retcols=["priority_class", "priority"], + allow_none=True, ) - txn.execute(sql, (user_name, relative_to_rule)) - res = txn.fetchall() + if not res: raise RuleNotFoundException( "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class, base_rule_priority = res[0] + + priority_class = res["priority_class"] + base_rule_priority = res["priority"] if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = copy.copy(kwargs) - if 'before' in new_rule: - del new_rule['before'] - if 'after' in new_rule: - del new_rule['after'] + new_rule = kwargs + new_rule.pop("before", None) + new_rule.pop("after", None) new_rule['priority_class'] = priority_class new_rule['user_name'] = user_name + new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) # check if the priority before/after is free new_rule_priority = base_rule_priority @@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) - # now insert the new rule - sql = "INSERT INTO "+PushRuleTable.table_name+" (" - sql += ",".join(new_rule.keys())+") VALUES (" - sql += ", ".join(["?" for _ in new_rule.keys()])+")" - - txn.execute(sql, new_rule.values()) + self._simple_insert_txn( + txn, + table=PushRuleTable.table_name, + values=new_rule, + ) def _add_push_rule_highest_priority_txn(self, txn, user_name, priority_class, **kwargs): @@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore): new_prio = highest_prio + 1 # and insert the new rule - new_rule = copy.copy(kwargs) - if 'id' in new_rule: - del new_rule['id'] + new_rule = kwargs + new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) new_rule['user_name'] = user_name new_rule['priority_class'] = priority_class new_rule['priority'] = new_prio - sql = "INSERT INTO "+PushRuleTable.table_name+" (" - sql += ",".join(new_rule.keys())+") VALUES (" - sql += ", ".join(["?" for _ in new_rule.keys()])+")" - - txn.execute(sql, new_rule.values()) + self._simple_insert_txn( + txn, + table=PushRuleTable.table_name, + values=new_rule, + ) @defer.inlineCallbacks def delete_push_rule(self, user_name, rule_id): @@ -211,7 +207,7 @@ class PushRuleStore(SQLBaseStore): yield self._simple_upsert( PushRuleEnableTable.table_name, {'user_name': user_name, 'rule_id': rule_id}, - {'enabled': enabled}, + {'enabled': 1 if enabled else 0}, desc="set_push_rule_enabled", ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 280d4ad605..8045e17fd7 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py
@@ -37,11 +37,9 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.types import RoomStreamToken from synapse.util.logutils import log_function -from collections import namedtuple - import logging @@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" -class _StreamToken(namedtuple("_StreamToken", "topological stream")): - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] V [1] V [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream events are ordered by when they - arrived at the homeserver. - - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". - - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, follewed by "-", - followed by the "stream_ordering" id of the event it comes after. - """ - __slots__ = [] - - @classmethod - def parse(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - if string[0] == 't': - parts = string[1:].split('-', 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - def __str__(self): - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - else: - return "s%d" % (self.stream,) +def lower_bound(token): + if token.topological is None: + return "(%d < %s)" % (token.stream, "stream_ordering") + else: + return "(%d < %s OR (%d = %s AND %d < %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) - def lower_bound(self): - if self.topological is None: - return "(%d < %s)" % (self.stream, "stream_ordering") - else: - return "(%d < %s OR (%d = %s AND %d < %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) - def upper_bound(self): - if self.topological is None: - return "(%d >= %s)" % (self.stream, "stream_ordering") - else: - return "(%d > %s OR (%d = %s AND %d >= %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) +def upper_bound(token): + if token.topological is None: + return "(%d >= %s)" % (token.stream, "stream_ordering") + else: + return "(%d > %s OR (%d = %s AND %d >= %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) class StreamStore(SQLBaseStore): @@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: defer.returnValue(([], to_key)) @@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: return defer.succeed(([], to_key)) @@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = _StreamToken.parse(from_key).upper_bound() + bounds = upper_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).lower_bound() + bounds, lower_bound(RoomStreamToken.parse(to_key)) ) else: order = "ASC" - bounds = _StreamToken.parse(from_key).lower_bound() + bounds = lower_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).upper_bound() + bounds, upper_bound(RoomStreamToken.parse(to_key)) ) if int(limit) > 0: @@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore): # when we are going backwards so we subtract one from the # stream part. toke -= 1 - next_token = str(_StreamToken(topo, toke)) + next_token = str(RoomStreamToken(topo, toke)) else: # TODO (erikj): We should work out what to do here instead. next_token = to_key if to_key else from_key @@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore): with_feedback=False, from_token=None): # TODO (erikj): Handle compressed feedback - end_token = _StreamToken.parse_stream_token(end_token) + end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: sql = ( @@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore): " LIMIT ?" ) else: - from_token = _StreamToken.parse_stream_token(from_token) + from_token = RoomStreamToken.parse_stream_token(from_token) sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore): # stream part. topo = rows[0]["topological_ordering"] toke = rows[0]["stream_ordering"] - 1 - start_token = str(_StreamToken(topo, toke)) + start_token = str(RoomStreamToken(topo, toke)) token = (start_token, str(end_token)) else: @@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_room_events_max_id(self): + def get_room_events_max_id(self, direction='f'): token = yield self._stream_id_gen.get_max_token(self) - defer.returnValue("s%d" % (token,)) + if direction != 'b': + defer.returnValue("s%d" % (token,)) + else: + topo = yield self.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn + ) + defer.returnValue("t%d-%d" % (topo, token)) + + def _get_max_topological_txn(self, txn): + txn.execute( + "SELECT MAX(topological_ordering) FROM events" + " WHERE outlier = ?", + (False,) + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 @defer.inlineCallbacks def _get_min_token(self): @@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore): stream = row["stream_ordering"] topo = event.depth internal = event.internal_metadata - internal.before = str(_StreamToken(topo, stream - 1)) - internal.after = str(_StreamToken(topo, stream)) + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e40eb8a8c4..89d1643f10 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object): self._current_max = None self._unfinished_ids = deque() - def get_next_txn(self, txn): + @defer.inlineCallbacks + def get_next(self, store): """ Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with yield stream_id_gen.get_next as stream_id: # ... persist event ... """ if not self._current_max: - self._get_or_compute_current_max(txn) + yield store.runInteraction( + "_compute_current_max", + self._get_or_compute_current_max, + ) with self._lock: self._current_max += 1 @@ -101,7 +105,7 @@ class StreamIdGenerator(object): with self._lock: self._unfinished_ids.remove(next_id) - return manager() + defer.returnValue(manager()) @defer.inlineCallbacks def get_max_token(self, store):