diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6fbd5d6876..d0dfa959dc 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -66,6 +66,17 @@ class CodeMessageException(RuntimeError):
return cs_error(self.msg)
+class MatrixCodeMessageException(CodeMessageException):
+ """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
+
+ Attributes:
+ errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+ """
+ def __init__(self, code, msg, errcode=Codes.UNKNOWN):
+ super(MatrixCodeMessageException, self).__init__(code, msg)
+ self.errcode = errcode
+
+
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6be18880b9..e9a732ff03 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -50,6 +50,7 @@ class EventContext(object):
"prev_group",
"delta_ids",
"prev_state_events",
+ "app_service",
]
def __init__(self):
@@ -68,3 +69,5 @@ class EventContext(object):
self.delta_ids = None
self.prev_state_events = None
+
+ self.app_service = None
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index dee387eb7f..695f1a7375 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -24,7 +24,6 @@ from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
@@ -183,15 +182,12 @@ class TransactionQueue(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
- users_in_room = yield self.state.get_current_user_in_room(
+ destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=[
prev_id for prev_id, _ in event.prev_events
],
)
- destinations = set(
- get_domain_from_id(user_id) for user_id in users_in_room
- )
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ebbf844489..2af9849ed0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -171,6 +171,16 @@ class FederationHandler(BaseHandler):
yield self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
+
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ have_seen = yield self.store.have_events(prevs)
+ seen = set(have_seen.iterkeys())
+
+ if not prevs - seen:
+ logger.info(
+ "Found all missing prev events for %s", pdu.event_id
+ )
elif prevs - seen:
logger.info(
"Not fetching %d missing events for room %r,event %s: %r...",
@@ -178,8 +188,6 @@ class FederationHandler(BaseHandler):
list(prevs - seen)[:5],
)
- prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
@@ -214,19 +222,15 @@ class FederationHandler(BaseHandler):
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
pdu: received pdu
- prevs (str[]): List of event ids which we are missing
+ prevs (set(str)): List of event ids which we are missing
min_depth (int): Minimum depth of events to return.
-
- Returns:
- Deferred<dict(str, str?)>: updated have_seen dictionary
"""
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if not prevs - seen:
- # nothing left to do
- defer.returnValue(have_seen)
+ return
latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id
@@ -288,19 +292,6 @@ class FederationHandler(BaseHandler):
get_missing=False
)
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
- seen = set(have_seen.keys())
- if prevs - seen:
- logger.info(
- "Still missing %d prev events for %s: %r...",
- len(prevs - seen), pdu.event_id, list(prevs - seen)[:5]
- )
- else:
- logger.info("Found all missing prev events for %s", pdu.event_id)
- defer.returnValue(have_seen)
-
@log_function
@defer.inlineCallbacks
def _process_received_pdu(self, origin, pdu, state, auth_chain):
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 6a53c5eb47..9efcdff1d6 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -18,7 +18,7 @@
from twisted.internet import defer
from synapse.api.errors import (
- CodeMessageException
+ MatrixCodeMessageException, CodeMessageException
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
@@ -90,6 +90,9 @@ class IdentityHandler(BaseHandler):
),
{'sid': creds['sid'], 'client_secret': client_secret}
)
+ except MatrixCodeMessageException as e:
+ logger.info("getValidated3pid failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
data = json.loads(e.msg)
@@ -159,6 +162,9 @@ class IdentityHandler(BaseHandler):
params
)
defer.returnValue(data)
+ except MatrixCodeMessageException as e:
+ logger.info("Proxied requestToken failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e
@@ -193,6 +199,9 @@ class IdentityHandler(BaseHandler):
params
)
defer.returnValue(data)
+ except MatrixCodeMessageException as e:
+ logger.info("Proxied requestToken failed with Matrix error: %r", e)
+ raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 82a2ade1f6..57265c6d7d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -175,7 +175,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
+ def create_event(self, requester, event_dict, token_id=None, txn_id=None,
+ prev_event_ids=None):
"""
Given a dict from a client, create a new event.
@@ -185,6 +186,7 @@ class MessageHandler(BaseHandler):
Adds display names to Join membership events.
Args:
+ requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
@@ -226,6 +228,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
+ requester=requester,
prev_event_ids=prev_event_ids,
)
@@ -319,6 +322,7 @@ class MessageHandler(BaseHandler):
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
+ requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
@@ -416,7 +420,7 @@ class MessageHandler(BaseHandler):
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
- def _create_new_client_event(self, builder, prev_event_ids=None):
+ def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@@ -456,6 +460,8 @@ class MessageHandler(BaseHandler):
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
+ if requester:
+ context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 28b2c80a93..ab87632d99 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -70,6 +70,7 @@ class RoomMemberHandler(BaseHandler):
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
+ requester,
{
"type": EventTypes.Member,
"content": content,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2f770f5d..9cf797043a 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -16,7 +16,7 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
- CodeMessageException, SynapseError, Codes,
+ CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
@@ -145,6 +145,11 @@ class SimpleHttpClient(object):
body = yield preserve_context_over_fn(readBody, response)
+ if 200 <= response.code < 300:
+ defer.returnValue(json.loads(body))
+ else:
+ raise self._exceptionFromFailedRequest(response, body)
+
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
@@ -164,8 +169,11 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
- body = yield self.get_raw(uri, args)
- defer.returnValue(json.loads(body))
+ try:
+ body = yield self.get_raw(uri, args)
+ defer.returnValue(json.loads(body))
+ except CodeMessageException as e:
+ raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
@@ -246,6 +254,15 @@ class SimpleHttpClient(object):
else:
raise CodeMessageException(response.code, body)
+ def _exceptionFromFailedRequest(self, response, body):
+ try:
+ jsonBody = json.loads(body)
+ errcode = jsonBody['errcode']
+ error = jsonBody['error']
+ return MatrixCodeMessageException(response.code, error, errcode)
+ except (ValueError, KeyError):
+ return CodeMessageException(response.code, body)
+
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index ab48ff925e..fcaf58b93b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -144,6 +144,9 @@ class SlavedEventStore(BaseSlavedStore):
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
+ get_joined_hosts = DataStore.get_joined_hosts.__func__
+ _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
+
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c376ab8fd7..cd388770c8 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -164,6 +164,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
else:
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
+ requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 3acf4eacdd..38a739f2f8 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -31,6 +31,7 @@ import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
+from synapse.util.ratelimitutils import FederationRateLimiter
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
@@ -115,6 +116,45 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
+class UsernameAvailabilityRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/register/available")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(UsernameAvailabilityRestServlet, self).__init__()
+ self.hs = hs
+ self.registration_handler = hs.get_handlers().registration_handler
+ self.ratelimiter = FederationRateLimiter(
+ hs.get_clock(),
+ # Time window of 2s
+ window_size=2000,
+ # Artificially delay requests if rate > sleep_limit/window_size
+ sleep_limit=1,
+ # Amount of artificial delay to apply
+ sleep_msec=1000,
+ # Error with 429 if more than reject_limit requests are queued
+ reject_limit=1,
+ # Allow 1 request at a time
+ concurrent_requests=1,
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ ip = self.hs.get_ip_from_request(request)
+ with self.ratelimiter.ratelimit(ip) as wait_deferred:
+ yield wait_deferred
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_request(body, ['username'])
+
+ yield self.registration_handler.check_username(body['username'])
+
+ defer.returnValue((200, {"available": True}))
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
@@ -555,4 +595,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
+ UsernameAvailabilityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/state.py b/synapse/state.py
index f6b83d888a..02fee47f39 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -176,6 +176,17 @@ class StateHandler(object):
defer.returnValue(joined_users)
@defer.inlineCallbacks
+ def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
+ entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ joined_hosts = yield self.store.get_joined_hosts(
+ room_id, entry.state_id, entry.state
+ )
+ defer.returnValue(joined_hosts)
+
+ @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event.
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index a3790419dd..98707d40ee 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from synapse.util.caches.descriptors import cached
+from synapse.types import get_domain_from_id
from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
@@ -49,6 +50,9 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
persist_event_counter = metrics.register_counter("persisted_events")
+event_counter = metrics.register_counter(
+ "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
+)
def encode_json(json_object):
@@ -370,6 +374,18 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
+ for event, context in chunk:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
+
+ event_counter.inc(event.type, origin_type, origin_entity)
@defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c63c0622dd..ad3c9b06d9 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
+from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@@ -147,7 +148,7 @@ class RoomMemberStore(SQLBaseStore):
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
- @cached(max_entries=500000, iterable=True)
+ @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
sql = (
@@ -160,7 +161,7 @@ class RoomMemberStore(SQLBaseStore):
)
txn.execute(sql, (room_id, Membership.JOIN,))
- return [r[0] for r in txn]
+ return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@@ -508,6 +509,44 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False)
+ def get_joined_hosts(self, room_id, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_hosts(
+ room_id, state_group, state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ joined_hosts = set()
+ for (etype, state_key), event_id in current_state_ids.items():
+ if etype == EventTypes.Member:
+ try:
+ host = get_domain_from_id(state_key)
+ except:
+ logger.warn("state_key not user_id: %s", state_key)
+ continue
+
+ if host in joined_hosts:
+ continue
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if event and event.content["membership"] == Membership.JOIN:
+ joined_hosts.add(intern_string(host))
+
+ defer.returnValue(joined_hosts)
+
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 98a5a26ac5..2a2360ab5d 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
- super(SynapseError, self).__init__(504, "Timed out")
+ super(DeferredTimedOutError, self).__init__(504, "Timed out")
def unwrapFirstError(failure):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 807e147657..aa182eeac7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.stringutils import to_ascii
from . import register_cache
@@ -163,10 +164,6 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
@@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
)
- def get_cache_key(args, kwargs):
+ def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
@@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
yield self.arg_defaults[nm]
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+ if self.num_args == 1:
+ nm = self.arg_names[0]
+
+ def get_cache_key(args, kwargs):
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return self.arg_defaults[nm]
+ else:
+ def get_cache_key(args, kwargs):
+ return tuple(get_cache_key_gen(args, kwargs))
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
- cache_key = tuple(get_cache_key(args, kwargs))
+ cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
+ # If our cache_key is a string, try to convert to ascii to save
+ # a bit of space in large caches
+ if isinstance(cache_key, basestring):
+ cache_key = to_ascii(cache_key)
+
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
@@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
return observer
- wrapped.invalidate = cache.invalidate
+ if self.num_args == 1:
+ wrapped.invalidate = lambda key: cache.invalidate(key[0])
+ wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+ else:
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
+
wrapped.invalidate_all = cache.invalidate_all
- wrapped.invalidate_many = cache.invalidate_many
- wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
|