diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ee5587c721..c9fe5a3555 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
@@ -308,6 +308,7 @@ class SQLBaseStore(object):
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
+ self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -419,10 +420,11 @@ class SQLBaseStore(object):
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield preserve_context_over_fn(
+ self._db_pool.runWithConnection,
+ inner_func, *args, **kwargs
+ )
+
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 74b4e23590..a1982dfbb5 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
+ def get_oldest_events_with_depth_in_room(self, room_id):
+ return self.runInteraction(
+ "get_oldest_events_with_depth_in_room",
+ self.get_oldest_events_with_depth_in_room_txn,
+ room_id,
+ )
+
+ def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
+ sql = (
+ "SELECT b.event_id, MAX(e.depth) FROM events as e"
+ " INNER JOIN event_edges as g"
+ " ON g.event_id = e.event_id AND g.room_id = e.room_id"
+ " INNER JOIN event_backward_extremities as b"
+ " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
+ " WHERE b.room_id = ? AND g.is_state is ?"
+ " GROUP BY b.event_id"
+ )
+
+ txn.execute(sql, (room_id, False,))
+
+ return dict(txn.fetchall())
+
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,
@@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore):
do_insert = depth < min_depth if min_depth else True
if do_insert:
- self._simple_insert_txn(
+ self._simple_upsert_txn(
txn,
table="room_depth",
- values={
+ keyvalues={
"room_id": room_id,
+ },
+ values={
"min_depth": depth,
},
)
@@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (event_id, room_id))
- # Insert all the prev_events as a backwards thing, they'll get
- # deleted in a second if they're incorrect anyway.
- self._simple_insert_many_txn(
- txn,
- table="event_backward_extremities",
- values=[
- {
- "event_id": e_id,
- "room_id": room_id,
- }
- for e_id, _ in prev_events
- ],
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
+ " )"
)
- # Also delete from the backwards extremities table all ones that
- # reference events that we have already seen
+ txn.executemany(query, [
+ (e_id, room_id, e_id, room_id, e_id, room_id, )
+ for e_id, _ in prev_events
+ ])
+
query = (
- "DELETE FROM event_backward_extremities WHERE EXISTS ("
- "SELECT 1 FROM events "
- "WHERE "
- "event_backward_extremities.event_id = events.event_id "
- "AND not events.outlier "
- ")"
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
)
- txn.execute(query)
+ txn.execute(query, (event_id, room_id))
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..a5a6869079 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
import logging
@@ -41,17 +42,25 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1
stream_ordering = self.min_token
+ if stream_ordering is None:
+ stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ else:
+ @contextmanager
+ def stream_ordering_manager():
+ yield stream_ordering
+
try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
+ with stream_ordering_manager as stream_ordering:
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
except _RollbackButIsFineException:
pass
@@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
- if stream_ordering is None:
- with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
- return self._persist_event_txn(
- txn, event, context, backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
-
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
@@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore):
outlier = event.internal_metadata.is_outlier()
if not outlier:
- self._store_state_groups_txn(txn, event, context)
-
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
- have_persisted = self._simple_select_one_onecol_txn(
+ have_persisted = self._simple_select_one_txn(
txn,
- table="event_json",
+ table="events",
keyvalues={"event_id": event.event_id},
- retcol="event_id",
+ retcols=["event_id", "outlier"],
allow_none=True,
)
@@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore):
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
- if not outlier:
+ if not outlier and have_persisted["outlier"]:
+ self._store_state_groups_txn(txn, event, context)
+
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
@@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore):
)
return
+ if not outlier:
+ self._store_state_groups_txn(txn, event, context)
+
self._handle_prev_events(
txn,
outlier=outlier,
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ee7718d5ed..34805e276e 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
-import copy
import simplejson as json
logger = logging.getLogger(__name__)
@@ -28,46 +27,45 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name):
- sql = (
- "SELECT "+",".join(PushRuleTable.fields)+" "
- "FROM "+PushRuleTable.table_name+" "
- "WHERE user_name = ? "
- "ORDER BY priority_class DESC, priority DESC"
+ rows = yield self._simple_select_list(
+ table=PushRuleTable.table_name,
+ keyvalues={
+ "user_name": user_name,
+ },
+ retcols=PushRuleTable.fields,
)
- rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
- dicts = []
- for r in rows:
- d = {}
- for i, f in enumerate(PushRuleTable.fields):
- d[f] = r[i]
- dicts.append(d)
+ rows.sort(
+ key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
+ )
- defer.returnValue(dicts)
+ defer.returnValue(rows)
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
- PushRuleEnableTable.table_name,
- {'user_name': user_name},
- PushRuleEnableTable.fields,
+ table=PushRuleEnableTable.table_name,
+ keyvalues={
+ 'user_name': user_name
+ },
+ retcols=PushRuleEnableTable.fields,
desc="get_push_rules_enabled_for_user",
)
- defer.returnValue(
- {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
- )
+ defer.returnValue({
+ r['rule_id']: False if r['enabled'] == 0 else True for r in results
+ })
@defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
- vals = copy.copy(kwargs)
+ vals = kwargs
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
+
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
- if 'id' in vals:
- del vals['id']
+ vals.pop("id", None)
if before or after:
ret = yield self.runInteraction(
@@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
- after = None
- relative_to_rule = None
- if 'after' in kwargs and kwargs['after']:
- after = kwargs['after']
- relative_to_rule = after
- if 'before' in kwargs and kwargs['before']:
- relative_to_rule = kwargs['before']
-
- # get the priority of the rule we're inserting after/before
- sql = (
- "SELECT priority_class, priority FROM ? "
- "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
+ after = kwargs.pop("after", None)
+ relative_to_rule = kwargs.pop("before", after)
+
+ res = self._simple_select_one_txn(
+ txn,
+ table=PushRuleTable.table_name,
+ keyvalues={
+ "user_name": user_name,
+ "rule_id": relative_to_rule,
+ },
+ retcols=["priority_class", "priority"],
+ allow_none=True,
)
- txn.execute(sql, (user_name, relative_to_rule))
- res = txn.fetchall()
+
if not res:
raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,)
)
- priority_class, base_rule_priority = res[0]
+
+ priority_class = res["priority_class"]
+ base_rule_priority = res["priority"]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
- new_rule = copy.copy(kwargs)
- if 'before' in new_rule:
- del new_rule['before']
- if 'after' in new_rule:
- del new_rule['after']
+ new_rule = kwargs
+ new_rule.pop("before", None)
+ new_rule.pop("after", None)
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
+ new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free
new_rule_priority = base_rule_priority
@@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
- # now insert the new rule
- sql = "INSERT INTO "+PushRuleTable.table_name+" ("
- sql += ",".join(new_rule.keys())+") VALUES ("
- sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
- txn.execute(sql, new_rule.values())
+ self._simple_insert_txn(
+ txn,
+ table=PushRuleTable.table_name,
+ values=new_rule,
+ )
def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs):
@@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore):
new_prio = highest_prio + 1
# and insert the new rule
- new_rule = copy.copy(kwargs)
- if 'id' in new_rule:
- del new_rule['id']
+ new_rule = kwargs
+ new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
- sql = "INSERT INTO "+PushRuleTable.table_name+" ("
- sql += ",".join(new_rule.keys())+") VALUES ("
- sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
- txn.execute(sql, new_rule.values())
+ self._simple_insert_txn(
+ txn,
+ table=PushRuleTable.table_name,
+ values=new_rule,
+ )
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id):
@@ -211,7 +207,7 @@ class PushRuleStore(SQLBaseStore):
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
- {'enabled': enabled},
+ {'enabled': 1 if enabled else 0},
desc="set_push_rule_enabled",
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 280d4ad605..8045e17fd7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,11 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function
-from collections import namedtuple
-
import logging
@@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-class _StreamToken(namedtuple("_StreamToken", "topological stream")):
- """Tokens are positions between events. The token "s1" comes after event 1.
-
- s0 s1
- | |
- [0] V [1] V [2]
-
- Tokens can either be a point in the live event stream or a cursor going
- through historic events.
-
- When traversing the live event stream events are ordered by when they
- arrived at the homeserver.
-
- When traversing historic events the events are ordered by their depth in
- the event graph "topological_ordering" and then by when they arrived at the
- homeserver "stream_ordering".
-
- Live tokens start with an "s" followed by the "stream_ordering" id of the
- event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, follewed by "-",
- followed by the "stream_ordering" id of the event it comes after.
- """
- __slots__ = []
-
- @classmethod
- def parse(cls, string):
- try:
- if string[0] == 's':
- return cls(topological=None, stream=int(string[1:]))
- if string[0] == 't':
- parts = string[1:].split('-', 1)
- return cls(topological=int(parts[0]), stream=int(parts[1]))
- except:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- @classmethod
- def parse_stream_token(cls, string):
- try:
- if string[0] == 's':
- return cls(topological=None, stream=int(string[1:]))
- except:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- def __str__(self):
- if self.topological is not None:
- return "t%d-%d" % (self.topological, self.stream)
- else:
- return "s%d" % (self.stream,)
+def lower_bound(token):
+ if token.topological is None:
+ return "(%d < %s)" % (token.stream, "stream_ordering")
+ else:
+ return "(%d < %s OR (%d = %s AND %d < %s))" % (
+ token.topological, "topological_ordering",
+ token.topological, "topological_ordering",
+ token.stream, "stream_ordering",
+ )
- def lower_bound(self):
- if self.topological is None:
- return "(%d < %s)" % (self.stream, "stream_ordering")
- else:
- return "(%d < %s OR (%d = %s AND %d < %s))" % (
- self.topological, "topological_ordering",
- self.topological, "topological_ordering",
- self.stream, "stream_ordering",
- )
- def upper_bound(self):
- if self.topological is None:
- return "(%d >= %s)" % (self.stream, "stream_ordering")
- else:
- return "(%d > %s OR (%d = %s AND %d >= %s))" % (
- self.topological, "topological_ordering",
- self.topological, "topological_ordering",
- self.stream, "stream_ordering",
- )
+def upper_bound(token):
+ if token.topological is None:
+ return "(%d >= %s)" % (token.stream, "stream_ordering")
+ else:
+ return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+ token.topological, "topological_ordering",
+ token.topological, "topological_ordering",
+ token.stream, "stream_ordering",
+ )
class StreamStore(SQLBaseStore):
@@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
- from_id = _StreamToken.parse_stream_token(from_key)
- to_id = _StreamToken.parse_stream_token(to_key)
+ from_id = RoomStreamToken.parse_stream_token(from_key)
+ to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key:
defer.returnValue(([], to_key))
@@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
- from_id = _StreamToken.parse_stream_token(from_key)
- to_id = _StreamToken.parse_stream_token(to_key)
+ from_id = RoomStreamToken.parse_stream_token(from_key)
+ to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key:
return defer.succeed(([], to_key))
@@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = _StreamToken.parse(from_key).upper_bound()
+ bounds = upper_bound(RoomStreamToken.parse(from_key))
if to_key:
bounds = "%s AND %s" % (
- bounds, _StreamToken.parse(to_key).lower_bound()
+ bounds, lower_bound(RoomStreamToken.parse(to_key))
)
else:
order = "ASC"
- bounds = _StreamToken.parse(from_key).lower_bound()
+ bounds = lower_bound(RoomStreamToken.parse(from_key))
if to_key:
bounds = "%s AND %s" % (
- bounds, _StreamToken.parse(to_key).upper_bound()
+ bounds, upper_bound(RoomStreamToken.parse(to_key))
)
if int(limit) > 0:
@@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore):
# when we are going backwards so we subtract one from the
# stream part.
toke -= 1
- next_token = str(_StreamToken(topo, toke))
+ next_token = str(RoomStreamToken(topo, toke))
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key
@@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore):
with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback
- end_token = _StreamToken.parse_stream_token(end_token)
+ end_token = RoomStreamToken.parse_stream_token(end_token)
if from_token is None:
sql = (
@@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore):
" LIMIT ?"
)
else:
- from_token = _StreamToken.parse_stream_token(from_token)
+ from_token = RoomStreamToken.parse_stream_token(from_token)
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore):
# stream part.
topo = rows[0]["topological_ordering"]
toke = rows[0]["stream_ordering"] - 1
- start_token = str(_StreamToken(topo, toke))
+ start_token = str(RoomStreamToken(topo, toke))
token = (start_token, str(end_token))
else:
@@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_room_events_max_id(self):
+ def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self)
- defer.returnValue("s%d" % (token,))
+ if direction != 'b':
+ defer.returnValue("s%d" % (token,))
+ else:
+ topo = yield self.runInteraction(
+ "_get_max_topological_txn", self._get_max_topological_txn
+ )
+ defer.returnValue("t%d-%d" % (topo, token))
+
+ def _get_max_topological_txn(self, txn):
+ txn.execute(
+ "SELECT MAX(topological_ordering) FROM events"
+ " WHERE outlier = ?",
+ (False,)
+ )
+
+ rows = txn.fetchall()
+ return rows[0][0] if rows else 0
@defer.inlineCallbacks
def _get_min_token(self):
@@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore):
stream = row["stream_ordering"]
topo = event.depth
internal = event.internal_metadata
- internal.before = str(_StreamToken(topo, stream - 1))
- internal.after = str(_StreamToken(topo, stream))
+ internal.before = str(RoomStreamToken(topo, stream - 1))
+ internal.after = str(RoomStreamToken(topo, stream))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e40eb8a8c4..89d1643f10 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object):
self._current_max = None
self._unfinished_ids = deque()
- def get_next_txn(self, txn):
+ @defer.inlineCallbacks
+ def get_next(self, store):
"""
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
if not self._current_max:
- self._get_or_compute_current_max(txn)
+ yield store.runInteraction(
+ "_compute_current_max",
+ self._get_or_compute_current_max,
+ )
with self._lock:
self._current_max += 1
@@ -101,7 +105,7 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ defer.returnValue(manager())
@defer.inlineCallbacks
def get_max_token(self, store):
|