diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 75af44d787..c6ce65b4cc 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,6 +37,9 @@ from .rejections import RejectionsStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
+from .end_to_end_keys import EndToEndKeyStore
+
+from .receipts import ReceiptsStore
import fnmatch
@@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 19
+SCHEMA_VERSION = 21
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
PushRuleStore,
ApplicationServiceTransactionStore,
EventsStore,
+ ReceiptsStore,
+ EndToEndKeyStore,
):
def __init__(self, hs):
@@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip)
try:
- last_seen = self.client_ip_last_seen.get(*key)
+ last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
@@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
- self.client_ip_last_seen.prefill(*key + (now,))
+ self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@@ -348,7 +353,12 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
- module.run_upgrade(cur)
+ module.run_upgrade(cur, database_engine)
+ elif ext == ".pyc":
+ # Sometimes .pyc files turn up anyway even though we've
+ # disabled their generation; e.g. from distribution package
+ # installers. Silently skip it
+ pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..73eea157a4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
+import inspect
import sys
import time
import threading
@@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
+ return val
cache_counter.inc_misses(self.name)
- raise KeyError()
- def update(self, sequence, *args):
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ self.prefill(key, value)
+ def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
- self.cache[keyargs] = value
+ self.cache[key] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, key):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(keyargs, None)
+ self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@@ -127,9 +131,12 @@ class Cache(object):
self.cache.clear()
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def wrap(orig):
- cache = Cache(
- name=orig.__name__,
- max_entries=max_entries,
- keylen=num_args,
- lru=lru,
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ inlineCallbacks=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ self.cache = Cache(
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
)
- @functools.wraps(orig)
- @defer.inlineCallbacks
- def wrapped(self, *keyargs):
+ def __get__(self, obj, objtype=None):
+
+ @functools.wraps(self.orig)
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result = cache.get(*keyargs)
+ cached_result_d = self.cache.get(cache_key)
+
+ observer = cached_result_d.observe()
if DEBUG_CACHES:
- actual_result = yield orig(self, *keyargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, cache_key,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
+
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ self.cache.invalidate(cache_key)
+ return f
+
+ ret.addErrback(onErr)
+
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
- ret = yield orig(self, *keyargs)
+ return ret.observe()
- cache.update(sequence, *keyargs + (ret,))
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
- defer.returnValue(ret)
+ obj.__dict__[self.orig.__name__] = wrapped
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
return wrapped
- return wrap
+
+def cached(max_entries=1000, num_args=1, lru=True):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
+
+
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
class LoggingTransaction(object):
@@ -312,13 +380,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator()
+ self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 2b2bdf8615..f3947bbe89 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
},
desc="create_room_alias_association",
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
new file mode 100644
index 0000000000..325740d7d0
--- /dev/null
+++ b/synapse/storage/end_to_end_keys.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class EndToEndKeyStore(SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
+ return self._simple_upsert(
+ table="e2e_device_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ )
+
+ def get_e2e_device_keys(self, query_list):
+ """Fetch a list of device keys.
+ Args:
+ query_list(list): List of pairs of user_ids and device_ids.
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key json byte strings.
+ """
+ def _get_e2e_device_keys(txn):
+ result = {}
+ for user_id, device_id in query_list:
+ user_result = result.setdefault(user_id, {})
+ keyvalues = {"user_id": user_id}
+ if device_id:
+ keyvalues["device_id"] = device_id
+ rows = self._simple_select_list_txn(
+ txn, table="e2e_device_keys_json",
+ keyvalues=keyvalues,
+ retcols=["device_id", "key_json"]
+ )
+ for row in rows:
+ user_result[row["device_id"]] = row["key_json"]
+ return result
+ return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+ def _add_e2e_one_time_keys(txn):
+ for (algorithm, key_id, json_bytes) in key_list:
+ self._simple_upsert_txn(
+ txn, table="e2e_one_time_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ values={
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ )
+ return self.runInteraction(
+ "add_e2e_one_time_keys", _add_e2e_one_time_keys
+ )
+
+ def count_e2e_one_time_keys(self, user_id, device_id):
+ """ Count the number of one time keys the server has for a device
+ Returns:
+ Dict mapping from algorithm to number of keys for that algorithm.
+ """
+ def _count_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ?"
+ " GROUP BY algorithm"
+ )
+ txn.execute(sql, (user_id, device_id))
+ result = {}
+ for algorithm, key_count in txn.fetchall():
+ result[algorithm] = key_count
+ return result
+ return self.runInteraction(
+ "count_e2e_one_time_keys", _count_e2e_one_time_keys
+ )
+
+ def claim_e2e_one_time_keys(self, query_list):
+ """Take a list of one time keys out of the database"""
+ def _claim_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
+ result = {}
+ delete = []
+ for user_id, device_id, algorithm in query_list:
+ user_result = result.setdefault(user_id, {})
+ device_result = user_result.setdefault(device_id, {})
+ txn.execute(sql, (user_id, device_id, algorithm))
+ for key_id, key_json in txn.fetchall():
+ device_result[algorithm + ":" + key_id] = key_json
+ delete.append((user_id, device_id, algorithm, key_id))
+ sql = (
+ "DELETE FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " AND key_id = ?"
+ )
+ for user_id, device_id, algorithm, key_id in delete:
+ txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ return result
+ return self.runInteraction(
+ "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+ )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 1ba073884b..910b6598a7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -49,14 +49,22 @@ class EventFederationStore(SQLBaseStore):
results = set()
base_sql = (
- "SELECT auth_id FROM event_auth WHERE event_id = ?"
+ "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
)
front = set(event_ids)
while front:
new_front = set()
- for f in front:
- txn.execute(base_sql, (f,))
+ front_list = list(front)
+ chunks = [
+ front_list[x:x+100]
+ for x in xrange(0, len(front), 100)
+ ]
+ for chunk in chunks:
+ txn.execute(
+ base_sql % (",".join(["?"] * len(chunk)),),
+ chunk
+ )
new_front.update([r[0] for r in txn.fetchall()])
new_front -= results
@@ -274,8 +282,7 @@ class EventFederationStore(SQLBaseStore):
},
)
- def _handle_prev_events(self, txn, outlier, event_id, prev_events,
- room_id):
+ def _handle_mult_prev_events(self, txn, events):
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@@ -285,70 +292,77 @@ class EventFederationStore(SQLBaseStore):
table="event_edges",
values=[
{
- "event_id": event_id,
+ "event_id": ev.event_id,
"prev_event_id": e_id,
- "room_id": room_id,
+ "room_id": ev.room_id,
"is_state": False,
}
- for e_id, _ in prev_events
+ for ev in events
+ for e_id, _ in ev.prev_events
],
)
- # Update the extremities table if this is not an outlier.
- if not outlier:
- for e_id, _ in prev_events:
- # TODO (erikj): This could be done as a bulk insert
- self._simple_delete_txn(
- txn,
- table="event_forward_extremities",
- keyvalues={
- "event_id": e_id,
- "room_id": room_id,
- }
- )
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
- # We only insert as a forward extremity the new event if there are
- # no other events that reference it as a prev event
- query = (
- "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
- )
+ for room_id, room_events in events_by_room.items():
+ prevs = [
+ e_id for ev in room_events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ]
+ if prevs:
+ txn.execute(
+ "DELETE FROM event_forward_extremities"
+ " WHERE room_id = ?"
+ " AND event_id in (%s)" % (
+ ",".join(["?"] * len(prevs)),
+ ),
+ [room_id] + prevs,
+ )
- txn.execute(query, (event_id,))
+ query = (
+ "INSERT INTO event_forward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
+ " )"
+ )
- if not txn.fetchone():
- query = (
- "INSERT INTO event_forward_extremities"
- " (event_id, room_id)"
- " VALUES (?, ?)"
- )
+ txn.executemany(
+ query,
+ [(ev.event_id, ev.room_id, ev.event_id) for ev in events]
+ )
- txn.execute(query, (event_id, room_id))
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
- txn.executemany(query, [
- (e_id, room_id, e_id, room_id, e_id, room_id, False)
- for e_id, _ in prev_events
- ])
+ txn.executemany(query, [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ])
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.execute(query, (event_id, room_id))
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [(ev.event_id, ev.room_id) for ev in events]
+ )
+ for room_id in events_by_room:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, room_id
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_backfill_events(self, room_id, event_list, limit):
@@ -400,10 +414,12 @@ class EventFederationStore(SQLBaseStore):
keyvalues={
"event_id": event_id,
},
- retcol="depth"
+ retcol="depth",
+ allow_none=True,
)
- queue.put((-depth, event_id))
+ if depth:
+ queue.put((-depth, event_id))
while not queue.empty() and len(event_results) < limit:
try:
@@ -489,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
- txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
+ txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 20a8d81794..5b64918024 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
-from synapse.crypto.event_signing import compute_event_reference_hash
-from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json
from contextlib import contextmanager
@@ -47,6 +45,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
+ def persist_events(self, events_and_contexts, backfilled=False,
+ is_new_state=True):
+ if not events_and_contexts:
+ return
+
+ if backfilled:
+ if not self.min_token_deferred.called:
+ yield self.min_token_deferred
+ start = self.min_token - 1
+ self.min_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_token, -1)
+
+ @contextmanager
+ def stream_ordering_manager():
+ yield stream_orderings
+ stream_ordering_manager = stream_ordering_manager()
+ else:
+ stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
+ self, len(events_and_contexts)
+ )
+
+ with stream_ordering_manager as stream_orderings:
+ for (event, _), stream in zip(events_and_contexts, stream_orderings):
+ event.internal_metadata.stream_ordering = stream
+
+ chunks = [
+ events_and_contexts[x:x+100]
+ for x in xrange(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+ yield self.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=chunk,
+ backfilled=backfilled,
+ is_new_state=is_new_state,
+ )
+
+ @defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
@@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
try:
with stream_ordering_manager as stream_ordering:
+ event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
- stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
@@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
- stream_ordering=None, is_new_state=True,
- current_state=None):
-
- # Remove the any existing cache entries for the event_id
- txn.call_after(self._invalidate_get_event_cache, event.event_id)
-
+ is_new_state=True, current_state=None):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
@@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore):
}
)
- outlier = event.internal_metadata.is_outlier()
+ return self._persist_events_txn(
+ txn,
+ [(event, context)],
+ backfilled=backfilled,
+ is_new_state=is_new_state,
+ )
- if not outlier:
- self._update_min_depth_for_room_txn(
- txn,
- event.room_id,
- event.depth
+ @log_function
+ def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+ is_new_state=True):
+
+ # Remove the any existing cache entries for the event_ids
+ for event, _ in events_and_contexts:
+ txn.call_after(self._invalidate_get_event_cache, event.event_id)
+
+ depth_updates = {}
+ for event, _ in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+ depth_updates[event.room_id] = max(
+ event.depth, depth_updates.get(event.room_id, event.depth)
)
- have_persisted = self._simple_select_one_txn(
- txn,
- table="events",
- keyvalues={"event_id": event.event_id},
- retcols=["event_id", "outlier"],
- allow_none=True,
+ for room_id, depth in depth_updates.items():
+ self._update_min_depth_for_room_txn(txn, room_id, depth)
+
+ txn.execute(
+ "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
+ ",".join(["?"] * len(events_and_contexts)),
+ ),
+ [event.event_id for event, _ in events_and_contexts]
)
+ have_persisted = {
+ event_id: outlier
+ for event_id, outlier in txn.fetchall()
+ }
+
+ event_map = {}
+ to_remove = set()
+ for event, context in events_and_contexts:
+ # Handle the case of the list including the same event multiple
+ # times. The tricky thing here is when they differ by whether
+ # they are an outlier.
+ if event.event_id in event_map:
+ other = event_map[event.event_id]
+
+ if not other.internal_metadata.is_outlier():
+ to_remove.add(event)
+ continue
+ elif not event.internal_metadata.is_outlier():
+ to_remove.add(event)
+ continue
+ else:
+ to_remove.add(other)
- metadata_json = encode_json(
- event.internal_metadata.get_dict(),
- using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
-
- # If we have already persisted this event, we don't need to do any
- # more processing.
- # The processing above must be done on every call to persist event,
- # since they might not have happened on previous calls. For example,
- # if we are persisting an event that we had persisted as an outlier,
- # but is no longer one.
- if have_persisted:
- if not outlier and have_persisted["outlier"]:
- self._store_state_groups_txn(txn, event, context)
+ event_map[event.event_id] = event
+
+ if event.event_id not in have_persisted:
+ continue
+
+ to_remove.add(event)
+
+ outlier_persisted = have_persisted[event.event_id]
+ if not event.internal_metadata.is_outlier() and outlier_persisted:
+ self._store_state_groups_txn(
+ txn, event, context,
+ )
+
+ metadata_json = encode_json(
+ event.internal_metadata.get_dict(),
+ using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8")
sql = (
"UPDATE event_json SET internal_metadata = ?"
@@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore):
sql,
(False, event.event_id,)
)
- return
-
- if not outlier:
- self._store_state_groups_txn(txn, event, context)
- self._handle_prev_events(
- txn,
- outlier=outlier,
- event_id=event.event_id,
- prev_events=event.prev_events,
- room_id=event.room_id,
+ events_and_contexts = filter(
+ lambda ec: ec[0] not in to_remove,
+ events_and_contexts
)
- if event.type == EventTypes.Member:
- self._store_room_member_txn(txn, event)
- elif event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
-
- event_dict = {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
+ if not events_and_contexts:
+ return
- self._simple_insert_txn(
+ self._store_mult_state_groups_txn(txn, [
+ (event, context)
+ for event, context in events_and_contexts
+ if not event.internal_metadata.is_outlier()
+ ])
+
+ self._handle_mult_prev_events(
txn,
- table="event_json",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": metadata_json,
- "json": encode_json(
- event_dict, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8"),
- },
+ events=[event for event, _ in events_and_contexts],
)
- content = encode_json(
- event.content, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
-
- vals = {
- "topological_ordering": event.depth,
- "event_id": event.event_id,
- "type": event.type,
- "room_id": event.room_id,
- "content": content,
- "processed": True,
- "outlier": outlier,
- "depth": event.depth,
- }
+ for event, _ in events_and_contexts:
+ if event.type == EventTypes.Name:
+ self._store_room_name_txn(txn, event)
+ elif event.type == EventTypes.Topic:
+ self._store_room_topic_txn(txn, event)
+ elif event.type == EventTypes.Redaction:
+ self._store_redaction(txn, event)
- unrec = {
- k: v
- for k, v in event.get_dict().items()
- if k not in vals.keys() and k not in [
- "redacted",
- "redacted_because",
- "signatures",
- "hashes",
- "prev_events",
+ self._store_room_members_txn(
+ txn,
+ [
+ event
+ for event, _ in events_and_contexts
+ if event.type == EventTypes.Member
]
- }
+ )
- vals["unrecognized_keys"] = encode_json(
- unrec, using_frozen_dicts=USE_FROZEN_DICTS
- ).decode("UTF-8")
+ def event_dict(event):
+ return {
+ k: v
+ for k, v in event.get_dict().items()
+ if k not in [
+ "redacted",
+ "redacted_because",
+ ]
+ }
- sql = (
- "INSERT INTO events"
- " (stream_ordering, topological_ordering, event_id, type,"
- " room_id, content, processed, outlier, depth)"
- " VALUES (?,?,?,?,?,?,?,?,?)"
+ self._simple_insert_many_txn(
+ txn,
+ table="event_json",
+ values=[
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "internal_metadata": encode_json(
+ event.internal_metadata.get_dict(),
+ using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ "json": encode_json(
+ event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ }
+ for event, _ in events_and_contexts
+ ],
)
- txn.execute(
- sql,
- (
- stream_ordering, event.depth, event.event_id, event.type,
- event.room_id, content, True, outlier, event.depth
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="events",
+ values=[
+ {
+ "stream_ordering": event.internal_metadata.stream_ordering,
+ "topological_ordering": event.depth,
+ "depth": event.depth,
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "processed": True,
+ "outlier": event.internal_metadata.is_outlier(),
+ "content": encode_json(
+ event.content, using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
+ }
+ for event, _ in events_and_contexts
+ ],
)
if context.rejected:
@@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, context.rejected
)
- for hash_alg, hash_base64 in event.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_event_content_hash_txn(
- txn, event.event_id, hash_alg, hash_bytes,
- )
-
- for prev_event_id, prev_hashes in event.prev_events:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg,
- hash_bytes
- )
-
self._simple_insert_many_txn(
txn,
table="event_auth",
@@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
"room_id": event.room_id,
"auth_id": auth_id,
}
+ for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
],
)
- (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
- self._store_event_reference_hash_txn(
- txn, event.event_id, ref_alg, ref_hash_bytes
+ self._store_event_reference_hashes_txn(
+ txn, [event for event, _ in events_and_contexts]
)
- if event.is_state():
+ state_events_and_contexts = filter(
+ lambda i: i[0].is_state(),
+ events_and_contexts,
+ )
+
+ state_values = []
+ for event, context in state_events_and_contexts:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
@@ -337,51 +402,55 @@ class EventsStore(SQLBaseStore):
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
- self._simple_insert_txn(
- txn,
- "state_events",
- vals,
- )
+ state_values.append(vals)
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": event.event_id,
- "prev_event_id": e_id,
- "room_id": event.room_id,
- "is_state": True,
- }
- for e_id, h in event.prev_state
- ],
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="state_events",
+ values=state_values,
+ )
- if is_new_state and not context.rejected:
- txn.call_after(
- self.get_current_state_for_key.invalidate,
- event.room_id, event.type, event.state_key
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": event.event_id,
+ "prev_event_id": prev_id,
+ "room_id": event.room_id,
+ "is_state": True,
+ }
+ for event, _ in state_events_and_contexts
+ for prev_id, _ in event.prev_state
+ ],
+ )
- if (event.type == EventTypes.Name
- or event.type == EventTypes.Aliases):
+ if is_new_state:
+ for event, _ in state_events_and_contexts:
+ if not context.rejected:
txn.call_after(
- self.get_room_name_and_aliases.invalidate,
- event.room_id
+ self.get_current_state_for_key.invalidate,
+ (event.room_id, event.type, event.state_key,)
)
- self._simple_upsert_txn(
- txn,
- "current_state_events",
- keyvalues={
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- values={
- "event_id": event.event_id,
- }
- )
+ if event.type in [EventTypes.Name, EventTypes.Aliases]:
+ txn.call_after(
+ self.get_room_name_and_aliases.invalidate,
+ (event.room_id,)
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ "current_state_events",
+ keyvalues={
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ }
+ )
return
@@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
- self._get_event_cache.invalidate(event_id, check_redacted,
- get_prev_content)
+ self._get_event_cache.invalidate(
+ (event_id, check_redacted, get_prev_content)
+ )
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events:
try:
ret = self._get_event_cache.get(
- event_id, check_redacted, get_prev_content
+ (event_id, check_redacted, get_prev_content,)
)
if allow_rejected or not ret.rejected_reason:
@@ -736,7 +806,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event(
redaction_id,
- check_redacted=False
+ check_redacted=False,
+ allow_none=True,
)
if because:
@@ -746,12 +817,13 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
+ allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
defer.returnValue(ev)
@@ -808,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
return ev
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..49b8e37cfd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
+ @cachedInlineCallbacks()
+ def get_all_server_verify_keys(self, server_name):
+ rows = yield self._simple_select_list(
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ },
+ retcols=["key_id", "verify_key"],
+ desc="get_all_server_verify_keys",
+ )
+
+ defer.returnValue({
+ row["key_id"]: decode_verify_key_bytes(
+ row["key_id"], str(row["verify_key"])
+ )
+ for row in rows
+ })
+
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
@@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- sql = (
- "SELECT key_id, verify_key FROM server_signature_keys"
- " WHERE server_name = ?"
- " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
- )
-
- rows = yield self._execute_and_decode(
- "get_server_verify_keys", sql, server_name, *key_ids
- )
-
- keys = []
- for row in rows:
- key_id = row["key_id"]
- key_bytes = row["verify_key"]
- key = decode_verify_key_bytes(key_id, str(key_bytes))
- keys.append(key)
- defer.returnValue(keys)
+ keys = yield self.get_all_server_verify_keys(server_name)
+ defer.returnValue({
+ k: keys[k]
+ for k in key_ids
+ if k in keys and keys[k]
+ })
+ @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
@@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- return self._simple_upsert(
+ yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
@@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
+ self.get_all_server_verify_keys.invalidate((server_name,))
+
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
@@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
+ desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index fefcf6bce0..576cf670cc 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid},
desc="del_presence_list",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 4cac118d17..9b88ca7b39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
import logging
@@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
@@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
@@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
- self.get_push_rules_for_user.invalidate(user_name)
- self.get_push_rules_enabled_for_user.invalidate(user_name)
+ self.get_push_rules_for_user.invalidate((user_name,))
+ self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id},
)
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
new file mode 100644
index 0000000000..b79d6683ca
--- /dev/null
+++ b/synapse/storage/receipts.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore, cachedInlineCallbacks
+
+from twisted.internet import defer
+
+from synapse.util import unwrapFirstError
+
+from blist import sorteddict
+import logging
+import ujson as json
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsStore(SQLBaseStore):
+ def __init__(self, hs):
+ super(ReceiptsStore, self).__init__(hs)
+
+ self._receipts_stream_cache = _RoomStreamChangeCache()
+
+ @defer.inlineCallbacks
+ def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ """Get receipts for multiple rooms for sending to clients.
+
+ Args:
+ room_ids (list): List of room_ids.
+ to_key (int): Max stream id to fetch receipts upto.
+ from_key (int): Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ list: A list of receipts.
+ """
+ room_ids = set(room_ids)
+
+ if from_key:
+ room_ids = yield self._receipts_stream_cache.get_rooms_changed(
+ self, room_ids, from_key
+ )
+
+ results = yield defer.gatherResults(
+ [
+ self.get_linearized_receipts_for_room(
+ room_id, to_key, from_key=from_key
+ )
+ for room_id in room_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue([ev for res in results for ev in res])
+
+ @defer.inlineCallbacks
+ def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """Get receipts for a single room for sending to clients.
+
+ Args:
+ room_ids (str): The room id.
+ to_key (int): Max stream id to fetch receipts upto.
+ from_key (int): Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ list: A list of receipts.
+ """
+ def f(txn):
+ if from_key:
+ sql = (
+ "SELECT * FROM receipts_linearized WHERE"
+ " room_id = ? AND stream_id > ? AND stream_id <= ?"
+ )
+
+ txn.execute(
+ sql,
+ (room_id, from_key, to_key)
+ )
+ else:
+ sql = (
+ "SELECT * FROM receipts_linearized WHERE"
+ " room_id = ? AND stream_id <= ?"
+ )
+
+ txn.execute(
+ sql,
+ (room_id, to_key)
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction(
+ "get_linearized_receipts_for_room", f
+ )
+
+ if not rows:
+ defer.returnValue([])
+
+ content = {}
+ for row in rows:
+ content.setdefault(
+ row["event_id"], {}
+ ).setdefault(
+ row["receipt_type"], {}
+ )[row["user_id"]] = json.loads(row["data"])
+
+ defer.returnValue([{
+ "type": "m.receipt",
+ "room_id": room_id,
+ "content": content,
+ }])
+
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_max_token(self)
+
+ @cachedInlineCallbacks()
+ def get_graph_receipts_for_room(self, room_id):
+ """Get receipts for sending to remote servers.
+ """
+ rows = yield self._simple_select_list(
+ table="receipts_graph",
+ keyvalues={"room_id": room_id},
+ retcols=["receipt_type", "user_id", "event_id"],
+ desc="get_linearized_receipts_for_room",
+ )
+
+ result = {}
+ for row in rows:
+ result.setdefault(
+ row["user_id"], {}
+ ).setdefault(
+ row["receipt_type"], []
+ ).append(row["event_id"])
+
+ defer.returnValue(result)
+
+ def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
+ user_id, event_id, data, stream_id):
+
+ # We don't want to clobber receipts for more recent events, so we
+ # have to compare orderings of existing receipts
+ sql = (
+ "SELECT topological_ordering, stream_ordering, event_id FROM events"
+ " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+ )
+
+ txn.execute(sql, (room_id, receipt_type, user_id))
+ results = txn.fetchall()
+
+ if results:
+ res = self._simple_select_one_txn(
+ txn,
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ )
+ topological_ordering = int(res["topological_ordering"])
+ stream_ordering = int(res["stream_ordering"])
+
+ for to, so, _ in results:
+ if int(to) > topological_ordering:
+ return False
+ elif int(to) == topological_ordering and int(so) >= stream_ordering:
+ return False
+
+ self._simple_delete_txn(
+ txn,
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="receipts_linearized",
+ values={
+ "stream_id": stream_id,
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_id": event_id,
+ "data": json.dumps(data),
+ }
+ )
+
+ return True
+
+ @defer.inlineCallbacks
+ def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ """Insert a receipt, either from local client or remote server.
+
+ Automatically does conversion between linearized and graph
+ representations.
+ """
+ if not event_ids:
+ return
+
+ if len(event_ids) == 1:
+ linearized_event_id = event_ids[0]
+ else:
+ # we need to points in graph -> linearized form.
+ # TODO: Make this better.
+ def graph_to_linear(txn):
+ query = (
+ "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
+ " SELECT max(stream_ordering) WHERE event_id IN (%s)"
+ ")"
+ ) % (",".join(["?"] * len(event_ids)))
+
+ txn.execute(query, [room_id] + event_ids)
+ rows = txn.fetchall()
+ if rows:
+ return rows[0][0]
+ else:
+ raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
+
+ linearized_event_id = yield self.runInteraction(
+ "insert_receipt_conv", graph_to_linear
+ )
+
+ stream_id_manager = yield self._receipts_id_gen.get_next(self)
+ with stream_id_manager as stream_id:
+ yield self._receipts_stream_cache.room_has_changed(
+ self, room_id, stream_id
+ )
+ have_persisted = yield self.runInteraction(
+ "insert_linearized_receipt",
+ self.insert_linearized_receipt_txn,
+ room_id, receipt_type, user_id, linearized_event_id,
+ data,
+ stream_id=stream_id,
+ )
+
+ if not have_persisted:
+ defer.returnValue(None)
+
+ yield self.insert_graph_receipt(
+ room_id, receipt_type, user_id, event_ids, data
+ )
+
+ max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ defer.returnValue((stream_id, max_persisted_id))
+
+ def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
+ data):
+ return self.runInteraction(
+ "insert_graph_receipt",
+ self.insert_graph_receipt_txn,
+ room_id, receipt_type, user_id, event_ids, data
+ )
+
+ def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
+ user_id, event_ids, data):
+ self._simple_delete_txn(
+ txn,
+ table="receipts_graph",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ }
+ )
+ self._simple_insert_txn(
+ txn,
+ table="receipts_graph",
+ values={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "event_ids": json.dumps(event_ids),
+ "data": json.dumps(data),
+ }
+ )
+
+
+class _RoomStreamChangeCache(object):
+ """Keeps track of the stream_id of the latest change in rooms.
+
+ Given a list of rooms and stream key, it will give a subset of rooms that
+ may have changed since that key. If the key is too old then the cache
+ will simply return all rooms.
+ """
+ def __init__(self, size_of_cache=10000):
+ self._size_of_cache = size_of_cache
+ self._room_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_key = None
+
+ @defer.inlineCallbacks
+ def get_rooms_changed(self, store, room_ids, key):
+ """Returns subset of room ids that have had new receipts since the
+ given key. If the key is too old it will just return the given list.
+ """
+ if key > (yield self._get_earliest_key(store)):
+ keys = self._cache.keys()
+ i = keys.bisect_right(key)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(room_ids)
+ else:
+ result = room_ids
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def room_has_changed(self, store, room_id, key):
+ """Informs the cache that the room has been changed at the given key.
+ """
+ if key > (yield self._get_earliest_key(store)):
+ old_key = self._room_to_key.get(room_id, None)
+ if old_key:
+ key = max(key, old_key)
+ self._cache.pop(old_key, None)
+ self._cache[key] = room_id
+
+ while len(self._cache) > self._size_of_cache:
+ k, r = self._cache.popitem()
+ self._earliest_key = max(k, self._earliest_key)
+ self._room_to_key.pop(r, None)
+
+ @defer.inlineCallbacks
+ def _get_earliest_key(self, store):
+ if self._earliest_key is None:
+ self._earliest_key = yield store.get_max_receipt_stream_id()
+ self._earliest_key = int(self._earliest_key)
+
+ defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 90e2606be2..4eaa088b36 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
- self.get_user_by_token.invalidate(r)
+ self.get_user_by_token.invalidate((r,))
@cached()
def get_user_by_token(self, token):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 4612a8aa83..dd5bc2c8fb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
import collections
import logging
@@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index d36a6c18a8..9f14f38f24 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,38 +35,28 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def _store_room_member_txn(self, txn, event):
+ def _store_room_members_txn(self, txn, events):
"""Store a room member in the database.
"""
- try:
- target_user_id = event.state_key
- except:
- logger.exception(
- "Failed to parse target_user_id=%s", target_user_id
- )
- raise
-
- logger.debug(
- "_store_room_member_txn: target_user_id=%s, membership=%s",
- target_user_id,
- event.membership,
- )
-
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
- "room_memberships",
- {
- "event_id": event.event_id,
- "user_id": target_user_id,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- }
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ }
+ for event in events
+ ]
)
- txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
+ for event in events:
+ txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -88,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
- @cached()
+ @cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@@ -164,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
- @cached()
+ @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 9f3a4dd4c5..61232f9757 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
-def run_upgrade(cur):
+def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/schema/delta/20/dummy.sql
new file mode 100644
index 0000000000..e0ac49d1ec
--- /dev/null
+++ b/synapse/storage/schema/delta/20/dummy.sql
@@ -0,0 +1 @@
+SELECT 1;
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
new file mode 100644
index 0000000000..543e57bbe2
--- /dev/null
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -0,0 +1,76 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Main purpose of this upgrade is to change the unique key on the
+pushers table again (it was missed when the v16 full schema was
+made) but this also changes the pushkey and data columns to text.
+When selecting a bytea column into a text column, postgres inserts
+the hex encoded data, and there's no portable way of getting the
+UTF-8 bytes, so we have to do it in Python.
+"""
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ logger.info("Porting pushers table...")
+ cur.execute("""
+ CREATE TABLE IF NOT EXISTS pushers2 (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag VARCHAR(32) NOT NULL,
+ kind VARCHAR(8) NOT NULL,
+ app_id VARCHAR(64) NOT NULL,
+ app_display_name VARCHAR(64) NOT NULL,
+ device_display_name VARCHAR(128) NOT NULL,
+ pushkey TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ lang VARCHAR(8),
+ data TEXT,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey, user_name)
+ )
+ """)
+ cur.execute("""SELECT
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ FROM pushers
+ """)
+ count = 0
+ for row in cur.fetchall():
+ row = list(row)
+ row[8] = bytes(row[8]).decode("utf-8")
+ row[11] = bytes(row[11]).decode("utf-8")
+ cur.execute(database_engine.convert_param_style("""
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
+ row
+ )
+ count += 1
+ cur.execute("DROP TABLE pushers")
+ cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
+ logger.info("Moved %d pushers to new table", count)
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql
new file mode 100644
index 0000000000..8b4a380d11
--- /dev/null
+++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql
@@ -0,0 +1,34 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
+ user_id TEXT NOT NULL, -- The user these keys are for.
+ device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
+ ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
+ key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
+ CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
+);
+
+
+CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
+ user_id TEXT NOT NULL, -- The user this one-time key is for.
+ device_id TEXT NOT NULL, -- The device this one-time key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
+);
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql
new file mode 100644
index 0000000000..2f64d609fc
--- /dev/null
+++ b/synapse/storage/schema/delta/21/receipts.sql
@@ -0,0 +1,38 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS receipts_graph(
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_ids TEXT NOT NULL,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE TABLE IF NOT EXISTS receipts_linearized (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
+);
+
+CREATE INDEX receipts_linearized_id ON receipts_linearized(
+ stream_id
+);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index f051828630..4f15e534b4 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from _base import SQLBaseStore
from syutil.base64util import encode_base64
+from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore):
@@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()}
- def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
- hash_bytes):
+ def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
Args:
txn (cursor):
- event_id (str): Id for the Event.
- algorithm (str): Hashing algorithm.
- hash_bytes (bytes): Hash function output bytes.
+ events (list): list of Events.
"""
- self._simple_insert_txn(
+
+ vals = []
+ for event in events:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ vals.append({
+ "event_id": event.event_id,
+ "algorithm": ref_alg,
+ "hash": buffer(ref_hash_bytes),
+ })
+
+ self._simple_insert_many_txn(
txn,
- "event_reference_hashes",
- {
- "event_id": event_id,
- "algorithm": algorithm,
- "hash": buffer(hash_bytes),
- },
+ table="event_reference_hashes",
+ values=vals,
)
def _get_event_signatures_txn(self, txn, event_id):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b24de34f23..7ce51b9bdc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -81,31 +81,41 @@ class StateStore(SQLBaseStore):
f,
)
- @defer.inlineCallbacks
- def c(vals):
- vals[:] = yield self._get_events(vals, get_prev_content=False)
-
- yield defer.gatherResults(
+ state_list = yield defer.gatherResults(
[
- c(vals)
- for vals in states.values()
+ self._fetch_events_for_group(group, vals)
+ for group, vals in states.items()
],
consumeErrors=True,
)
- defer.returnValue(states)
+ defer.returnValue(dict(state_list))
+
+ def _fetch_events_for_group(self, key, events):
+ return self._get_events(
+ events, get_prev_content=False
+ ).addCallback(
+ lambda evs: (key, evs)
+ )
def _store_state_groups_txn(self, txn, event, context):
- if context.current_state is None:
- return
+ return self._store_mult_state_groups_txn(txn, [(event, context)])
- state_events = dict(context.current_state)
+ def _store_mult_state_groups_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if context.current_state is None:
+ continue
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ if context.state_group is not None:
+ state_groups[event.event_id] = context.state_group
+ continue
+
+ state_events = dict(context.current_state)
+
+ if event.is_state():
+ state_events[(event.type, event.state_key)] = event
- state_group = context.state_group
- if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn(
txn,
@@ -131,14 +141,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values()
],
)
+ state_groups[event.event_id] = state_group
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
- values={
- "state_group": state_group,
- "event_id": event.event_id,
- },
+ values=[
+ {
+ "state_group": state_groups[event.event_id],
+ "event_id": event.event_id,
+ }
+ for event, context in events_and_contexts
+ if context.current_state is not None
+ ],
)
@defer.inlineCallbacks
@@ -173,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- @cached(num_args=3)
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
@@ -190,6 +204,65 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def get_state_for_events(self, room_id, event_ids):
+ def f(txn):
+ groups = set()
+ event_to_group = {}
+ for event_id in event_ids:
+ # TODO: Remove this loop.
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ event_to_group[event_id] = group
+ groups.add(group)
+
+ group_to_state_ids = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+
+ group_to_state_ids[group] = state_ids
+
+ return event_to_group, group_to_state_ids
+
+ res = yield self.runInteraction(
+ "annotate_events_with_state_groups",
+ f,
+ )
+
+ event_to_group, group_to_state_ids = res
+
+ state_list = yield defer.gatherResults(
+ [
+ self._fetch_events_for_group(group, vals)
+ for group, vals in group_to_state_ids.items()
+ ],
+ consumeErrors=True,
+ )
+
+ state_dict = {
+ group: {
+ (ev.type, ev.state_key): ev
+ for ev in state
+ }
+ for group, state in state_list
+ }
+
+ defer.returnValue([
+ state_dict.get(event_to_group.get(event, None), None)
+ for event in event_ids
+ ])
+
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 89d1643f10..e956df62c7 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,7 +72,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
- def __init__(self):
+ def __init__(self, table, column):
+ self.table = table
+ self.column = column
+
self._lock = threading.Lock()
self._current_max = None
@@ -108,6 +111,37 @@ class StreamIdGenerator(object):
defer.returnValue(manager())
@defer.inlineCallbacks
+ def get_next_mult(self, store, n):
+ """
+ Usage:
+ with yield stream_id_gen.get_next(store, n) as stream_ids:
+ # ... persist events ...
+ """
+ if not self._current_max:
+ yield store.runInteraction(
+ "_compute_current_max",
+ self._get_or_compute_current_max,
+ )
+
+ with self._lock:
+ next_ids = range(self._current_max + 1, self._current_max + n + 1)
+ self._current_max += n
+
+ for next_id in next_ids:
+ self._unfinished_ids.append(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ with self._lock:
+ for next_id in next_ids:
+ self._unfinished_ids.remove(next_id)
+
+ defer.returnValue(manager())
+
+ @defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
@@ -126,7 +160,7 @@ class StreamIdGenerator(object):
def _get_or_compute_current_max(self, txn):
with self._lock:
- txn.execute("SELECT MAX(stream_ordering) FROM events")
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
|