diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a7a8ec9b7b..c659004e8d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
+ def __iter__(self):
+ return self.txn.__iter__()
+
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
@@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.items():
+ for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -357,7 +360,7 @@ class SQLBaseStore(object):
"""
col_headers = list(intern(column[0]) for column in cursor.description)
results = list(
- dict(zip(col_headers, row)) for row in cursor.fetchall()
+ dict(zip(col_headers, row)) for row in cursor
)
return results
@@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
@@ -712,7 +715,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.items():
+ for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@@ -840,6 +843,47 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
+ def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ """
+ if not iterable:
+ return
+
+ sql = "DELETE FROM %s" % table
+
+ clauses = []
+ values = []
+ clauses.append(
+ "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+ )
+ values.extend(iterable)
+
+ for key, value in keyvalues.iteritems():
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (
+ sql,
+ " AND ".join(clauses),
+ )
+ return txn.execute(sql, values)
+
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -860,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
- rows = txn.fetchall()
- txn.close()
cache = {
row[0]: int(row[1])
- for row in rows
+ for row in txn
}
+ txn.close()
+
if cache:
- min_val = min(cache.values())
+ min_val = min(cache.itervalues())
else:
min_val = max_value
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e92d..aa84ffc2b0 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {
- row[0]: json.loads(row[1]) for row in txn.fetchall()
+ row[0]: json.loads(row[1]) for row in txn
}
sql = (
@@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
- for row in txn.fetchall():
+ for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc54a..813ad59e56 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse.util.async
from ._base import SQLBaseStore
from . import engines
@@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
- self._background_update_timer = None
@defer.inlineCallbacks
def start_doing_background_updates(self):
- assert self._background_update_timer is None, \
- "background updates already running"
-
logger.info("Starting background schema updates")
while True:
- sleep = defer.Deferred()
- self._background_update_timer = self._clock.call_later(
- self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
- )
- try:
- yield sleep
- finally:
- self._background_update_timer = None
+ yield synapse.util.async.sleep(
+ self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
result = yield self.do_next_background_update(
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 5c7db5e5f6..0b62b493d5 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
- for row in txn.fetchall():
+ for row in txn:
# Add the message for all devices for this user on this
# server.
device = row[0]
@@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
- for row in txn.fetchall():
+ for row in txn:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
@@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
@@ -325,22 +325,25 @@ class DeviceInboxStore(BackgroundUpdateStore):
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
- "SELECT stream_id, user_id"
+ "SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
+ " GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
- "SELECT stream_id, destination"
+ "SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
+ " GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn.fetchall())
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
return rows
@@ -357,12 +360,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
"""
Args:
destination(str): The name of the remote server.
- last_stream_id(int): The last position of the device message stream
+ last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
- current_stream_id(int): The current position of the device
+ current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int): List of messages for the device and where
+ Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
@@ -384,7 +387,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit
))
messages = []
- for row in txn.fetchall():
+ for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bd56ba2515..c8d5f5ba8b 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
desc="delete_device",
)
+ def delete_devices(self, user_id, device_ids):
+ """Deletes several devices.
+
+ Args:
+ user_id (str): The ID of the user which owns the devices
+ device_ids (list): The IDs of the devices to delete
+ Returns:
+ defer.Deferred
+ """
+ return self._simple_delete_many(
+ table="devices",
+ column="device_id",
+ iterable=device_ids,
+ keyvalues={"user_id": user_id},
+ desc="delete_devices",
+ )
+
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
@@ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers
Returns:
- (now_stream_id, [ { updates }, .. ])
+ (int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
+ LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
- rows = txn.fetchall()
- if not rows:
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
return (now_stream_id, [])
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in rows}
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
@@ -513,7 +533,7 @@ class DeviceStore(SQLBaseStore):
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
- def get_all_device_list_changes_for_remotes(self, from_key):
+ def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
@@ -521,11 +541,11 @@ class DeviceStore(SQLBaseStore):
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
- WHERE stream_id > ?
+ WHERE ? < stream_id AND stream_id <= ?
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
- sql, from_key,
+ sql, from_key, to_key
)
@defer.inlineCallbacks
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f92..7cbc1470fd 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
+from synapse.api.errors import SynapseError
+
from canonicaljson import encode_canonical_json
import ujson as json
@@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
return result
+ @defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+ """Insert some new one time keys for a device.
+
+ Checks if any of the keys are already inserted, if they are then check
+ if they match. If they don't then we raise an error.
+ """
+
+ # First we check if we have already persisted any of the keys.
+ rows = yield self._simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=[key_id for _, key_id, _ in key_list],
+ retcols=("algorithm", "key_id", "key_json",),
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ desc="add_e2e_one_time_keys_check",
+ )
+
+ existing_key_map = {
+ (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+ }
+
+ new_keys = [] # Keys that we need to insert
+ for algorithm, key_id, json_bytes in key_list:
+ ex_bytes = existing_key_map.get((algorithm, key_id), None)
+ if ex_bytes:
+ if json_bytes != ex_bytes:
+ raise SynapseError(
+ 400, "One time key with key_id %r already exists" % (key_id,)
+ )
+ else:
+ new_keys.append((algorithm, key_id, json_bytes))
+
def _add_e2e_one_time_keys(txn):
- for (algorithm, key_id, json_bytes) in key_list:
- self._simple_upsert_txn(
- txn, table="e2e_one_time_keys_json",
- keyvalues={
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self._simple_insert_many_txn(
+ txn, table="e2e_one_time_keys_json",
+ values=[
+ {
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
- },
- values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
- )
- return self.runInteraction(
- "add_e2e_one_time_keys", _add_e2e_one_time_keys
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ yield self.runInteraction(
+ "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
def count_e2e_one_time_keys(self, user_id, device_id):
@@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
)
txn.execute(sql, (user_id, device_id))
result = {}
- for algorithm, key_count in txn.fetchall():
+ for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.runInteraction(
@@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn.fetchall():
+ for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 256e50dc20..519059c306 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
- new_front.update([r[0] for r in txn.fetchall()])
+ new_front.update([r[0] for r in txn])
new_front -= results
@@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,))
- return dict(txn.fetchall())
+ return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
- do_insert = depth < min_depth if min_depth else True
+ if min_depth and depth >= min_depth:
+ return
- if do_insert:
- self._simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
- )
+ self._simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={
+ "room_id": room_id,
+ },
+ values={
+ "min_depth": depth,
+ },
+ )
def _handle_mult_prev_events(self, txn, events):
"""
@@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
- rows = txn.fetchall()
- return [event_id for event_id, in rows]
+ return [event_id for event_id, in txn]
return self.runInteraction(
"get_forward_extremeties_for_room",
@@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for row in txn.fetchall():
+ for row in txn:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
@@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
- for e_id, in txn.fetchall():
+ for e_id, in txn:
new_front.add(e_id)
new_front -= earliest_events
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 14543b4269..d6d8723b4a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
- return [r[0] for r in txn.fetchall()]
+ return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index db01eb6d14..64fe937bdc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
from functools import wraps
-import synapse
import synapse.metrics
-
import logging
import math
import ujson as json
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+
logger = logging.getLogger(__name__)
@@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
@@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in partitioned.items():
+ for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
)
deferreds.append(d)
- for room_id in partitioned.keys():
+ for room_id in partitioned:
self._maybe_start_persisting(room_id)
return preserve_context_over_deferred(
@@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False):
+ """
+
+ Args:
+ event (EventBase):
+ context (EventContext):
+ backfilled (bool):
+
+ Returns:
+ Deferred: resolves to (int, int): the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
@@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False,
delete_existing=False):
+ """Persist events to db
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+ delete_existing (bool):
+
+ Returns:
+ Deferred: resolves when the events have been persisted
+ """
if not events_and_contexts:
return
@@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context)
)
- for room_id, ev_ctx_rm in events_by_room.items():
+ for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
+ state_groups = set()
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
@@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
- state_sets.append(ctx.current_state_ids)
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+
+ # If we've already seen the state group don't bother adding
+ # it to the state sets again
+ if ctx.state_group not in state_groups:
+ state_sets.append(ctx.current_state_ids)
+ if ctx.delta_ids or hasattr(ev, "state_key"):
+ was_updated = True
+ if ctx.state_group:
+ # Add this as a seen state group (if it has a state
+ # group)
+ state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
@@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
missing_event_ids,
)
- groups = set(event_to_groups.values())
- group_to_state = yield self._get_state_for_groups(groups)
+ groups = set(event_to_groups.itervalues()) - state_groups
- state_sets.extend(group_to_state.values())
+ if groups:
+ group_to_state = yield self._get_state_for_groups(groups)
+ state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
- current_state = yield resolve_events(
- state_sets,
- state_map_factory=lambda ev_ids: self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
- )
+ if len(state_sets) == 1:
+ # If there is only one state set, then we know what the current
+ # state is.
+ current_state = state_sets[0]
+ else:
+ # We work out the current state by passing the state sets to the
+ # state resolution algorithm. It may ask for some events, including
+ # the events we have yet to persist, so we need a slightly more
+ # complicated event lookup function than simply looking the events
+ # up in the db.
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ @defer.inlineCallbacks
+ def get_events(ev_ids):
+ # We get the events by first looking at the list of events we
+ # are trying to persist, and then fetching the rest from the DB.
+ db = []
+ to_return = {}
+ for ev_id in ev_ids:
+ ev = events_map.get(ev_id, None)
+ if ev:
+ to_return[ev_id] = ev
+ else:
+ db.append(ev_id)
+
+ if db:
+ evs = yield self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+ to_return.update(evs)
+ defer.returnValue(to_return)
+
+ current_state = yield resolve_events(
+ state_sets,
+ state_map_factory=get_events,
+ )
else:
return
- existing_state_rows = yield self._simple_select_list(
- table="current_state_events",
- keyvalues={"room_id": room_id},
- retcols=["event_id", "type", "state_key"],
- desc="_calculate_state_delta",
- )
+ existing_state = yield self.get_current_state_ids(room_id)
- existing_events = set(row["event_id"] for row in existing_state_rows)
+ existing_events = set(existing_state.itervalues())
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
@@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
return
to_delete = {
- (row["type"], row["state_key"]): row["event_id"]
- for row in existing_state_rows
- if row["event_id"] in changed_events
+ key: ev_id for key, ev_id in existing_state.iteritems()
+ if ev_id in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
@@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
- If delete_existing is True then existing events will be purged from the
- database before insertion. This is useful when retrying due to IntegrityError.
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]):
+ events to persist
+ backfilled (bool): True if the events were backfilled
+ delete_existing (bool): True to purge existing table rows for the
+ events from the database. This is useful when retrying due to
+ IntegrityError.
+ current_state_for_room (dict[str, (list[str], list[str])]):
+ The current-state delta for each room. For each room, a tuple
+ (to_delete, to_insert), being a list of event ids to be removed
+ from the current state, and a list of event ids to be added to
+ the current state.
+ new_forward_extremeties (dict[str, list[str]]):
+ The new forward extremities for each room. For each room, a
+ list of the event ids which are the forward extremities.
+
"""
+ self._update_current_state_txn(txn, current_state_for_room)
+
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- for room_id, current_state_tuple in current_state_for_room.iteritems():
+ self._update_forward_extremities_txn(
+ txn,
+ new_forward_extremities=new_forward_extremeties,
+ max_stream_order=max_stream_order,
+ )
+
+ # Ensure that we don't have the same event twice.
+ events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+ events_and_contexts,
+ )
+
+ self._update_room_depths_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ # _update_outliers_txn filters out any events which have already been
+ # persisted, and returns the filtered list.
+ events_and_contexts = self._update_outliers_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # From this point onwards the events are only events that we haven't
+ # seen before.
+
+ if delete_existing:
+ # For paranoia reasons, we go and delete all the existing entries
+ # for these events so we can reinsert them.
+ # This gets around any problems with some tables already having
+ # entries.
+ self._delete_existing_rows_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ self._store_event_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # Insert into the state_groups, state_groups_state, and
+ # event_to_state_groups tables.
+ self._store_mult_state_groups_txn(txn, events_and_contexts)
+
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ )
+
+ # From this point onwards the events are only ones that weren't
+ # rejected.
+
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ def _update_current_state_txn(self, txn, state_delta_by_room):
+ for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
@@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
txn, self.get_users_in_room, (room_id,)
)
- for room_id, new_extrem in new_forward_extremeties.items():
+ self._invalidate_cache_and_stream(
+ txn, self.get_current_state_ids, (room_id,)
+ )
+
+ def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+ max_stream_order):
+ for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id,
"room_id": room_id,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem
],
)
@@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in new_forward_extremeties.items()
+ for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem
]
)
- # Ensure that we don't have the same event twice.
- # Pick the earliest non-outlier if there is one, else the earliest one.
+ @classmethod
+ def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ """Ensure that we don't have the same event twice.
+
+ Pick the earliest non-outlier if there is one, else the earliest one.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ Returns:
+ list[(EventBase, EventContext)]: filtered list
+ """
new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
+ return new_events_and_contexts.values()
- events_and_contexts = new_events_and_contexts.values()
+ def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ """Update min_depth for each room
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ backfilled (bool): True if the events were backfilled
+ """
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in depth_updates.items():
+ for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth)
+ def _update_outliers_txn(self, txn, events_and_contexts):
+ """Update any outliers with new event info.
+
+ This turns outliers into ex-outliers (unless the new event was
+ rejected).
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without events which
+ are already in the events table.
+ """
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
@@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
have_persisted = {
event_id: outlier
- for event_id, outlier in txn.fetchall()
+ for event_id, outlier in txn
}
to_remove = set()
for event, context in events_and_contexts:
- if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
- if event.event_id in have_persisted:
- # If we have already seen the event then ignore it.
- to_remove.add(event)
- continue
-
if event.event_id not in have_persisted:
continue
to_remove.add(event)
+ if context.rejected:
+ # If the event is rejected then we don't care if the event
+ # was an outlier or not.
+ continue
+
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
@@ -741,37 +918,19 @@ class EventsStore(SQLBaseStore):
# event isn't an outlier any more.
self._update_backward_extremeties(txn, [event])
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ @classmethod
+ def _delete_existing_rows_txn(cls, txn, events_and_contexts):
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only events that we haven't
- # seen before.
-
- def event_dict(event):
- return {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
-
- logger.info("Deleting existing")
+ logger.info("Deleting existing")
- for table in (
+ for table in (
"events",
"event_auth",
"event_json",
@@ -794,11 +953,30 @@ class EventsStore(SQLBaseStore):
"redactions",
"room_memberships",
"topics"
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts]
- )
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE event_id = ?" % (table,),
+ [(ev.event_id,) for ev, _ in events_and_contexts]
+ )
+
+ def _store_event_txn(self, txn, events_and_contexts):
+ """Insert new events into the event and event_json tables
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ """
+
+ if not events_and_contexts:
+ # nothing to do here
+ return
+
+ def event_dict(event):
+ d = event.get_dict()
+ d.pop("redacted", None)
+ d.pop("redacted_because", None)
+ return d
self._simple_insert_many_txn(
txn,
@@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
],
)
+ def _store_rejected_events_txn(self, txn, events_and_contexts):
+ """Add rows to the 'rejections' table for received events which were
+ rejected
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+
+ Returns:
+ list[(EventBase, EventContext)] new list, without the rejected
+ events.
+ """
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
@@ -853,17 +1044,24 @@ class EventsStore(SQLBaseStore):
)
to_remove.add(event)
- events_and_contexts = [
+ return [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
+ def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+ """Update all the miscellaneous tables for new events
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ backfilled (bool): True if the events were backfilled
+ """
+
if not events_and_contexts:
- # Make sure we don't pass an empty list to functions that expect to
- # be storing at least one element.
+ # nothing to do here
return
- # From this point onwards the events are only ones that weren't rejected.
-
for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table.
if context.push_actions:
@@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
],
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
-
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- if backfilled:
- # Backfilled events come before the current state so we don't need
- # to update the current state table
- return
-
- return
-
def _add_to_cache(self, txn, events_and_contexts):
to_prefill = []
@@ -1584,6 +1771,94 @@ class EventsStore(SQLBaseStore):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_id, upper_bound))
+ new_event_updates.extend(txn)
+
+ return new_event_updates
+ return self.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+ return self.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
@cached(num_args=5, max_entries=10)
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
@@ -1597,14 +1872,13 @@ class EventsStore(SQLBaseStore):
def get_all_new_events_txn(txn):
sql = (
- "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
@@ -1630,15 +1904,13 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = []
sql = (
- "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
- " eg.state_group"
- " FROM events as e"
- " JOIN event_json as ej"
- " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
- " LEFT JOIN event_to_state_groups as eg"
- " ON e.event_id = eg.event_id"
- " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
- " ORDER BY e.stream_ordering DESC"
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
@@ -1825,7 +2097,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in curr_state.items()
+ for key, state_id in curr_state.iteritems()
],
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..3b5e0a4fb9 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
key_ids
Args:
server_name (str): The name of the server.
- key_ids (list of str): List of key_ids to try and look up.
+ key_ids (iterable[str]): key_ids to try and look up.
Returns:
- (list of VerifyKey): The verification keys.
+ Deferred: resolves to dict[str, VerifyKey]: map from
+ key_id to verification key.
"""
keys = {}
for key_id in key_ids:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ed84db6b4b..6e623843d5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
),
(current_version,)
)
- applied_deltas = [d for d, in txn.fetchall()]
+ applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded
return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4d1590d2b4..9e9d3c2591 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
- self._invalidate_cache_and_stream(
- txn, self._get_presence_for_user, (state.user_id,)
+ txn.call_after(
+ self._get_presence_for_user.invalidate, (state.user_id,)
)
# Actually insert new rows
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8cc9f0353b..34d2f82b7f 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,6 +135,48 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
+ def get_all_updated_pushers_rows(self, last_id, current_id, limit):
+ """Get all the pushers that have changed between the given tokens.
+
+ Returns:
+ Deferred(list(tuple)): each tuple consists of:
+ stream_id (str)
+ user_id (str)
+ app_id (str)
+ pushkey (str)
+ was_deleted (bool): whether the pusher was added/updated (False)
+ or deleted (True)
+ """
+
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_pushers_rows_txn(txn):
+ sql = (
+ "SELECT id, user_name, app_id, pushkey"
+ " FROM pushers"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ results = [list(row) + [False] for row in txn]
+
+ sql = (
+ "SELECT stream_id, user_id, app_id, pushkey"
+ " FROM deleted_pushers"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+
+ results.extend(list(row) + [True] for row in txn)
+ results.sort() # Sort so that they're ordered by stream id
+
+ return results
+ return self.runInteraction(
+ "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
+ )
+
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 5cf41501ea..6b0f8c2787 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
)
txn.execute(sql, (room_id, receipt_type, user_id))
- results = txn.fetchall()
- if results and topological_ordering:
- for to, so, _ in results:
+ if topological_ordering:
+ for to, so, _ in txn:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060c3..ec2c52ab93 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
- return dict(txn.fetchall())
+ return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fdf5..e4c56cc175 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
- return dict(txn.fetchall())
+ return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
@@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {}
# A room is visible if its visible on any list.
- for room_id, visibility in txn.fetchall():
+ for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 545d3d3a99..367dbbbcf6 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+ @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+ def get_hosts_in_room(self, room_id, cache_context):
+ """Returns the set of all hosts currently in the room
+ """
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate,
+ )
+ hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+ defer.returnValue(hosts)
+
@cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
-
- rows = self._get_members_rows_txn(
- txn,
- room_id=room_id,
- membership=Membership.JOIN,
+ sql = (
+ "SELECT m.user_id FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id "
+ " AND m.room_id = c.room_id "
+ " AND m.user_id = c.state_key"
+ " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
- return [r["user_id"] for r in rows]
+ txn.execute(sql, (room_id, Membership.JOIN,))
+ return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@@ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore):
return results
- def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
- where_clause = "c.room_id = ?"
- where_values = [room_id]
-
- if membership:
- where_clause += " AND m.membership = ?"
- where_values.append(membership)
-
- if user_id:
- where_clause += " AND m.user_id = ?"
- where_values.append(user_id)
-
- sql = (
- "SELECT m.* FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND %(where)s"
- ) % {
- "where": where_clause,
- }
-
- txn.execute(sql, where_values)
- rows = self.cursor_to_dict(txn)
-
- return rows
-
- @cached(max_entries=500000, iterable=True)
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id):
- return self.get_rooms_for_user_where_membership_is(
+ """Returns a set of room_ids the user is currently joined to
+ """
+ rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
+ defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
- rooms = yield self.get_rooms_for_user(
+ room_ids = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
- for room in rooms:
+ for room_id in room_ids:
user_ids = yield self.get_users_in_room(
- room.room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927d7..67d5d9969a 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
- return {k: v for k, v in txn.fetchall()}
+ return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 84482d8285..fb23f6f462 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine
@@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
where_clause="type='m.room.member'",
)
+ @cachedInlineCallbacks(max_entries=100000, iterable=True)
+ def get_current_state_ids(self, room_id):
+ rows = yield self._simple_select_list(
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=["event_id", "type", "state_key"],
+ desc="_calculate_state_delta",
+ )
+ defer.returnValue({
+ (r["type"], r["state_key"]): r["event_id"] for r in rows
+ })
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -78,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@@ -96,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in group_to_ids.values()
- for ev_id in group_ids.values()
+ ev_id for group_ids in group_to_ids.itervalues()
+ for ev_id in group_ids.itervalues()
],
get_prev_content=False
)
defer.returnValue({
group: [
- state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ state_event_map[v] for v in event_id_map.itervalues()
+ if v in state_event_map
]
- for group, event_id_map in group_to_ids.items()
+ for group, event_id_map in group_to_ids.iteritems()
})
def _have_persisted_state_group_txn(self, txn, state_group):
@@ -124,6 +137,16 @@ class StateStore(SQLBaseStore):
continue
if context.current_state_ids is None:
+ # AFAIK, this can never happen
+ logger.error(
+ "Non-outlier event %s had current_state_ids==None",
+ event.event_id)
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
@@ -168,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.delta_ids.items()
+ for key, state_id in context.delta_ids.iteritems()
],
)
else:
@@ -183,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.current_state_ids.items()
+ for key, state_id in context.current_state_ids.iteritems()
],
)
@@ -195,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id,
"event_id": event_id,
}
- for event_id, state_group_id in state_groups.items()
+ for event_id, state_group_id in state_groups.iteritems()
],
)
@@ -319,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
- rows = self.cursor_to_dict(txn)
- for row in rows:
- key = (row["type"], row["state_key"])
- results[group][key] = row["event_id"]
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
else:
if types is not None:
where_clause = "AND (%s)" % (
@@ -351,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,),
args
)
- rows = txn.fetchall()
- results[group].update({
- (typ, state_key): event_id
- for typ, state_key, event_id in rows
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
- })
+ )
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
@@ -393,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
- [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in group_to_state[group].items()
+ for k, v in group_to_state[group].iteritems()
if v in state_event_map
}
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -430,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.values())
+ groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in event_to_groups.items()
+ for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -474,7 +496,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=10000)
+ @cached(num_args=2, max_entries=100000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
@@ -547,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None)
return {
- k: v for k, v in state_dict_ids.items()
+ k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1])
}, missing_types, got_all
@@ -606,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
- for group, group_state_dict in group_to_state_dict.items():
+ for group, group_state_dict in group_to_state_dict.iteritems():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
@@ -621,10 +643,10 @@ class StateStore(SQLBaseStore):
else:
state_dict = results[group]
- state_dict.update({
- (intern_string(k[0]), intern_string(k[1])): v
- for k, v in group_state_dict.items()
- })
+ state_dict.update(
+ ((intern_string(k[0]), intern_string(k[1])), v)
+ for k, v in group_state_dict.iteritems()
+ )
self._state_group_cache.update(
cache_seq_num,
@@ -635,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
- for group, state_dict in results.items():
+ for group, state_dict in results.iteritems():
results[group] = {
key: event_id
- for key, event_id in state_dict.items()
+ for key, event_id in state_dict.iteritems()
if event_id
}
@@ -727,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys
delta_state = {
- key: value for key, value in curr_state.items()
+ key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value
}
@@ -767,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_state.items()
+ for key, state_id in delta_state.iteritems()
],
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 200d124632..dddd5fc0e7 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
+
+ def has_room_changed_since(self, room_id, stream_id):
+ return self._events_stream_cache.has_entity_changed(room_id, stream_id)
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa59b..bff73f3f04 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
- for tag, content in txn.fetchall():
+ for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
@@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
- room_ids = [row[0] for row in txn.fetchall()]
+ room_ids = [row[0] for row in txn]
return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 46cf93ff87..95031dc9ec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1):
+ """
+
+ Args:
+ db_conn (object):
+ table (str):
+ column (str):
+ step (int):
+
+ Returns:
+ int
+ """
cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
+
+ Returns:
+ int
"""
with self._lock:
if self._unfinished_ids:
|