diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index cd4bd28e8c..2807abbc90 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -144,16 +144,17 @@ class Config(object):
)
config_args, remaining_args = config_parser.parse_known_args(argv)
- if not config_args.config_path:
- config_parser.error(
- "Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -h SERVER_NAME"
- " -c CONFIG-FILE\""
- )
-
- config_dir_path = os.path.dirname(config_args.config_path[0])
- config_dir_path = os.path.abspath(config_dir_path)
if config_args.generate_config:
+ if not config_args.config_path:
+ config_parser.error(
+ "Must supply a config file.\nA config file can be automatically"
+ " generated using \"--generate-config -h SERVER_NAME"
+ " -c CONFIG-FILE\""
+ )
+
+ config_dir_path = os.path.dirname(config_args.config_path[0])
+ config_dir_path = os.path.abspath(config_dir_path)
+
server_name = config_args.server_name
if not server_name:
print "Most specify a server_name to a generate config for."
@@ -196,6 +197,25 @@ class Config(object):
)
sys.exit(0)
+ parser = argparse.ArgumentParser(
+ parents=[config_parser],
+ description=description,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ obj.invoke_all("add_arguments", parser)
+ args = parser.parse_args(remaining_args)
+
+ if not config_args.config_path:
+ config_parser.error(
+ "Must supply a config file.\nA config file can be automatically"
+ " generated using \"--generate-config -h SERVER_NAME"
+ " -c CONFIG-FILE\""
+ )
+
+ config_dir_path = os.path.dirname(config_args.config_path[0])
+ config_dir_path = os.path.abspath(config_dir_path)
+
specified_config = {}
for config_path in config_args.config_path:
yaml_config = cls.read_config_file(config_path)
@@ -208,15 +228,6 @@ class Config(object):
obj.invoke_all("read_config", config)
- parser = argparse.ArgumentParser(
- parents=[config_parser],
- description=description,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
-
- obj.invoke_all("add_arguments", parser)
- args = parser.parse_args(remaining_args)
-
obj.invoke_all("read_arguments", args)
return obj
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 6811a0e3d1..904c7c0945 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -491,7 +491,7 @@ class FederationClient(FederationBase):
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=True
+ destination, events, outlier=False
)
have_gotten_all_from_destination = True
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 76a9dcd777..1a7cc02f92 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -23,8 +23,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
-from syutil.jsonutil import encode_canonical_json
-
import logging
@@ -71,7 +69,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
- encode_canonical_json(response)
+ response,
)
@defer.inlineCallbacks
@@ -101,5 +99,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
- encode_canonical_json(response_dict)
+ response_dict,
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7c3cf03c8..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -31,7 +31,9 @@ import functools
import simplejson as json
import sys
import time
+import threading
+DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@@ -68,9 +70,20 @@ class Cache(object):
self.name = name
self.keylen = keylen
-
+ self.sequence = 0
+ self.thread = None
caches_by_name[name] = self.cache
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
@@ -82,6 +95,13 @@ class Cache(object):
cache_counter.inc_misses(self.name)
raise KeyError()
+ def update(self, sequence, *args):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(*args)
+
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
@@ -96,9 +116,12 @@ class Cache(object):
self.cache[keyargs] = value
def invalidate(self, *keyargs):
+ self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
-
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
self.cache.pop(keyargs, None)
@@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
@defer.inlineCallbacks
def wrapped(self, *keyargs):
try:
- defer.returnValue(cache.get(*keyargs))
+ cached_result = cache.get(*keyargs)
+ if DEBUG_CACHES:
+ actual_result = yield orig(self, *keyargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ orig.__name__, keyargs,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
except KeyError:
+ # Get the sequence number of the cache before reading from the
+ # database so that we can tell if the cache is invalidated
+ # while the SELECT is executing (SYN-369)
+ sequence = cache.sequence
+
ret = yield orig(self, *keyargs)
- cache.prefill(*keyargs + (ret,))
+ cache.update(sequence, *keyargs + (ret,))
defer.returnValue(ret)
@@ -147,12 +185,20 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine"]
+ __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
- def __init__(self, txn, name, database_engine):
+ def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
+ object.__setattr__(self, "after_callbacks", after_callbacks)
+
+ def call_after(self, callback, *args):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -299,6 +345,8 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
+ after_callbacks = []
+
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn):
@@ -323,10 +371,10 @@ class SQLBaseStore(object):
while True:
try:
txn = conn.cursor()
- return func(
- LoggingTransaction(txn, name, self.database_engine),
- *args, **kwargs
+ txn = LoggingTransaction(
+ txn, name, self.database_engine, after_callbacks
)
+ return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
@@ -375,6 +423,8 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@@ -453,6 +503,14 @@ class SQLBaseStore(object):
if not values:
return
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
keys, vals = zip(*[
zip(
*(sorted(i.items(), key=lambda kv: kv[0]))
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 36b1feac60..74b4e23590 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -332,7 +332,9 @@ class EventFederationStore(SQLBaseStore):
)
txn.execute(query)
- self.get_latest_event_ids_in_room.invalidate(room_id)
+ txn.call_after(
+ self.get_latest_event_ids_in_room.invalidate, room_id
+ )
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 34bd49cfe9..38395c66ab 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore):
current_state=None):
# Remove the any existing cache entries for the event_id
- self._invalidate_get_event_cache(event.event_id)
+ txn.call_after(self._invalidate_get_event_cache, event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@@ -113,19 +113,24 @@ class EventsStore(SQLBaseStore):
keyvalues={"room_id": event.room_id},
)
- self._simple_insert_many_txn(
- txn,
- "current_state_events",
- [
+ for s in current_state:
+ if s.type == EventTypes.Member:
+ txn.call_after(
+ self.get_rooms_for_user.invalidate, s.state_key
+ )
+ txn.call_after(
+ self.get_joined_hosts_for_room.invalidate, s.room_id
+ )
+ self._simple_insert_txn(
+ txn,
+ "current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
- for s in current_state
- ],
- )
+ )
outlier = event.internal_metadata.is_outlier()
@@ -261,7 +266,9 @@ class EventsStore(SQLBaseStore):
)
if context.rejected:
- self._store_rejections_txn(txn, event.event_id, context.rejected)
+ self._store_rejections_txn(
+ txn, event.event_id, context.rejected
+ )
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
@@ -273,7 +280,8 @@ class EventsStore(SQLBaseStore):
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg, hash_bytes
+ txn, event.event_id, prev_event_id, alg,
+ hash_bytes
)
self._simple_insert_many_txn(
@@ -340,9 +348,11 @@ class EventsStore(SQLBaseStore):
}
)
+ return
+
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
- self._invalidate_get_event_cache(event.redacts)
+ txn.call_after(self._invalidate_get_event_cache, event.redacts)
txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 09fb77a194..839c74f63a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore):
}
)
- self.get_rooms_for_user.invalidate(target_user_id)
- self.get_joined_hosts_for_room.invalidate(event.room_id)
+ txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 89dd7d8947..624da4a9dc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached
from collections import namedtuple
+from syutil.jsonutil import encode_canonical_json
import logging
logger = logging.getLogger(__name__)
@@ -82,7 +83,7 @@ class TransactionStore(SQLBaseStore):
"transaction_id": transaction_id,
"origin": origin,
"response_code": code,
- "response_json": response_dict,
+ "response_json": buffer(encode_canonical_json(response_dict)),
},
or_ignore=True,
desc="set_received_txn_response",
@@ -161,7 +162,8 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction(
"delivered_txn",
self._delivered_txn,
- transaction_id, destination, code, response_dict
+ transaction_id, destination, code,
+ buffer(encode_canonical_json(response_dict)),
)
def _delivered_txn(self, txn, transaction_id, destination,
|