diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b5536e8565..e2f84c4d57 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64
import logging
@@ -529,7 +530,8 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
- self.store.insert_client_ip(
+ preserve_context_over_fn(
+ self.store.insert_client_ip,
user=user,
access_token=access_token,
ip=ip_addr,
@@ -574,7 +576,7 @@ class Auth(object):
raise AuthError(
403,
"Application service has not registered this user"
- )
+ )
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -696,6 +698,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
+ logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
@@ -713,6 +716,7 @@ class Auth(object):
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
+ logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 19824f9a02..0fd9b7f244 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
-MEDIA_PREFIX = "/_matrix/media/v1"
+MEDIA_PREFIX = "/_matrix/media/r0"
+LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index bfebb0f644..1bc4279807 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import sys
+sys.dont_write_bytecode = True
+
+from synapse.python_dependencies import (
+ check_requirements, MissingRequirementError
+) # NOQA
+
+try:
+ check_requirements()
+except MissingRequirementError as e:
+ message = "\n".join([
+ "Missing Requirement: %s" % (e.message,),
+ "To install run:",
+ " pip install --upgrade --force \"%s\"" % (e.dependency,),
+ "",
+ ])
+ sys.stderr.writelines(message)
+ sys.exit(1)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e5066c48ef..2b4be7bdd0 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -14,27 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import synapse
+
+import contextlib
+import logging
+import os
+import re
+import resource
+import subprocess
import sys
-from synapse.rest import ClientRestResource
+import time
+from synapse.config._base import ConfigError
-sys.dont_write_bytecode = True
from synapse.python_dependencies import (
- check_requirements, DEPENDENCY_LINKS, MissingRequirementError
+ check_requirements, DEPENDENCY_LINKS
)
-if __name__ == '__main__':
- try:
- check_requirements()
- except MissingRequirementError as e:
- message = "\n".join([
- "Missing Requirement: %s" % (e.message,),
- "To install run:",
- " pip install --upgrade --force \"%s\"" % (e.dependency,),
- "",
- ])
- sys.stderr.writelines(message)
- sys.exit(1)
-
+from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
@@ -60,7 +56,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.api.urls import (
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
- SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
+ SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
@@ -73,17 +69,6 @@ from synapse import events
from daemonize import Daemonize
-import synapse
-
-import contextlib
-import logging
-import os
-import re
-import resource
-import subprocess
-import time
-
-
logger = logging.getLogger("synapse.app.homeserver")
@@ -163,8 +148,10 @@ class SynapseHomeServer(HomeServer):
})
if name in ["media", "federation", "client"]:
+ media_repo = MediaRepositoryResource(self)
resources.update({
- MEDIA_PREFIX: MediaRepositoryResource(self),
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
),
@@ -366,11 +353,20 @@ def setup(config_options):
Returns:
HomeServer
"""
- config = HomeServerConfig.load_config(
- "Synapse Homeserver",
- config_options,
- generate_section="Homeserver"
- )
+ try:
+ config = HomeServerConfig.load_config(
+ "Synapse Homeserver",
+ config_options,
+ generate_section="Homeserver"
+ )
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
+
+ if not config:
+ # If a config isn't returned, and an exception isn't raised, we're just
+ # generating config files and shouldn't try to continue.
+ sys.exit(0)
config.setup_logging()
@@ -690,8 +686,8 @@ def run(hs):
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
- all_rooms = yield hs.get_datastore().get_rooms(False)
- stats["total_room_count"] = len(all_rooms)
+ room_count = yield hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
@@ -713,6 +709,8 @@ def run(hs):
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread():
+ # Uncomment to enable tracing of log context changes.
+ # sys.settrace(logcontext_tracer)
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e1c07028e8..bc90605324 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -29,7 +29,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing.
"""
- def __init__(self, hs):
+ def __init__(self, hs):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index ea9e7907a6..0a3b70e11f 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
@@ -21,7 +22,11 @@ if __name__ == "__main__":
if action == "read":
key = sys.argv[2]
- config = HomeServerConfig.load_config("", sys.argv[3:])
+ try:
+ config = HomeServerConfig.load_config("", sys.argv[3:])
+ except ConfigError as e:
+ sys.stderr.write("\n" + e.message + "\n")
+ sys.exit(1)
print getattr(config, key)
sys.exit(0)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a9304a11ba..15d78ff33a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -17,7 +17,6 @@ import argparse
import errno
import os
import yaml
-import sys
from textwrap import dedent
@@ -136,13 +135,20 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs))
return results
- def generate_config(self, config_dir_path, server_name, report_stats=None):
+ def generate_config(
+ self,
+ config_dir_path,
+ server_name,
+ is_generating_file,
+ report_stats=None,
+ ):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
+ is_generating_file=is_generating_file,
report_stats=report_stats,
))
@@ -244,8 +250,10 @@ class Config(object):
server_name = config_args.server_name
if not server_name:
- print "Must specify a server_name to a generate config for."
- sys.exit(1)
+ raise ConfigError(
+ "Must specify a server_name to a generate config for."
+ " Pass -H server.name."
+ )
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
@@ -253,6 +261,7 @@ class Config(object):
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
+ is_generating_file=True
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
@@ -266,7 +275,7 @@ class Config(object):
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
- sys.exit(0)
+ return
else:
print (
"Config file %r already exists. Generating any missing key"
@@ -302,25 +311,25 @@ class Config(object):
specified_config.update(yaml_config)
if "server_name" not in specified_config:
- sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
- sys.exit(1)
+ raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
_, config = obj.generate_config(
config_dir_path=config_dir_path,
- server_name=server_name
+ server_name=server_name,
+ is_generating_file=False,
)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
- sys.stderr.write(
- "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
- MISSING_REPORT_STATS_SPIEL + "\n")
- sys.exit(1)
+ raise ConfigError(
+ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
+ MISSING_REPORT_STATS_SPIEL
+ )
if generate_keys:
obj.invoke_all("generate_files", config)
- sys.exit(0)
+ return
obj.invoke_all("read_config", config)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ac90cd3fc1..a072aec714 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -22,8 +22,14 @@ from signedjson.key import (
read_signing_keys, write_signing_keys, NACL_ED25519
)
from unpaddedbase64 import decode_base64
+from synapse.util.stringutils import random_string_with_symbols
import os
+import hashlib
+import logging
+
+
+logger = logging.getLogger(__name__)
class KeyConfig(Config):
@@ -40,9 +46,29 @@ class KeyConfig(Config):
config["perspectives"]
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ self.macaroon_secret_key = config.get(
+ "macaroon_secret_key", self.registration_shared_secret
+ )
+
+ if not self.macaroon_secret_key:
+ # Unfortunately, there are people out there that don't have this
+ # set. Lets just be "nice" and derive one from their secret key.
+ logger.warn("Config is missing missing macaroon_secret_key")
+ seed = self.signing_key[0].seed
+ self.macaroon_secret_key = hashlib.sha256(seed)
+
+ def default_config(self, config_dir_path, server_name, is_generating_file=False,
+ **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
+
+ if is_generating_file:
+ macaroon_secret_key = random_string_with_symbols(50)
+ else:
+ macaroon_secret_key = None
+
return """\
+ macaroon_secret_key: "%(macaroon_secret_key)s"
+
## Signing Keys ##
# Path to the signing key to sign messages with
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index d3f4b9d543..ab062d528c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -23,22 +23,23 @@ from distutils.util import strtobool
class RegistrationConfig(Config):
def read_config(self, config):
- self.disable_registration = not bool(
+ self.enable_registration = bool(
strtobool(str(config["enable_registration"]))
)
if "disable_registration" in config:
- self.disable_registration = bool(
+ self.enable_registration = not bool(
strtobool(str(config["disable_registration"]))
)
self.registration_shared_secret = config.get("registration_shared_secret")
- self.macaroon_secret_key = config.get("macaroon_secret_key")
+
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
+ self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
- macaroon_secret_key = random_string_with_symbols(50)
+
return """\
## Registration ##
@@ -49,8 +50,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
- macaroon_secret_key: "%(macaroon_secret_key)s"
-
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
@@ -60,6 +59,12 @@ class RegistrationConfig(Config):
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
+
+ # The list of identity servers trusted to verify third party
+ # identifiers by this server.
+ trusted_third_party_id_servers:
+ - matrix.org
+ - vector.im
""" % locals()
def add_arguments(self, parser):
@@ -71,6 +76,6 @@ class RegistrationConfig(Config):
def read_arguments(self, args):
if args.enable_registration is not None:
- self.disable_registration = not bool(
+ self.enable_registration = bool(
strtobool(str(args.enable_registration))
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index cddec0b2bc..d08ee0aa91 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import (
+ preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+ preserve_fn
+)
from twisted.internet import defer
@@ -142,40 +146,43 @@ class Keyring(object):
for server_name, _ in server_and_json
}
- # We want to wait for any previous lookups to complete before
- # proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
- server_to_deferred,
- )
+ with PreserveLoggingContext():
- # Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
- )
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- server_to_gids = {}
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- handle_key_deferred(
+ preserve_context_over_fn(
+ handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
@@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads
]
if wait_on:
- yield defer.DeferredList(wait_on)
+ with PreserveLoggingContext():
+ yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(deferred)
+ d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d
def rm(r, server_name):
@@ -244,12 +252,13 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
- key_id,
- merged_results[group.server_name][key_id],
- ))
+ with PreserveLoggingContext():
+ group_id_to_deferred[group.group_id].callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
break
else:
missing_groups.setdefault(
@@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store_keys(
+ preserve_fn(self.store_keys)(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store.store_server_keys_json(
+ preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
- self.store.store_server_verify_key(
+ preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f51200d18e..8a475417a6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state
self.state_group = None
self.rejected = False
+ self.push_actions = []
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c6259f9dc8..e30e2da58d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
- expiry_ms=120*1000,
+ expiry_ms=120 * 1000,
reset_expiry_on_get=False,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a97aa0c94a..90718192dd 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -126,10 +126,8 @@ class FederationServer(FederationBase):
results = []
for pdu in pdu_list:
- d = self._handle_new_pdu(transaction.origin, pdu)
-
try:
- yield d
+ yield self._handle_new_pdu(transaction.origin, pdu)
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 622adad3ae..1928da03b3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -103,7 +103,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- @defer.inlineCallbacks
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@@ -141,8 +140,6 @@ class TransactionQueue(object):
deferreds.append(deferred)
- yield defer.DeferredList(deferreds, consumeErrors=True)
-
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 744a9ee507..064e8723c8 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,25 +53,10 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
- def _filter_events_for_clients(self, user_tuples, events):
+ def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to
see.
"""
- # If there is only one user, just get the state for that one user,
- # otherwise just get all the state.
- if len(user_tuples) == 1:
- types = (
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_tuples[0][0]),
- )
- else:
- types = None
-
- event_id_to_state = yield self.store.get_state_for_events(
- frozenset(e.event_id for e in events),
- types=types
- )
-
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
@@ -135,7 +120,17 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest.
- res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
+ types = (
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, user_id),
+ )
+ event_id_to_state = yield self.store.get_state_for_events(
+ frozenset(e.event_id for e in events),
+ types=types
+ )
+ res = yield self._filter_events_for_clients(
+ [(user_id, is_peeking)], events, event_id_to_state
+ )
defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id):
@@ -147,7 +142,7 @@ class BaseHandler(object):
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000*(time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now)),
)
@defer.inlineCallbacks
@@ -269,13 +264,13 @@ class BaseHandler(object):
"You don't have permission to redact events"
)
- (event_stream_id, max_stream_id) = yield self.store.persist_event(
- event, context=context
- )
-
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
- event, self
+ event, context, self
+ )
+
+ (event_stream_id, max_stream_id) = yield self.store.persist_event(
+ event, context=context
)
destinations = set()
@@ -293,19 +288,11 @@ class BaseHandler(object):
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
- notify_d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- notify_d.addErrback(log_failure)
-
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 691564c651..4efecb1ffd 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
- [self.server_name]
- + [s for s in servers if s != self.server_name]
+ [self.server_name] +
+ [s for s in servers if s != self.server_name]
)
else:
servers = list(servers)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 254b483da6..4933c31c19 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
+from synapse.util.logcontext import preserve_context_over_fn
from ._base import BaseHandler
@@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
- return distributor.fire("started_user_eventstream", user)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "started_user_eventstream", user
+ )
def stopped_user_eventstream(distributor, user):
- return distributor.fire("stopped_user_eventstream", user)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "stopped_user_eventstream", user
+ )
class EventStreamHandler(BaseHandler):
@@ -130,7 +137,7 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
- timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
+ timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2ce1e9d6c7..da55d43541 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
@@ -244,12 +236,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- if not backfilled and not event.internal_metadata.is_outlier():
- action_generator = ActionGenerator(self.hs)
- yield action_generator.handle_push_actions_for_event(
- event, self
- )
-
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
@@ -643,19 +629,11 @@ class FederationHandler(BaseHandler):
)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -730,18 +708,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
@@ -811,19 +781,11 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
defer.returnValue(event)
@defer.inlineCallbacks
@@ -948,18 +910,10 @@ class FederationHandler(BaseHandler):
extra_users.append(target_user)
with PreserveLoggingContext():
- d = self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- def log_failure(f):
- logger.warn(
- "Failed to notify about %s: %s",
- event.event_id, f.value
- )
-
- d.addErrback(log_failure)
-
new_pdu = event
destinations = set()
@@ -1113,6 +1067,12 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
+ if not backfilled and not event.internal_metadata.is_outlier():
+ action_generator = ActionGenerator(self.hs)
+ yield action_generator.handle_push_actions_for_event(
+ event, context, self
+ )
+
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 819ec57c4f..656ce124f9 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
self.http_client = hs.get_simple_http_client()
+ self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
+ self.trust_any_id_server_just_for_testing_do_not_use = (
+ hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
+ )
+
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
- # XXX: make this configurable!
- # trustedIdServers = ['matrix.org', 'localhost:8090']
- trustedIdServers = ['matrix.org', 'vector.im']
-
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
@@ -58,10 +59,19 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
- if id_server not in trustedIdServers:
- logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server)
- defer.returnValue(None)
+ if id_server not in self.trusted_id_servers:
+ if self.trust_any_id_server_just_for_testing_do_not_use:
+ logger.warn(
+ "Trusting untrustworthy ID server %r even though it isn't"
+ " in the trusted id list for testing because"
+ " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
+ " is set in the config",
+ id_server,
+ )
+ else:
+ logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
+ 'credentials', id_server)
+ defer.returnValue(None)
data = {}
try:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d36eb3b8d7..b61394f2b5 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
-LAST_ACTIVE_GRANULARITY = 60*1000
+LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000
@@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
- self.start_polling_presence(target_user, state=state)
+ yield self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
- self.stop_polling_presence(target_user)
+ yield self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
@@ -394,7 +394,8 @@ class PresenceHandler(BaseHandler):
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
- self.changed_presencelike_data(user, {"last_active": now})
+ with PreserveLoggingContext():
+ self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
@@ -466,11 +467,12 @@ class PresenceHandler(BaseHandler):
local_user, room_ids=[room_id], add_to_cache=False
)
- self.push_update_to_local_and_remote(
- observed_user=local_user,
- users_to_push=[user],
- statuscache=statuscache,
- )
+ with PreserveLoggingContext():
+ self.push_update_to_local_and_remote(
+ observed_user=local_user,
+ users_to_push=[user],
+ statuscache=statuscache,
+ )
@defer.inlineCallbacks
def send_presence_invite(self, observer_user, observed_user):
@@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
- self.start_polling_presence(
+ yield self.start_polling_presence(
observer_user, target_user=observed_user
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c11b98d0b7..24c850ae9b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
-import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
@@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs)
+ self._next_generated_user_id = None
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor()
@@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler):
Args:
localpart : The local part of the user ID to register. If None,
- one will be randomly generated.
+ one will be generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
@@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler):
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
+ was_guest = guest_access_token is not None
+
+ if not was_guest:
+ try:
+ int(localpart)
+ raise RegistrationError(
+ 400,
+ "Numeric user IDs are reserved for guest users."
+ )
+ except ValueError:
+ pass
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -118,38 +131,36 @@ class RegistrationHandler(BaseHandler):
user_id=user_id,
token=token,
password_hash=password_hash,
- was_guest=guest_access_token is not None,
+ was_guest=was_guest,
make_guest=make_guest,
)
yield registered_user(self.distributor, user)
else:
- # autogen a random user ID
+ # autogen a sequential user ID
attempts = 0
- user_id = None
token = None
- while not user_id:
+ user = None
+ while not user:
+ localpart = yield self._generate_user_id(attempts > 0)
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ yield self.check_user_id_is_valid(user_id)
+ if generate_token:
+ token = self.auth_handler().generate_access_token(user_id)
try:
- localpart = self._generate_user_id()
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
- yield self.check_user_id_is_valid(user_id)
- if generate_token:
- token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
- password_hash=password_hash)
-
- yield registered_user(self.distributor, user)
+ password_hash=password_hash,
+ make_guest=make_guest
+ )
except SynapseError:
# if user id is taken, just generate another
user_id = None
token = None
attempts += 1
- if attempts > 5:
- raise RegistrationError(
- 500, "Cannot generate user ID.")
+ yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
@@ -175,7 +186,7 @@ class RegistrationHandler(BaseHandler):
token=token,
password_hash=""
)
- registered_user(self.distributor, user)
+ yield registered_user(self.distributor, user)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -211,7 +222,7 @@ class RegistrationHandler(BaseHandler):
400,
"User ID must only contain characters which do not"
" require URL encoding."
- )
+ )
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -281,8 +292,16 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
- def _generate_user_id(self):
- return "-" + stringutils.random_string(18)
+ @defer.inlineCallbacks
+ def _generate_user_id(self, reseed=False):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
+
+ id = self._next_generated_user_id
+ self._next_generated_user_id += 1
+ defer.returnValue(str(id))
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 58e2d25f97..a8e3a9029c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -18,13 +18,14 @@ from twisted.internet import defer
from ._base import BaseHandler
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
+from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
@@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
def user_left_room(distributor, user, room_id):
- return distributor.fire("user_left_room", user=user, room_id=room_id)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_left_room", user=user, room_id=room_id
+ )
def user_joined_room(distributor, user, room_id):
- return distributor.fire("user_joined_room", user=user, room_id=room_id)
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_joined_room", user=user, room_id=room_id
+ )
class RoomCreationHandler(BaseHandler):
@@ -876,39 +883,71 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
- chunk = yield self.store.get_rooms(is_public=True)
-
- room_members = yield defer.gatherResults(
- [
- self.store.get_users_in_room(room["room_id"])
- for room in chunk
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
-
- avatar_urls = yield defer.gatherResults(
- [
- self.get_room_avatar_url(room["room_id"])
- for room in chunk
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
-
- for i, room in enumerate(chunk):
- room["num_joined_members"] = len(room_members[i])
- if avatar_urls[i]:
- room["avatar_url"] = avatar_urls[i]
+ room_ids = yield self.store.get_public_room_ids()
+
+ @defer.inlineCallbacks
+ def handle_room(room_id):
+ aliases = yield self.store.get_aliases_for_room(room_id)
+ if not aliases:
+ defer.returnValue(None)
+
+ state = yield self.state_handler.get_current_state(room_id)
+
+ result = {"aliases": aliases, "room_id": room_id}
+
+ name_event = state.get((EventTypes.Name, ""), None)
+ if name_event:
+ name = name_event.content.get("name", None)
+ if name:
+ result["name"] = name
+
+ topic_event = state.get((EventTypes.Topic, ""), None)
+ if topic_event:
+ topic = topic_event.content.get("topic", None)
+ if topic:
+ result["topic"] = topic
+
+ canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
+ if canonical_event:
+ canonical_alias = canonical_event.content.get("alias", None)
+ if canonical_alias:
+ result["canonical_alias"] = canonical_alias
+
+ visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+ visibility = None
+ if visibility_event:
+ visibility = visibility_event.content.get("history_visibility", None)
+ result["world_readable"] = visibility == "world_readable"
+
+ guest_event = state.get((EventTypes.GuestAccess, ""), None)
+ guest = None
+ if guest_event:
+ guest = guest_event.content.get("guest_access", None)
+ result["guest_can_join"] = guest == "can_join"
+
+ avatar_event = state.get(("m.room.avatar", ""), None)
+ if avatar_event:
+ avatar_url = avatar_event.content.get("url", None)
+ if avatar_url:
+ result["avatar_url"] = avatar_url
+
+ result["num_joined_members"] = sum(
+ 1 for (event_type, _), ev in state.items()
+ if event_type == EventTypes.Member and ev.membership == Membership.JOIN
+ )
- # FIXME (erikj): START is no longer a valid value
- defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
+ defer.returnValue(result)
- @defer.inlineCallbacks
- def get_room_avatar_url(self, room_id):
- event = yield self.hs.get_state_handler().get_current_state(
- room_id, "m.room.avatar"
- )
- if event and "url" in event.content:
- defer.returnValue(event.content["url"])
+ result = []
+ for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
+ chunk_result = yield defer.gatherResults([
+ handle_room(room_id)
+ for room_id in chunk
+ ], consumeErrors=True).addErrback(unwrapFirstError)
+ result.extend(v for v in chunk_result if v)
+
+ # FIXME (erikj): START is no longer a valid value
+ defer.returnValue({"start": "START", "end": "END", "chunk": result})
class RoomContextHandler(BaseHandler):
@@ -927,7 +966,7 @@ class RoomContextHandler(BaseHandler):
Returns:
dict, or None if the event isn't found
"""
- before_limit = math.floor(limit/2.)
+ before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -997,6 +1036,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key()
+ from_token = RoomStreamToken.parse(from_key)
+ if from_token.topological:
+ logger.warn("Stream has topological part!!!! %r", from_key)
+ from_key = "s%s" % (from_token.stream,)
+
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
@@ -1008,15 +1052,30 @@ class RoomEventSource(object):
limit=limit,
)
else:
- events, end_key = yield self.store.get_room_events_stream(
- user_id=user.to_string(),
+ room_events = yield self.store.get_membership_changes_for_user(
+ user.to_string(), from_key, to_key
+ )
+
+ room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_ids=room_ids,
from_key=from_key,
to_key=to_key,
- limit=limit,
- room_ids=room_ids,
- is_guest=is_guest,
+ limit=limit or 10,
)
+ events = list(room_events)
+ events.extend(e for evs, _ in room_to_events.values() for e in evs)
+
+ events.sort(key=lambda e: e.internal_metadata.order)
+
+ if limit:
+ events[:] = events[:limit]
+
+ if events:
+ end_key = events[-1].internal_metadata.after
+ else:
+ end_key = to_key
+
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 075566417f..1d0f0058a2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,11 +18,14 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
+from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.metrics import Measure
from twisted.internet import defer
import collections
import logging
+import itertools
logger = logging.getLogger(__name__)
@@ -139,6 +142,15 @@ class SyncHandler(BaseHandler):
A Deferred SyncResult.
"""
+ context = LoggingContext.current_context()
+ if context:
+ if since_token is None:
+ context.tag = "initial_sync"
+ elif full_state:
+ context.tag = "full_state_sync"
+ else:
+ context.tag = "incremental_sync"
+
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
@@ -167,18 +179,6 @@ class SyncHandler(BaseHandler):
else:
return self.incremental_sync_with_gap(sync_config, since_token)
- def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
- if room_id not in ephemeral_by_room:
- return None
- for e in ephemeral_by_room[room_id]:
- if e['type'] != 'm.receipt':
- continue
- for receipt_event_id, val in e['content'].items():
- if 'm.read' in val:
- if user_id in val['m.read']:
- return receipt_event_id
- return None
-
@defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state.
@@ -228,44 +228,51 @@ class SyncHandler(BaseHandler):
invited = []
archived = []
deferreds = []
- for event in room_list:
- if event.membership == Membership.JOIN:
- room_sync_deferred = self.full_state_sync_for_joined_room(
- room_id=event.room_id,
- sync_config=sync_config,
- now_token=now_token,
- timeline_since_token=timeline_since_token,
- ephemeral_by_room=ephemeral_by_room,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
- room_sync_deferred.addCallback(joined.append)
- deferreds.append(room_sync_deferred)
- elif event.membership == Membership.INVITE:
- invite = yield self.store.get_event(event.event_id)
- invited.append(InvitedSyncResult(
- room_id=event.room_id,
- invite=invite,
- ))
- elif event.membership in (Membership.LEAVE, Membership.BAN):
- leave_token = now_token.copy_and_replace(
- "room_key", "s%d" % (event.stream_ordering,)
- )
- room_sync_deferred = self.full_state_sync_for_archived_room(
- sync_config=sync_config,
- room_id=event.room_id,
- leave_event_id=event.event_id,
- leave_token=leave_token,
- timeline_since_token=timeline_since_token,
- tags_by_room=tags_by_room,
- account_data_by_room=account_data_by_room,
- )
- room_sync_deferred.addCallback(archived.append)
- deferreds.append(room_sync_deferred)
- yield defer.gatherResults(
- deferreds, consumeErrors=True
- ).addErrback(unwrapFirstError)
+ room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
+ for room_list_chunk in room_list_chunks:
+ for event in room_list_chunk:
+ if event.membership == Membership.JOIN:
+ room_sync_deferred = preserve_fn(
+ self.full_state_sync_for_joined_room
+ )(
+ room_id=event.room_id,
+ sync_config=sync_config,
+ now_token=now_token,
+ timeline_since_token=timeline_since_token,
+ ephemeral_by_room=ephemeral_by_room,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ )
+ room_sync_deferred.addCallback(joined.append)
+ deferreds.append(room_sync_deferred)
+ elif event.membership == Membership.INVITE:
+ invite = yield self.store.get_event(event.event_id)
+ invited.append(InvitedSyncResult(
+ room_id=event.room_id,
+ invite=invite,
+ ))
+ elif event.membership in (Membership.LEAVE, Membership.BAN):
+ leave_token = now_token.copy_and_replace(
+ "room_key", "s%d" % (event.stream_ordering,)
+ )
+ room_sync_deferred = preserve_fn(
+ self.full_state_sync_for_archived_room
+ )(
+ sync_config=sync_config,
+ room_id=event.room_id,
+ leave_event_id=event.event_id,
+ leave_token=leave_token,
+ timeline_since_token=timeline_since_token,
+ tags_by_room=tags_by_room,
+ account_data_by_room=account_data_by_room,
+ )
+ room_sync_deferred.addCallback(archived.append)
+ deferreds.append(room_sync_deferred)
+
+ yield defer.gatherResults(
+ deferreds, consumeErrors=True
+ ).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@@ -305,7 +312,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
- all_ephemeral_by_room=ephemeral_by_room,
batch=batch,
full_state=True,
)
@@ -355,50 +361,51 @@ class SyncHandler(BaseHandler):
typing events for that room.
"""
- typing_key = since_token.typing_key if since_token else "0"
-
- rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
- room_ids = [room.room_id for room in rooms]
-
- typing_source = self.event_sources.sources["typing"]
- typing, typing_key = yield typing_source.get_new_events(
- user=sync_config.user,
- from_key=typing_key,
- limit=sync_config.filter_collection.ephemeral_limit(),
- room_ids=room_ids,
- is_guest=sync_config.is_guest,
- )
- now_token = now_token.copy_and_replace("typing_key", typing_key)
-
- ephemeral_by_room = {}
-
- for event in typing:
- # we want to exclude the room_id from the event, but modifying the
- # result returned by the event source is poor form (it might cache
- # the object)
- room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
- if k != "room_id"}
- ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+ with Measure(self.clock, "ephemeral_by_room"):
+ typing_key = since_token.typing_key if since_token else "0"
- receipt_key = since_token.receipt_key if since_token else "0"
+ rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
+ room_ids = [room.room_id for room in rooms]
- receipt_source = self.event_sources.sources["receipt"]
- receipts, receipt_key = yield receipt_source.get_new_events(
- user=sync_config.user,
- from_key=receipt_key,
- limit=sync_config.filter_collection.ephemeral_limit(),
- room_ids=room_ids,
- is_guest=sync_config.is_guest,
- )
- now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+ typing_source = self.event_sources.sources["typing"]
+ typing, typing_key = yield typing_source.get_new_events(
+ user=sync_config.user,
+ from_key=typing_key,
+ limit=sync_config.filter_collection.ephemeral_limit(),
+ room_ids=room_ids,
+ is_guest=sync_config.is_guest,
+ )
+ now_token = now_token.copy_and_replace("typing_key", typing_key)
+
+ ephemeral_by_room = {}
+
+ for event in typing:
+ # we want to exclude the room_id from the event, but modifying the
+ # result returned by the event source is poor form (it might cache
+ # the object)
+ room_id = event["room_id"]
+ event_copy = {k: v for (k, v) in event.iteritems()
+ if k != "room_id"}
+ ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+
+ receipt_key = since_token.receipt_key if since_token else "0"
+
+ receipt_source = self.event_sources.sources["receipt"]
+ receipts, receipt_key = yield receipt_source.get_new_events(
+ user=sync_config.user,
+ from_key=receipt_key,
+ limit=sync_config.filter_collection.ephemeral_limit(),
+ room_ids=room_ids,
+ is_guest=sync_config.is_guest,
+ )
+ now_token = now_token.copy_and_replace("receipt_key", receipt_key)
- for event in receipts:
- room_id = event["room_id"]
- # exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
- if k != "room_id"}
- ephemeral_by_room.setdefault(room_id, []).append(event_copy)
+ for event in receipts:
+ room_id = event["room_id"]
+ # exclude room id, as above
+ event_copy = {k: v for (k, v) in event.iteritems()
+ if k != "room_id"}
+ ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
@@ -438,13 +445,6 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
- # We now fetch all ephemeral events for this room in order to get
- # this users current read receipt. This could almost certainly be
- # optimised.
- _, all_ephemeral_by_room = yield self.ephemeral_by_room(
- sync_config, now_token
- )
-
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
@@ -478,7 +478,7 @@ class SyncHandler(BaseHandler):
)
# Get a list of membership change events that have happened.
- rooms_changed = yield self.store.get_room_changes_for_user(
+ rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
@@ -576,7 +576,6 @@ class SyncHandler(BaseHandler):
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
- all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch,
full_state=full_state,
)
@@ -606,58 +605,64 @@ class SyncHandler(BaseHandler):
"""
:returns a Deferred TimelineBatch
"""
- filtering_factor = 2
- timeline_limit = sync_config.filter_collection.timeline_limit()
- load_limit = max(timeline_limit * filtering_factor, 10)
- max_repeat = 5 # Only try a few times per room, otherwise
- room_key = now_token.room_key
- end_key = room_key
-
- limited = recents is None or newly_joined_room or timeline_limit < len(recents)
-
- if recents is not None:
- recents = sync_config.filter_collection.filter_room_timeline(recents)
- recents = yield self._filter_events_for_client(
- sync_config.user.to_string(),
- recents,
- is_peeking=sync_config.is_guest,
- )
- else:
- recents = []
-
- since_key = None
- if since_token and not newly_joined_room:
- since_key = since_token.room_key
-
- while limited and len(recents) < timeline_limit and max_repeat:
- events, end_key = yield self.store.get_room_events_stream_for_room(
- room_id,
- limit=load_limit + 1,
- from_key=since_key,
- to_key=end_key,
- )
- loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
- loaded_recents = yield self._filter_events_for_client(
- sync_config.user.to_string(),
- loaded_recents,
- is_peeking=sync_config.is_guest,
- )
- loaded_recents.extend(recents)
- recents = loaded_recents
-
- if len(events) <= load_limit:
+ with Measure(self.clock, "load_filtered_recents"):
+ filtering_factor = 2
+ timeline_limit = sync_config.filter_collection.timeline_limit()
+ load_limit = max(timeline_limit * filtering_factor, 10)
+ max_repeat = 5 # Only try a few times per room, otherwise
+ room_key = now_token.room_key
+ end_key = room_key
+
+ if recents is None or newly_joined_room or timeline_limit < len(recents):
+ limited = True
+ else:
limited = False
- break
- max_repeat -= 1
- if len(recents) > timeline_limit:
- limited = True
- recents = recents[-timeline_limit:]
- room_key = recents[0].internal_metadata.before
+ if recents is not None:
+ recents = sync_config.filter_collection.filter_room_timeline(recents)
+ recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(),
+ recents,
+ is_peeking=sync_config.is_guest,
+ )
+ else:
+ recents = []
+
+ since_key = None
+ if since_token and not newly_joined_room:
+ since_key = since_token.room_key
+
+ while limited and len(recents) < timeline_limit and max_repeat:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ loaded_recents = sync_config.filter_collection.filter_room_timeline(
+ events
+ )
+ loaded_recents = yield self._filter_events_for_client(
+ sync_config.user.to_string(),
+ loaded_recents,
+ is_peeking=sync_config.is_guest,
+ )
+ loaded_recents.extend(recents)
+ recents = loaded_recents
- prev_batch_token = now_token.copy_and_replace(
- "room_key", room_key
- )
+ if len(events) <= load_limit:
+ limited = False
+ break
+ max_repeat -= 1
+
+ if len(recents) > timeline_limit:
+ limited = True
+ recents = recents[-timeline_limit:]
+ room_key = recents[0].internal_metadata.before
+
+ prev_batch_token = now_token.copy_and_replace(
+ "room_key", room_key
+ )
defer.returnValue(TimelineBatch(
events=recents,
@@ -670,37 +675,11 @@ class SyncHandler(BaseHandler):
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room,
- all_ephemeral_by_room,
batch, full_state=False):
- if full_state:
- state = yield self.get_state_at(room_id, now_token)
-
- elif batch.limited:
- current_state = yield self.get_state_at(room_id, now_token)
-
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=current_state,
- )
- else:
- state = {
- (event.type, event.state_key): event
- for event in batch.events if event.is_state()
- }
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- state = yield self.get_state_at(room_id, now_token)
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
- }
+ state = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token,
+ full_state=full_state
+ )
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@@ -726,14 +705,12 @@ class SyncHandler(BaseHandler):
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config, all_ephemeral_by_room
+ room_id, sync_config
)
if notifs is not None:
- unread_notifications["notification_count"] = len(notifs)
- unread_notifications["highlight_count"] = len([
- 1 for notif in notifs if _action_has_highlight(notif["actions"])
- ])
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
logger.debug("Room sync: %r", room_sync)
@@ -766,30 +743,11 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch)
- state_events_at_leave = yield self.store.get_state_for_event(
- leave_event_id
+ state_events_delta = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, leave_token,
+ full_state=full_state
)
- if not full_state:
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state_events_delta = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=state_events_at_leave,
- )
- else:
- state_events_delta = state_events_at_leave
-
- state_events_delta = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state_events_delta.values()
- )
- }
-
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
@@ -843,15 +801,19 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
- def compute_state_delta(self, since_token, previous_state, current_state):
- """ Works out the differnce in state between the current state and the
- state the client got when it last performed a sync.
-
- :param str since_token: the point we are comparing against
- :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
- state to compare to
- :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
- new state
+ @defer.inlineCallbacks
+ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+ full_state):
+ """ Works out the differnce in state between the start of the timeline
+ and the previous sync.
+
+ :param str room_id
+ :param TimelineBatch batch: The timeline batch for the room that will
+ be sent to the user.
+ :param sync_config
+ :param str since_token: Token of the end of the previous batch. May be None.
+ :param str now_token: Token of the end of the current batch.
+ :param bool full_state: Whether to force returning the full state.
:returns A new event dictionary
"""
@@ -860,12 +822,53 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
- state_delta = {}
- for key, event in current_state.iteritems():
- if (key not in previous_state or
- previous_state[key].event_id != event.event_id):
- state_delta[key] = event
- return state_delta
+ with Measure(self.clock, "compute_state_delta"):
+ if full_state:
+ if batch:
+ state = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+ else:
+ state = yield self.get_state_at(
+ room_id, stream_position=now_token
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state,
+ previous={},
+ )
+ elif batch.limited:
+ state_at_previous_sync = yield self.get_state_at(
+ room_id, stream_position=since_token
+ )
+
+ state_at_timeline_start = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state_at_timeline_start,
+ previous=state_at_previous_sync,
+ )
+ else:
+ state = {}
+
+ defer.returnValue({
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(state.values())
+ })
def check_joined_room(self, sync_config, state_delta):
"""
@@ -886,21 +889,24 @@ class SyncHandler(BaseHandler):
return False
@defer.inlineCallbacks
- def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
- last_unread_event_id = self.last_read_event_id_for_room_and_user(
- room_id, sync_config.user.to_string(), ephemeral_by_room
- )
-
- notifs = []
- if last_unread_event_id:
- notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
- room_id, sync_config.user.to_string(), last_unread_event_id
+ def unread_notifs_for_room_id(self, room_id, sync_config):
+ with Measure(self.clock, "unread_notifs_for_room_id"):
+ last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+ user_id=sync_config.user.to_string(),
+ room_id=room_id,
+ receipt_type="m.read"
)
- defer.returnValue(notifs)
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- defer.returnValue(None)
+ notifs = []
+ if last_unread_event_id:
+ notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ defer.returnValue(notifs)
+
+ # There is no new information in this period, so your notification
+ # count is whatever it was last time.
+ defer.returnValue(None)
def _action_has_highlight(actions):
@@ -912,3 +918,37 @@ def _action_has_highlight(actions):
pass
return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+ """Works out what state to include in a sync response.
+
+ Args:
+ timeline_contains (dict): state in the timeline
+ timeline_start (dict): state at the start of the timeline
+ previous (dict): state at the end of the previous sync (or empty dict
+ if this is an initial sync)
+
+ Returns:
+ dict
+ """
+ event_id_to_state = {
+ e.event_id: e
+ for e in itertools.chain(
+ timeline_contains.values(),
+ previous.values(),
+ timeline_start.values(),
+ )
+ }
+
+ tc_ids = set(e.event_id for e in timeline_contains.values())
+ p_ids = set(e.event_id for e in previous.values())
+ ts_ids = set(e.event_id for e in timeline_start.values())
+
+ state_ids = (ts_ids - p_ids) - tc_ids
+
+ evs = (event_id_to_state[e] for e in state_ids)
+ return {
+ (e.type, e.state_key): e
+ for e in evs
+ }
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 43bf600913..b16d0017df 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.metrics import Measure
from synapse.types import UserID
import logging
@@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
+ self.clock = hs.get_clock()
self._handler = None
self._room_member_handler = None
@@ -247,19 +249,20 @@ class TypingNotificationEventSource(object):
}
def get_new_events(self, from_key, room_ids, **kwargs):
- from_key = int(from_key)
- handler = self.handler()
+ with Measure(self.clock, "typing.get_new_events"):
+ from_key = int(from_key)
+ handler = self.handler()
- events = []
- for room_id in room_ids:
- if room_id not in handler._room_serials:
- continue
- if handler._room_serials[room_id] <= from_key:
- continue
+ events = []
+ for room_id in room_ids:
+ if room_id not in handler._room_serials:
+ continue
+ if handler._room_serials[room_id] <= from_key:
+ continue
- events.append(self._make_event_for(room_id))
+ events.append(self._make_event_for(room_id))
- return events, handler._latest_room_serial
+ return events, handler._latest_room_serial
def get_current_key(self):
return self.handler()._latest_room_serial
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index da13e32e78..c3589534f8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred(
request_deferred,
- time_out=timeout/1000. if timeout else 60,
+ time_out=timeout / 1000. if timeout else 60,
)
response = yield preserve_context_over_fn(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 10d1fcd3f6..a90e2e1125 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter(
"requests",
- labels=["method", "servlet"],
+ labels=["method", "servlet", "tag"],
)
outgoing_responses_counter = metrics.register_counter(
"responses",
@@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
response_timer = metrics.register_distribution(
"response_time",
- labels=["method", "servlet"]
+ labels=["method", "servlet", "tag"]
)
response_ru_utime = metrics.register_distribution(
- "response_ru_utime", labels=["method", "servlet"]
+ "response_ru_utime", labels=["method", "servlet", "tag"]
)
response_ru_stime = metrics.register_distribution(
- "response_ru_stime", labels=["method", "servlet"]
+ "response_ru_stime", labels=["method", "servlet", "tag"]
)
response_db_txn_count = metrics.register_distribution(
- "response_db_txn_count", labels=["method", "servlet"]
+ "response_db_txn_count", labels=["method", "servlet", "tag"]
)
response_db_txn_duration = metrics.register_distribution(
- "response_db_txn_duration", labels=["method", "servlet"]
+ "response_db_txn_duration", labels=["method", "servlet", "tag"]
)
@@ -99,9 +99,8 @@ def request_handler(request_handler):
request_context.request = request_id
with request.processing():
try:
- d = request_handler(self, request)
- with PreserveLoggingContext():
- yield d
+ with PreserveLoggingContext(request_context):
+ yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
@@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
+
+ start_context = LoggingContext.current_context()
+
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
- incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
- response_timer.inc_by(
- self.clock.time_msec() - start, request.method, servlet_classname
- )
-
try:
context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+ response_timer.inc_by(
+ self.clock.time_msec() - start, request.method,
+ servlet_classname, tag
+ )
+
ru_utime, ru_stime = context.get_resource_usage()
- response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
- response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
+ response_ru_utime.inc_by(
+ ru_utime, request.method, servlet_classname, tag
+ )
+ response_ru_stime.inc_by(
+ ru_stime, request.method, servlet_classname, tag
+ )
response_db_txn_count.inc_by(
- context.db_txn_count, request.method, servlet_classname
+ context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
- context.db_txn_duration, request.method, servlet_classname
+ context.db_txn_duration, request.method, servlet_classname, tag
)
except:
pass
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6eaa65071e..560866b26e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, ObservableDeferred
+from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken
import synapse.metrics
+from collections import namedtuple
+
import logging
@@ -71,7 +74,8 @@ class _NotifierUserStream(object):
self.current_token = current_token
self.last_notified_ms = time_now_ms
- self.notify_deferred = ObservableDeferred(defer.Deferred())
+ with PreserveLoggingContext():
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
@@ -86,8 +90,10 @@ class _NotifierUserStream(object):
)
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
- self.notify_deferred = ObservableDeferred(defer.Deferred())
- noify_deferred.callback(self.current_token)
+
+ with PreserveLoggingContext():
+ self.notify_deferred = ObservableDeferred(defer.Deferred())
+ noify_deferred.callback(self.current_token)
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
@@ -118,6 +124,11 @@ class _NotifierUserStream(object):
return _NotificationListener(self.notify_deferred.observe())
+class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
+ def __nonzero__(self):
+ return bool(self.events)
+
+
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
new events available for it.
@@ -177,8 +188,6 @@ class Notifier(object):
lambda: count(bool, self.appservice_to_user_streams.values()),
)
- @log_function
- @defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -192,12 +201,11 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- yield run_on_reactor()
-
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
- self._notify_pending_new_room_events(max_room_stream_id)
+ with PreserveLoggingContext():
+ self.pending_new_room_events.append((
+ room_stream_id, event, extra_users
+ ))
+ self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
@@ -244,31 +252,29 @@ class Notifier(object):
extra_streams=app_streams,
)
- @defer.inlineCallbacks
- @log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
- yield run_on_reactor()
- user_streams = set()
+ with PreserveLoggingContext():
+ user_streams = set()
- for user in users:
- user_stream = self.user_to_user_stream.get(str(user))
- if user_stream is not None:
- user_streams.add(user_stream)
+ for user in users:
+ user_stream = self.user_to_user_stream.get(str(user))
+ if user_stream is not None:
+ user_streams.add(user_stream)
- for room in rooms:
- user_streams |= self.room_to_user_streams.get(room, set())
+ for room in rooms:
+ user_streams |= self.room_to_user_streams.get(room, set())
- time_now_ms = self.clock.time_msec()
- for user_stream in user_streams:
- try:
- user_stream.notify(stream_key, new_token, time_now_ms)
- except:
- logger.exception("Failed to notify listener")
+ time_now_ms = self.clock.time_msec()
+ for user_stream in user_streams:
+ try:
+ user_stream.notify(stream_key, new_token, time_now_ms)
+ except:
+ logger.exception("Failed to notify listener")
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -301,7 +307,7 @@ class Notifier(object):
def timed_out():
if listener:
listener.deferred.cancel()
- timer = self.clock.call_later(timeout/1000., timed_out)
+ timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token
while not result:
@@ -318,7 +324,8 @@ class Notifier(object):
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
- yield listener.deferred
+ with PreserveLoggingContext():
+ yield listener.deferred
except defer.CancelledError:
break
@@ -356,7 +363,7 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
- defer.returnValue(None)
+ defer.returnValue(EventStreamResult([], (from_token, from_token)))
events = []
end_token = from_token
@@ -369,6 +376,7 @@ class Notifier(object):
continue
if only_keys and name not in only_keys:
continue
+
new_events, new_key = yield source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
@@ -388,10 +396,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- if events:
- defer.returnValue((events, (from_token, end_token)))
- else:
- defer.returnValue(None)
+ defer.returnValue(EventStreamResult(events, (from_token, end_token)))
user_id_for_stream = user.to_string()
if is_peeking:
@@ -415,9 +420,6 @@ class Notifier(object):
from_token=from_token,
)
- if result is None:
- result = ([], (from_token, from_token))
-
defer.returnValue(result)
@defer.inlineCallbacks
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9bc0b356f4..8da2d8716c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -17,6 +17,8 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken
+from synapse.util.logcontext import LoggingContext
+from synapse.util.metrics import Measure
import synapse.util.async
import push_rule_evaluator as push_rule_evaluator
@@ -27,6 +29,16 @@ import random
logger = logging.getLogger(__name__)
+_NEXT_ID = 1
+
+
+def _get_next_id():
+ global _NEXT_ID
+ _id = _NEXT_ID
+ _NEXT_ID += 1
+ return _id
+
+
# Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the
# rules again.
@@ -57,6 +69,8 @@ class Pusher(object):
self.alive = True
self.badge = None
+ self.name = "Pusher-%d" % (_get_next_id(),)
+
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@@ -86,38 +100,46 @@ class Pusher(object):
@defer.inlineCallbacks
def start(self):
- if not self.last_token:
- # First-time setup: get a token to start from (we can't
- # just start from no token, ie. 'now'
- # because we need the result to be reproduceable in case
- # we fail to dispatch the push)
- config = PaginationConfig(from_token=None, limit='1')
- chunk = yield self.evStreamHandler.get_stream(
- self.user_id, config, timeout=0, affect_presence=False
- )
- self.last_token = chunk['end']
- self.store.update_pusher_last_token(
- self.app_id, self.pushkey, self.user_id, self.last_token
- )
- logger.info("Pusher %s for user %s starting from token %s",
- self.pushkey, self.user_id, self.last_token)
-
- wait = 0
- while self.alive:
- try:
- if wait > 0:
- yield synapse.util.async.sleep(wait)
- yield self.get_and_dispatch()
- wait = 0
- except:
- if wait == 0:
- wait = 1
- else:
- wait = min(wait * 2, 1800)
- logger.exception(
- "Exception in pusher loop for pushkey %s. Pausing for %ds",
- self.pushkey, wait
+ with LoggingContext(self.name):
+ if not self.last_token:
+ # First-time setup: get a token to start from (we can't
+ # just start from no token, ie. 'now'
+ # because we need the result to be reproduceable in case
+ # we fail to dispatch the push)
+ config = PaginationConfig(from_token=None, limit='1')
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_id, config, timeout=0, affect_presence=False
)
+ self.last_token = chunk['end']
+ yield self.store.update_pusher_last_token(
+ self.app_id, self.pushkey, self.user_id, self.last_token
+ )
+ logger.info("New pusher %s for user %s starting from token %s",
+ self.pushkey, self.user_id, self.last_token)
+
+ else:
+ logger.info(
+ "Old pusher %s for user %s starting",
+ self.pushkey, self.user_id,
+ )
+
+ wait = 0
+ while self.alive:
+ try:
+ if wait > 0:
+ yield synapse.util.async.sleep(wait)
+ with Measure(self.clock, "push"):
+ yield self.get_and_dispatch()
+ wait = 0
+ except:
+ if wait == 0:
+ wait = 1
+ else:
+ wait = min(wait * 2, 1800)
+ logger.exception(
+ "Exception in pusher loop for pushkey %s. Pausing for %ds",
+ self.pushkey, wait
+ )
@defer.inlineCallbacks
def get_and_dispatch(self):
@@ -316,7 +338,7 @@ class Pusher(object):
r.room_id, self.user_id, last_unread_event_id
)
)
- badge += len(notifs)
+ badge += notifs["notify_count"]
defer.returnValue(badge)
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 1d2e558f9a..e0da0868ec 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -19,8 +19,6 @@ import bulk_push_rule_evaluator
import logging
-from synapse.api.constants import EventTypes
-
logger = logging.getLogger(__name__)
@@ -36,21 +34,15 @@ class ActionGenerator:
# tag (ie. we just need all the users).
@defer.inlineCallbacks
- def handle_push_actions_for_event(self, event, handler):
- if event.type == EventTypes.Redaction and event.redacts is not None:
- yield self.store.remove_push_actions_for_event_id(
- event.room_id, event.redacts
- )
-
+ def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.hs, self.store
)
- actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
-
- yield self.store.set_push_actions_for_event_and_users(
- event,
- [
- (uid, None, actions) for uid, actions in actions_by_user.items()
- ]
+ actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+ event, handler, context.current_state
)
+
+ context.push_actions = [
+ (uid, None, actions) for uid, actions in actions_by_user.items()
+ ]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 20c60422bf..8ac5ceb9ef 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -98,25 +98,21 @@ class BulkPushRuleEvaluator:
self.store = store
@defer.inlineCallbacks
- def action_for_event_by_user(self, event, handler):
+ def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
filtered_by_user = yield handler._filter_events_for_clients(
- users_dict.items(), [event]
+ users_dict.items(), [event], {event.event_id: current_state}
)
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}
- member_state = yield self.store.get_state_for_event(
- event.event_id,
- )
-
display_names = {}
- for ev in member_state.values():
+ for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index dca018af95..2a2b4437dc 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
- _flatten_dict(value, prefix=(prefix+[key]), result=result)
+ _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d1b7c0802f..d7dcb2de4b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
+from synapse.util.logcontext import preserve_fn
import logging
@@ -76,7 +77,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
- self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id):
@@ -91,7 +92,7 @@ class PusherPool:
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
- self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@@ -110,7 +111,7 @@ class PusherPool:
lang=lang,
data=data,
)
- self._refresh_pusher(app_id, pushkey, user_id)
+ yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
@@ -166,7 +167,7 @@ class PusherPool:
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
- p.start()
+ preserve_fn(p.start)()
logger.info("Started pushers")
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 07836709fb..7199113dac 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
- relay_state = "&RelayState="+urllib.quote(
+ relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b15defdd07..3c5a212920 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -33,7 +33,11 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {"displayname": displayname}))
+ ret = {}
+ if displayname is not None:
+ ret["displayname"] = displayname
+
+ defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -66,7 +70,11 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {"avatar_url": avatar_url}))
+ ret = {}
+ if avatar_url is not None:
+ ret["avatar_url"] = avatar_url
+
+ defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
user,
)
- defer.returnValue((200, {
- "displayname": displayname,
- "avatar_url": avatar_url
- }))
+ ret = {}
+ if displayname is not None:
+ ret["displayname"] = displayname
+ if avatar_url is not None:
+ ret["avatar_url"] = avatar_url
+
+ defer.returnValue((200, ret))
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index e218ed215c..5547f1b112 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
if i not in content:
missing.append(i)
if len(missing):
- raise SynapseError(400, "Missing parameters: "+','.join(missing),
+ raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
data=content['data']
)
except PusherConfigException as pce:
- raise SynapseError(400, "Config Error: "+pce.message,
+ raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 5378a9a938..6d6d03c34c 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
- compare_digest = lambda a, b: a == b
+ def compare_digest(a, b):
+ return a == b
class RegisterRestServlet(ClientV1RestServlet):
@@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# }
# TODO: persistent storage
self.sessions = {}
- self.disable_registration = hs.config.disable_registration
+ self.enable_registration = hs.config.enable_registration
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
- not self.disable_registration
+ self.enable_registration
or is_application_server
or is_using_shared_secret
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c7ea15c624..81bfe377bd 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -429,8 +429,6 @@ class RoomEventContext(ClientV1RestServlet):
serialize_event(event, time_now) for event in results["state"]
]
- logger.info("Responding with %r", results)
-
defer.returnValue((200, results))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d507172704..a614b79d45 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -116,9 +116,10 @@ class ThreepidRestServlet(RestServlet):
body = parse_json_dict_from_request(request)
- if 'threePidCreds' not in body:
+ threePidCreds = body.get('threePidCreds')
+ threePidCreds = body.get('three_pid_creds', threePidCreds)
+ if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
- threePidCreds = body['threePidCreds']
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 985efe2a62..1456881c1a 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body
)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
user_id, room_id, account_data_type, body
)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c4d025b465..ec5c21fa1f 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
- compare_digest = lambda a, b: a == b
+ def compare_digest(a, b):
+ return a == b
logger = logging.getLogger(__name__)
@@ -116,7 +117,7 @@ class RegisterRestServlet(RestServlet):
return
# == Normal User Registration == (everyone else)
- if self.hs.config.disable_registration:
+ if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None)
@@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
desired_username = params.get("username", None)
new_password = params.get("password", None)
+ guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 07b5b5dfd5..140ce2704b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -20,7 +20,6 @@ from synapse.http.servlet import (
)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
-from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
@@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state
timeline_events = room.timeline.events
- state_dict = SyncRestServlet._rollback_state_for_timeline(
- state_dict, timeline_events)
-
state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events]
@@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result
- @staticmethod
- def _rollback_state_for_timeline(state, timeline):
- """
- Wind the state dictionary backwards, so that it represents the
- state at the start of the timeline, rather than at the end.
-
- :param dict[(str, str), synapse.events.EventBase] state: the
- state dictionary. Will be updated to the state before the timeline.
- :param list[synapse.events.EventBase] timeline: the event timeline
- :return: updated state dictionary
- """
-
- result = state.copy()
-
- for timeline_event in reversed(timeline):
- if not timeline_event.is_state():
- continue
-
- event_key = (timeline_event.type, timeline_event.state_key)
-
- logger.debug("Considering %s for removal", event_key)
-
- state_event = result.get(event_key)
- if (state_event is None or
- state_event.event_id != timeline_event.event_id):
- # the event in the timeline isn't present in the state
- # dictionary.
- #
- # the most likely cause for this is that there was a fork in
- # the event graph, and the state is no longer valid. Really,
- # the event shouldn't be in the timeline. We're going to ignore
- # it for now, however.
- logger.debug("Found state event %r in timeline which doesn't "
- "match state dictionary", timeline_event)
- continue
-
- prev_event_id = timeline_event.unsigned.get("replaces_state", None)
-
- prev_content = timeline_event.unsigned.get('prev_content')
- prev_sender = timeline_event.unsigned.get('prev_sender')
- # Empircally it seems possible for the event to have a
- # "replaces_state" key but not a prev_content or prev_sender
- # markjh conjectures that it could be due to the server not
- # having a copy of that event.
- # If this is the case the we ignore the previous event. This will
- # cause the displayname calculations on the client to be incorrect
- if prev_event_id is None or not prev_content or not prev_sender:
- logger.debug(
- "Removing %r from the state dict, as it is missing"
- " prev_content (prev_event_id=%r)",
- timeline_event.event_id, prev_event_id
- )
- del result[event_key]
- else:
- logger.debug(
- "Replacing %r with %r in state dict",
- timeline_event.event_id, prev_event_id
- )
- result[event_key] = FrozenEvent({
- "type": timeline_event.type,
- "state_key": timeline_event.state_key,
- "content": prev_content,
- "sender": prev_sender,
- "event_id": prev_event_id,
- "room_id": timeline_event.room_id,
- })
-
- logger.debug("New value: %r", result.get(event_key))
-
- return result
-
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 42f2203f3d..79c436a8cf 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -80,7 +80,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@@ -94,7 +94,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
- yield self.notifier.on_new_event(
+ self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 349ef6b396..ca5468c402 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
- "versions": [
- "r0.0.1",
- ]
+ "versions": ["r0.0.1"]
})
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index bdc65f0198..58d56ec7a4 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -28,6 +28,7 @@ from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
+from synapse.util.logcontext import preserve_context_over_fn
import os
@@ -276,7 +277,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
- t_len = yield threads.deferToThread(
+ t_len = yield preserve_context_over_fn(
+ threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@@ -298,7 +300,8 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
- t_len = yield threads.deferToThread(
+ t_len = yield preserve_context_over_fn(
+ threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
@@ -372,7 +375,7 @@ class BaseMediaResource(Resource):
media_id, t_width, t_height, t_type, t_method, t_len
))
- yield threads.deferToThread(generate_thumbnails)
+ yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
@@ -445,7 +448,7 @@ class BaseMediaResource(Resource):
t_width, t_height, t_type, t_method, t_len
])
- yield threads.deferToThread(generate_thumbnails)
+ yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
diff --git a/synapse/server.py b/synapse/server.py
index 5fee7fe130..368d615576 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,7 +23,7 @@ from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.federation import initialize_http_replication
-from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
+from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
diff --git a/synapse/state.py b/synapse/state.py
index 0acf309fe0..b9a1387520 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -63,7 +63,7 @@ class StateHandler(object):
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
reset_expiry_on_get=True,
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index eb88842308..5a9e7720d9 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,9 +45,10 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
-
from util.id_generators import IdGenerator, StreamIdGenerator
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
import logging
@@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
-LAST_SEEN_GRANULARITY = 120*1000
+LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore,
@@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
def __init__(self, db_conn, hs):
self.hs = hs
+ self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
@@ -117,8 +119,61 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ events_max = self._stream_id_gen.get_max_token(None)
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
+ )
+
+ account_max = self._account_data_id_gen.get_max_token(None)
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
super(DataStore, self).__init__(hs)
+ def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - 100000"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ }
+
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+ rows = txn.fetchall()
+
+ cache = {
+ row[0]: int(row[1])
+ for row in rows
+ }
+
+ if cache:
+ min_val = min(cache.values())
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 5e77320540..2e97ac84a8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,7 +15,7 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
@@ -185,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
- ratio = (curr - prev)/(time_now - time_then)
+ ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
@@ -298,10 +298,10 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
@@ -326,10 +326,10 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
defer.returnValue(result)
@@ -643,7 +643,10 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
- chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
+ chunks = [
+ iterable[i:i + batch_size]
+ for i in xrange(0, len(iterable), batch_size)
+ ]
for chunk in chunks:
rows = yield self.runInteraction(
desc,
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index ed6587429b..b8387fc500 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
import ujson as json
@@ -24,14 +23,6 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
- def __init__(self, hs):
- super(AccountDataStore, self).__init__(hs)
-
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache",
- self._account_data_id_gen.get_max_token(None),
- max_size=10000,
- )
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
@@ -166,6 +157,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index b5aa55c0a3..1100c67714 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
"application_services_state",
dict(as_id=service.id),
["state"],
- allow_none=True
+ allow_none=True,
+ desc="get_appservice_state",
)
if result:
defer.returnValue(result.get("state"))
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 400c10103c..91fac33b8b 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -54,7 +54,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5f32eec6f8..ce2c794025 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set()
front_list = list(front)
chunks = [
- front_list[x:x+100]
+ front_list[x:x + 100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a05c4f84cf..d77a817682 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -24,8 +24,7 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
- @defer.inlineCallbacks
- def set_push_actions_for_event_and_users(self, event, tuples):
+ def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
:param tuples: list of tuples of (user_id, profile_tag, actions)
@@ -37,21 +36,19 @@ class EventPushActionsStore(SQLBaseStore):
'event_id': event.event_id,
'user_id': uid,
'profile_tag': profile_tag,
- 'actions': json.dumps(actions)
+ 'actions': json.dumps(actions),
+ 'stream_ordering': event.internal_metadata.stream_ordering,
+ 'topological_ordering': event.depth,
+ 'notif': 1,
+ 'highlight': 1 if _action_has_highlight(actions) else 0,
})
- def f(txn):
- for uid, _, __ in tuples:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid)
- )
- return self._simple_insert_many_txn(txn, "event_push_actions", values)
-
- yield self.runInteraction(
- "set_actions_for_event_and_users",
- f,
- )
+ for uid, _, __ in tuples:
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid)
+ )
+ self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True)
def get_unread_event_push_actions_by_room_for_user(
@@ -68,32 +65,34 @@ class EventPushActionsStore(SQLBaseStore):
)
results = txn.fetchall()
if len(results) == 0:
- return []
+ return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
topological_ordering = results[0][1]
sql = (
- "SELECT ea.event_id, ea.actions"
- " FROM event_push_actions ea, events e"
- " WHERE ea.room_id = e.room_id"
- " AND ea.event_id = e.event_id"
- " AND ea.user_id = ?"
- " AND ea.room_id = ?"
+ "SELECT sum(notif), sum(highlight)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " user_id = ?"
+ " AND room_id = ?"
" AND ("
- " e.topological_ordering > ?"
- " OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
+ " topological_ordering > ?"
+ " OR (topological_ordering = ? AND stream_ordering > ?)"
")"
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
- )
- )
- return [
- {"event_id": row[0], "actions": json.loads(row[1])}
- for row in txn.fetchall()
- ]
+ ))
+ row = txn.fetchone()
+ if row:
+ return {
+ "notify_count": row[0] or 0,
+ "highlight_count": row[1] or 0,
+ }
+ else:
+ return {"notify_count": 0, "highlight_count": 0}
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -101,19 +100,24 @@ class EventPushActionsStore(SQLBaseStore):
)
defer.returnValue(ret)
- @defer.inlineCallbacks
- def remove_push_actions_for_event_id(self, room_id, event_id):
- def f(txn):
- # Sad that we have to blow away the cache for the whole room here
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id,)
- )
- txn.execute(
- "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
- (room_id, event_id)
- )
- yield self.runInteraction(
- "remove_push_actions_for_event_id",
- f
+ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ # Sad that we have to blow away the cache for the whole room here
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id,)
)
+ txn.execute(
+ "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+ (room_id, event_id)
+ )
+
+
+def _action_has_highlight(actions):
+ for action in actions:
+ try:
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
+ except AttributeError:
+ pass
+
+ return False
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5e85552029..3a5c6ee4b1 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
-from synapse.util.logcontext import preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
@@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream
chunks = [
- events_and_contexts[x:x+100]
+ events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
@@ -205,23 +205,29 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True):
-
- # Remove the any existing cache entries for the event_ids
- for event, _ in events_and_contexts:
+ depth_updates = {}
+ for event, context in events_and_contexts:
+ # Remove the any existing cache entries for the event_ids
txn.call_after(self._invalidate_get_event_cache, event.event_id)
-
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering,
)
- depth_updates = {}
- for event, _ in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
- depth_updates[event.room_id] = max(
- event.depth, depth_updates.get(event.room_id, event.depth)
+ if not event.internal_metadata.is_outlier():
+ depth_updates[event.room_id] = max(
+ event.depth, depth_updates.get(event.room_id, event.depth)
+ )
+
+ if context.push_actions:
+ self._set_push_actions_for_event_and_users_txn(
+ txn, event, context.push_actions
+ )
+
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ self._remove_push_actions_for_event_id_txn(
+ txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items():
@@ -664,14 +670,16 @@ class EventsStore(SQLBaseStore):
for ids, d in lst:
if not d.called:
try:
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
except:
logger.exception("Failed to callback")
- reactor.callFromThread(fire, event_list, row_dict)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@@ -679,10 +687,12 @@ class EventsStore(SQLBaseStore):
def fire(evs):
for _, d in evs:
if not d.called:
- d.errback(e)
+ with PreserveLoggingContext():
+ d.errback(e)
if event_list:
- reactor.callFromThread(fire, event_list)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
@@ -709,18 +719,20 @@ class EventsStore(SQLBaseStore):
should_start = False
if should_start:
- self.runWithConnection(
- self._do_fetch
- )
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
- rows = yield preserve_context_over_deferred(events_d)
+ with PreserveLoggingContext():
+ rows = yield events_d
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
- self._get_event_from_row(
+ preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -740,7 +752,7 @@ class EventsStore(SQLBaseStore):
rows = []
N = 200
for i in range(1 + len(events) / N):
- evs = events[i*N:(i + 1)*N]
+ evs = events[i * N:(i + 1) * N]
if not evs:
break
@@ -755,7 +767,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
- ) % (",".join(["?"]*len(evs)),)
+ ) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 8022b8cfc6..fd05bfe54e 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
table="server_tls_certificates",
keyvalues={"server_name": server_name},
retcols=("tls_certificate",),
+ desc="get_server_certificate",
)
tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c1f5f99789..850736c85e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 28
+SCHEMA_VERSION = 29
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
- logger.debug("Upgrading schema to v%d", v)
+ logger.info("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 9b3aecaf8c..ef525f34c5 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
for row in rows
})
+ @defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state):
- res = self._simple_update_one(
+ res = yield self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
@@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
)
self.get_presence_state.invalidate((user_localpart,))
- return res
+ defer.returnValue(res)
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 8068c73740..4202a6b3dc 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -46,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room",
)
+ @cached(num_args=3)
+ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
+ return self._simple_select_one_onecol(
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id
+ },
+ retcol="event_id",
+ desc="get_own_receipt_for_user",
+ allow_none=True,
+ )
+
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
def f(txn):
@@ -226,6 +240,11 @@ class ReceiptsStore(SQLBaseStore):
room_id, stream_id
)
+ txn.call_after(
+ self.get_last_receipt_event_id_for_user.invalidate,
+ (user_id, room_id, receipt_type)
+ )
+
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 70cde0d04d..967c732bda 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
@@ -134,6 +136,7 @@ class RegistrationStore(SQLBaseStore):
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
+ desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id):
@@ -350,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def find_next_generated_user_id_localpart(self):
+ """
+ Gets the localpart of the next generated user ID.
+
+ Generated user IDs are integers, and we aim for them to be as small as
+ we can. Unfortunately, it's possible some of them are already taken by
+ existing users, and there may be gaps in the already taken range. This
+ function returns the start of the first allocatable gap. This is to
+ avoid the case of ID 10000000 being pre-allocated, so us wasting the
+ first (and shortest) many generated user IDs.
+ """
+ def _find_next_generated_user_id(txn):
+ txn.execute("SELECT name FROM users")
+ rows = self.cursor_to_dict(txn)
+
+ regex = re.compile("^@(\d+):")
+
+ found = set()
+
+ for r in rows:
+ user_id = r["name"]
+ match = regex.search(user_id)
+ if match:
+ found.add(int(match.group(1)))
+ for i in xrange(len(found) + 1):
+ if i not in found:
+ return i
+
+ defer.returnValue((yield self.runInteraction(
+ "find_next_generated_user_id",
+ _find_next_generated_user_id
+ )))
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index dc09a3aaba..46ab38a313 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
desc="get_public_room_ids",
)
- @defer.inlineCallbacks
- def get_rooms(self, is_public):
- """Retrieve a list of all public rooms.
-
- Args:
- is_public (bool): True if the rooms returned should be public.
- Returns:
- A list of room dicts containing at least a "room_id" key, a
- "topic" key if one is set, and a "name" key if one is set
+ def get_room_count(self):
+ """Retrieve a list of all rooms
"""
def f(txn):
- def subquery(table_name, column_name=None):
- column_name = column_name or table_name
- return (
- "SELECT %(table_name)s.event_id as event_id, "
- "%(table_name)s.room_id as room_id, %(column_name)s "
- "FROM %(table_name)s "
- "INNER JOIN current_state_events as c "
- "ON c.event_id = %(table_name)s.event_id " % {
- "column_name": column_name,
- "table_name": table_name,
- }
- )
-
- sql = (
- "SELECT"
- " r.room_id,"
- " max(n.name),"
- " max(t.topic),"
- " max(v.history_visibility),"
- " max(g.guest_access)"
- " FROM rooms AS r"
- " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
- " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
- " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
- " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
- " WHERE r.is_public = ?"
- " GROUP BY r.room_id" % {
- "topic": subquery("topics", "topic"),
- "name": subquery("room_names", "name"),
- "history_visibility": subquery("history_visibility"),
- "guest_access": subquery("guest_access"),
- }
- )
-
- txn.execute(sql, (is_public,))
-
- rows = txn.fetchall()
-
- for i, row in enumerate(rows):
- room_id = row[0]
- aliases = self._simple_select_onecol_txn(
- txn,
- table="room_aliases",
- keyvalues={
- "room_id": room_id
- },
- retcol="room_alias",
- )
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
- rows[i] = list(row) + [aliases]
-
- return rows
-
- rows = yield self.runInteraction(
+ return self.runInteraction(
"get_rooms", f
)
- ret = [
- {
- "room_id": r[0],
- "name": r[1],
- "topic": r[2],
- "world_readable": r[3] == "world_readable",
- "guest_can_join": r[4] == "can_join",
- "aliases": r[5],
- }
- for r in rows
- if r[5] # We only return rooms that have at least one alias.
- ]
-
- defer.returnValue(ret)
-
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn(
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 1d3e004c90..3065b0c1a5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql
new file mode 100644
index 0000000000..ba62a974a4
--- /dev/null
+++ b/synapse/storage/schema/delta/28/public_roms_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX public_room_index on rooms(is_public);
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql
new file mode 100644
index 0000000000..7e7b09820a
--- /dev/null
+++ b/synapse/storage/schema/delta/29/push_actions.sql
@@ -0,0 +1,31 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
+
+UPDATE event_push_actions SET stream_ordering = (
+ SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
+), topological_ordering = (
+ SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
+);
+
+UPDATE event_push_actions SET notif = 1, highlight = 0;
+
+CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
+ user_id, room_id, topological_ordering, stream_ordering
+);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6c32e8f7b3..372b540002 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -171,41 +171,43 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- def _get_state_groups_from_groups(self, groups_and_types):
+ def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> state event ids
-
- Args:
- groups_and_types (list): list of 2-tuple (`group`, `types`)
"""
- def f(txn):
- results = {}
- for group, types in groups_and_types:
- if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
- else:
- where_clause = ""
-
- sql = (
- "SELECT event_id FROM state_groups_state WHERE"
- " state_group = ? %s"
- ) % (where_clause,)
+ def f(txn, groups):
+ if types is not None:
+ where_clause = "AND (%s)" % (
+ " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+ )
+ else:
+ where_clause = ""
- args = [group]
- if types is not None:
- args.extend([i for typ in types for i in typ])
+ sql = (
+ "SELECT state_group, event_id FROM state_groups_state WHERE"
+ " state_group IN (%s) %s" % (
+ ",".join("?" for _ in groups),
+ where_clause,
+ )
+ )
- txn.execute(sql, args)
+ args = list(groups)
+ if types is not None:
+ args.extend([i for typ in types for i in typ])
- results[group] = [r[0] for r in txn.fetchall()]
+ txn.execute(sql, args)
+ rows = self.cursor_to_dict(txn)
+ results = {}
+ for row in rows:
+ results.setdefault(row["state_group"], []).append(row["event_id"])
return results
- return self.runInteraction(
- "_get_state_groups_from_groups",
- f,
- )
+ chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+ for chunk in chunks:
+ return self.runInteraction(
+ "_get_state_groups_from_groups",
+ f, chunk
+ )
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
@@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
- num_args=1)
+ num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- def f(txn):
- results = {}
- for event_id in event_ids:
- results[event_id] = self._simple_select_one_onecol_txn(
- txn,
- table="event_to_state_groups",
- keyvalues={
- "event_id": event_id,
- },
- retcol="state_group",
- allow_none=True,
- )
-
- return results
+ rows = yield self._simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group",),
+ desc="_get_state_group_for_events",
+ )
- return self.runInteraction("_get_state_group_for_events", f)
+ defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
@@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
all events are returned.
"""
results = {}
- missing_groups_and_types = []
+ missing_groups = []
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
@@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
- missing_groups_and_types.append((group, missing_types))
+ missing_groups.append(group)
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
@@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
- missing_groups_and_types.append((group, None))
+ missing_groups.append(group)
- if not missing_groups_and_types:
+ if not missing_groups:
defer.returnValue({
group: {
type_tuple: event
@@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups(
- missing_groups_and_types
+ missing_groups, types
)
state_events = yield self._get_events(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 6e81d46c60..367ffc9543 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,10 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
-from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn
import logging
@@ -78,13 +77,6 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
- def __init__(self, hs):
- super(StreamStore, self).__init__(hs)
-
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
- )
-
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
@@ -177,14 +169,14 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
+ for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
- self.get_room_events_stream_for_room(
- room_id, from_key, to_key, limit
- ).addCallback(lambda r, rm: (rm, r), room_id)
+ preserve_fn(self.get_room_events_stream_for_room)(
+ room_id, from_key, to_key, limit,
+ )
for room_id in room_ids
])
- results.update(dict(res))
+ results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -229,28 +221,30 @@ class StreamStore(SQLBaseStore):
rows = self.cursor_to_dict(txn)
- ret = self._get_events_txn(
- txn,
- [r["event_id"] for r in rows],
- get_prev_content=True
- )
+ return rows
+
+ rows = yield self.runInteraction("get_room_events_stream_for_room", f)
- self._set_before_and_after(ret, rows, topo_order=False)
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
- ret.reverse()
+ self._set_before_and_after(ret, rows, topo_order=False)
- if rows:
- key = "s%d" % min(r["stream_ordering"] for r in rows)
- else:
- # Assume we didn't get anything because there was nothing to
- # get.
- key = from_key
+ ret.reverse()
- return ret, key
- res = yield self.runInteraction("get_room_events_stream_for_room", f)
- defer.returnValue(res)
+ if rows:
+ key = "s%d" % min(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = from_key
- def get_room_changes_for_user(self, user_id, from_key, to_key):
+ defer.returnValue((ret, key))
+
+ @defer.inlineCallbacks
+ def get_membership_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
@@ -258,7 +252,14 @@ class StreamStore(SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
- return defer.succeed([])
+ defer.returnValue([])
+
+ if from_id:
+ has_changed = self._membership_stream_cache.has_entity_changed(
+ user_id, int(from_id)
+ )
+ if not has_changed:
+ defer.returnValue([])
def f(txn):
if from_id is not None:
@@ -283,17 +284,19 @@ class StreamStore(SQLBaseStore):
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
- ret = self._get_events_txn(
- txn,
- [r["event_id"] for r in rows],
- get_prev_content=True
- )
+ return rows
+
+ rows = yield self.runInteraction("get_membership_changes_for_user", f)
- return ret
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
- return self.runInteraction("get_room_changes_for_user", f)
+ self._set_before_and_after(ret, rows, topo_order=False)
+
+ defer.returnValue(ret)
- @log_function
def get_room_events_stream(
self,
user_id,
@@ -324,11 +327,6 @@ class StreamStore(SQLBaseStore):
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
- if room_ids:
- current_room_membership_sql += " AND m.room_id in (%s)" % (
- ",".join(map(lambda _: "?", room_ids))
- )
- current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
@@ -567,6 +565,7 @@ class StreamStore(SQLBaseStore):
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
+ desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
@@ -604,6 +603,10 @@ class StreamStore(SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
+ internal.order = (
+ int(topo) if topo else 0,
+ int(stream),
+ )
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f1fe963adf..133671e238 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -46,7 +46,7 @@ class Clock(object):
def looping_call(self, f, msec):
l = task.LoopingCall(f)
- l.start(msec/1000.0, now=False)
+ l.start(msec / 1000.0, now=False)
return l
def stop_looping_call(self, loop):
@@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
- current_context = LoggingContext.current_context()
-
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext(current_context):
+ with PreserveLoggingContext():
callback(*args, **kwargs)
with PreserveLoggingContext():
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 200edd404c..640fae3890 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,13 +16,16 @@
from twisted.internet import defer, reactor
-from .logcontext import preserve_context_over_deferred
+from .logcontext import PreserveLoggingContext
+@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
- reactor.callLater(seconds, d.callback, seconds)
- return preserve_context_over_deferred(d)
+ with PreserveLoggingContext():
+ reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def run_on_reactor():
@@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().callback(r)
except:
pass
@@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().errback(f)
except:
pass
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -149,7 +152,7 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 494226f5ea..62cae99649 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -55,7 +55,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
- self._clock.looping_call(f, self._expiry_ms/2)
+ self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c673b1bdfc..b37f1c0725 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -32,7 +32,7 @@ class StreamChangeCache(object):
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
- def __init__(self, name, current_stream_pos, max_size=10000):
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
@@ -40,6 +40,9 @@ class StreamChangeCache(object):
self.name = name
caches_by_name[self.name] = self._cache
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
+
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
@@ -49,15 +52,10 @@ class StreamChangeCache(object):
cache_counter.inc_misses(self.name)
return True
- if stream_pos == self._earliest_known_stream_pos:
- # If the same as the earliest key, assume nothing has changed.
- cache_counter.inc_hits(self.name)
- return False
-
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_misses(self.name)
- return True
+ cache_counter.inc_hits(self.name)
+ return False
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
@@ -95,7 +93,7 @@ class StreamChangeCache(object):
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
- if old_pos:
+ if old_pos is not None:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 29d02f7e95..03bc1401b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -58,7 +58,7 @@ class TreeCache(object):
if n:
break
- node_and_keys[i+1][0].pop(k)
+ node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 4ebfebf701..8875813de4 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,9 +15,7 @@
from twisted.internet import defer
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import unwrapFirstError
@@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
+ @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject()))
if not self.suppress_failures:
return failure
+
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers
]
- d = defer.gatherResults(deferreds, consumeErrors=True)
+ res = yield defer.gatherResults(
+ deferreds, consumeErrors=True
+ ).addErrback(unwrapFirstError)
- d.addErrback(unwrapFirstError)
+ defer.returnValue(res)
- return preserve_context_over_deferred(d)
+ def __repr__(self):
+ return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 0595c0fa4f..5316259d15 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -41,13 +41,14 @@ except:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
- "with" block. Contexts inherit the state of their parent contexts.
+ "with" block.
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
- "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
+ "previous_context", "name", "usage_start", "usage_end", "main_thread",
+ "__dict__", "tag", "alive",
]
thread_local = threading.local()
@@ -72,10 +73,13 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
+ def __nonzero__(self):
+ return False
+
sentinel = Sentinel()
def __init__(self, name=None):
- self.parent_context = None
+ self.previous_context = LoggingContext.current_context()
self.name = name
self.ru_stime = 0.
self.ru_utime = 0.
@@ -83,6 +87,8 @@ class LoggingContext(object):
self.db_txn_duration = 0.
self.usage_start = None
self.main_thread = threading.current_thread()
+ self.tag = ""
+ self.alive = True
def __str__(self):
return "%s@%x" % (self.name, id(self))
@@ -101,6 +107,7 @@ class LoggingContext(object):
The context that was previously active
"""
current = cls.current_context()
+
if current is not context:
current.stop()
cls.thread_local.current_context = context
@@ -109,9 +116,13 @@ class LoggingContext(object):
def __enter__(self):
"""Enters this logging context into thread local storage"""
- if self.parent_context is not None:
- raise Exception("Attempt to enter logging context multiple times")
- self.parent_context = self.set_current_context(self)
+ old_context = self.set_current_context(self)
+ if self.previous_context != old_context:
+ logger.warn(
+ "Expected previous context %r, found %r",
+ self.previous_context, old_context
+ )
+ self.alive = True
return self
def __exit__(self, type, value, traceback):
@@ -120,7 +131,7 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exeptions that were thrown.
"""
- current = self.set_current_context(self.parent_context)
+ current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self)
@@ -130,16 +141,11 @@ class LoggingContext(object):
current,
self
)
- self.parent_context = None
-
- def __getattr__(self, name):
- """Delegate member lookup to parent context"""
- return getattr(self.parent_context, name)
+ self.previous_context = None
+ self.alive = False
def copy_to(self, record):
- """Copy fields from this context and its parents to the record"""
- if self.parent_context is not None:
- self.parent_context.copy_to(record)
+ """Copy fields from this context to the record"""
for key, value in self.__dict__.items():
setattr(record, key, value)
@@ -208,7 +214,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor."""
- __slots__ = ["current_context", "new_context"]
+ __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
@@ -219,12 +225,27 @@ class PreserveLoggingContext(object):
self.new_context
)
+ if self.current_context:
+ self.has_parent = self.current_context.previous_context is not None
+ if not self.current_context.alive:
+ logger.debug(
+ "Entering dead context: %s",
+ self.current_context,
+ )
+
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
- LoggingContext.set_current_context(self.current_context)
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ logger.debug(
+ "Unexpected logging context: %s is not %s",
+ context, self.new_context,
+ )
+
if self.current_context is not LoggingContext.sentinel:
- if self.current_context.parent_context is None:
- logger.warn(
+ if not self.current_context.alive:
+ logger.debug(
"Restoring dead context: %s",
self.current_context,
)
@@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d)
return d
+
+
+def preserve_fn(f):
+ """Ensures that function is called with correct context and that context is
+ restored after return. Useful for wrapping functions that return a deferred
+ which you don't yield on.
+ """
+ current = LoggingContext.current_context()
+
+ def g(*args, **kwargs):
+ with PreserveLoggingContext(current):
+ return f(*args, **kwargs)
+
+ return g
+
+
+# modules to ignore in `logcontext_tracer`
+_to_ignore = [
+ "synapse.util.logcontext",
+ "synapse.http.server",
+ "synapse.storage._base",
+ "synapse.util.async",
+]
+
+
+def logcontext_tracer(frame, event, arg):
+ """A tracer that logs whenever a logcontext "unexpectedly" changes within
+ a function. Probably inaccurate.
+
+ Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+ """
+ if event == 'call':
+ name = frame.f_globals["__name__"]
+ if name.startswith("synapse"):
+ if name == "synapse.util.logcontext":
+ if frame.f_code.co_name in ["__enter__", "__exit__"]:
+ tracer = frame.f_back.f_trace
+ if tracer:
+ tracer.just_changed = True
+
+ tracer = frame.f_trace
+ if tracer:
+ return tracer
+
+ if not any(name.startswith(ig) for ig in _to_ignore):
+ return LineTracer()
+
+
+class LineTracer(object):
+ __slots__ = ["context", "just_changed"]
+
+ def __init__(self):
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+
+ def __call__(self, frame, event, arg):
+ if event in 'line':
+ if self.just_changed:
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+ else:
+ c = LoggingContext.current_context()
+ if c != self.context:
+ logger.info(
+ "Context changed! %s -> %s, %s, %s",
+ self.context, c,
+ frame.f_code.co_filename, frame.f_lineno
+ )
+ self.context = c
+
+ return self
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index d5b1a37eff..3a83828d25 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f(
f,
"[FUNC END] {%s-%d} %f",
- (func_name, id, end-start,),
+ (func_name, id, end - start,),
)
return r
@@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name
return wrapped
+
+
+def get_previous_frames():
+ s = inspect.currentframe().f_back.f_back
+ to_return = []
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ to_return.append("{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ ))
+
+ s = s.f_back
+
+ return ", ". join(to_return)
+
+
+def get_previous_frame(ignore=[]):
+ s = inspect.currentframe().f_back.f_back
+
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ return "{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ )
+
+ s = s.f_back
+
+ return None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
new file mode 100644
index 0000000000..c51b641125
--- /dev/null
+++ b/synapse/util/metrics.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.util.logcontext import LoggingContext
+import synapse.metrics
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+block_timer = metrics.register_distribution(
+ "block_timer",
+ labels=["block_name"]
+)
+
+block_ru_utime = metrics.register_distribution(
+ "block_ru_utime", labels=["block_name"]
+)
+
+block_ru_stime = metrics.register_distribution(
+ "block_ru_stime", labels=["block_name"]
+)
+
+block_db_txn_count = metrics.register_distribution(
+ "block_db_txn_count", labels=["block_name"]
+)
+
+block_db_txn_duration = metrics.register_distribution(
+ "block_db_txn_duration", labels=["block_name"]
+)
+
+
+class Measure(object):
+ __slots__ = [
+ "clock", "name", "start_context", "start", "new_context", "ru_utime",
+ "ru_stime", "db_txn_count", "db_txn_duration"
+ ]
+
+ def __init__(self, clock, name):
+ self.clock = clock
+ self.name = name
+ self.start_context = None
+ self.start = None
+
+ def __enter__(self):
+ self.start = self.clock.time_msec()
+ self.start_context = LoggingContext.current_context()
+ if self.start_context:
+ self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+ self.db_txn_count = self.start_context.db_txn_count
+ self.db_txn_duration = self.start_context.db_txn_duration
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None or not self.start_context:
+ return
+
+ duration = self.clock.time_msec() - self.start
+ block_timer.inc_by(duration, self.name)
+
+ context = LoggingContext.current_context()
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+ context, self.start_context, self.name
+ )
+ return
+
+ if not context:
+ logger.warn("Expected context. (%r)", self.name)
+ return
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
+ block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
+ block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+ block_db_txn_duration.inc_by(
+ context.db_txn_duration - self.db_txn_duration, self.name
+ )
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index c37d6f12e3..4076eed269 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
+from synapse.util.logcontext import preserve_fn
import collections
import contextlib
@@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = sleep(self.sleep_msec/1000.0)
+ ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
|