diff options
-rw-r--r-- | synapse/federation/transaction_queue.py | 6 | ||||
-rw-r--r-- | synapse/http/client.py | 35 | ||||
-rw-r--r-- | synapse/replication/slave/storage/events.py | 3 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 41 | ||||
-rw-r--r-- | synapse/state.py | 11 | ||||
-rw-r--r-- | synapse/storage/_base.py | 8 | ||||
-rw-r--r-- | synapse/storage/events.py | 16 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 101 | ||||
-rw-r--r-- | synapse/storage/state.py | 12 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 42 |
10 files changed, 202 insertions, 73 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index dee387eb7f..695f1a7375 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -24,7 +24,6 @@ from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func -from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state, get_interested_remotes import synapse.metrics @@ -183,15 +182,12 @@ class TransactionQueue(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - users_in_room = yield self.state.get_current_user_in_room( + destinations = yield self.state.get_current_hosts_in_room( event.room_id, latest_event_ids=[ prev_id for prev_id, _ in event.prev_events ], ) - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - ) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to diff --git a/synapse/http/client.py b/synapse/http/client.py index 9cf797043a..9eba046bbf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -19,6 +19,7 @@ from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) from synapse.util.logcontext import preserve_context_over_fn +from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -72,39 +73,45 @@ class SimpleHttpClient(object): contextFactory=hs.get_http_client_context_factory() ) self.user_agent = hs.version_string + self.clock = hs.get_clock() if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + @defer.inlineCallbacks def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) - d = preserve_context_over_fn( - self.agent.request, - method, uri, *args, **kwargs - ) + + def send_request(): + request_deferred = self.agent.request( + method, uri, *args, **kwargs + ) + + return self.clock.time_bound_deferred( + request_deferred, + time_out=60, + ) logger.info("Sending request %s %s", method, uri) - def _cb(response): + try: + with logcontext.PreserveLoggingContext(): + response = yield send_request() + incoming_responses_counter.inc(method, response.code) logger.info( "Received response to %s %s: %s", method, uri, response.code ) - return response - - def _eb(failure): + defer.returnValue(response) + except Exception as e: incoming_responses_counter.inc(method, "ERR") logger.info( "Error sending request to %s %s: %s %s", - method, uri, failure.type, failure.getErrorMessage() + method, uri, type(e).__name__, e.message ) - return failure - - d.addCallbacks(_cb, _eb) - - return d + raise e @defer.inlineCallbacks def post_urlencoded_get_json(self, uri, args={}): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index ab48ff925e..fcaf58b93b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -144,6 +144,9 @@ class SlavedEventStore(BaseSlavedStore): RoomMemberStore.__dict__["_get_joined_users_from_context"] ) + get_joined_hosts = DataStore.get_joined_hosts.__func__ + _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3acf4eacdd..38a739f2f8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -31,6 +31,7 @@ import logging import hmac from hashlib import sha1 from synapse.util.async import run_on_reactor +from synapse.util.ratelimitutils import FederationRateLimiter # We ought to be using hmac.compare_digest() but on older pythons it doesn't @@ -115,6 +116,45 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class UsernameAvailabilityRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/available") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(UsernameAvailabilityRestServlet, self).__init__() + self.hs = hs + self.registration_handler = hs.get_handlers().registration_handler + self.ratelimiter = FederationRateLimiter( + hs.get_clock(), + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ) + + @defer.inlineCallbacks + def on_GET(self, request): + ip = self.hs.get_ip_from_request(request) + with self.ratelimiter.ratelimit(ip) as wait_deferred: + yield wait_deferred + + body = parse_json_object_from_request(request) + assert_params_in_request(body, ['username']) + + yield self.registration_handler.check_username(body['username']) + + defer.returnValue((200, {"available": True})) + + class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") @@ -555,4 +595,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/state.py b/synapse/state.py index f6b83d888a..02fee47f39 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -176,6 +176,17 @@ class StateHandler(object): defer.returnValue(joined_users) @defer.inlineCallbacks + def get_current_hosts_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + logger.debug("calling resolve_state_groups from get_current_hosts_in_room") + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_hosts = yield self.store.get_joined_hosts( + room_id, entry.state_id, entry.state + ) + defer.returnValue(joined_hosts) + + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """Build an EventContext structure for the event. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 58b73af7d2..c659004e8d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -60,12 +60,12 @@ class LoggingTransaction(object): object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) - def call_after(self, callback, *args, **kwargs): + def call_after(self, callback, *args): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ - self.after_callbacks.append((callback, args, kwargs)) + self.after_callbacks.append((callback, args)) def __getattr__(self, name): return getattr(self.txn, name) @@ -319,8 +319,8 @@ class SQLBaseStore(object): inner_func, *args, **kwargs ) finally: - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index d946024c9b..98707d40ee 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -374,12 +374,6 @@ class EventsStore(SQLBaseStore): new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) - - for room_id, (_, _, new_state) in current_state_for_room.iteritems(): - self.get_current_state_ids.prefill( - (room_id, ), new_state - ) - for event, context in chunk: if context.app_service: origin_type = "local" @@ -441,10 +435,10 @@ class EventsStore(SQLBaseStore): Assumes that we are only persisting events for one room at a time. Returns: - 3-tuple (to_delete, to_insert, new_state) where both are state dicts, - i.e. (type, state_key) -> event_id. `to_delete` are the entries to + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entries to first be deleted from current_state_events, `to_insert` are entries - to insert. `new_state` is the full set of state. + to insert. May return None if there are no changes to be applied. """ # Now we need to work out the different state sets for @@ -551,7 +545,7 @@ class EventsStore(SQLBaseStore): if ev_id in events_to_insert } - defer.returnValue((to_delete, to_insert, current_state)) + defer.returnValue((to_delete, to_insert)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -704,7 +698,7 @@ class EventsStore(SQLBaseStore): def _update_current_state_txn(self, txn, state_delta_by_room): for room_id, current_state_tuple in state_delta_by_room.iteritems(): - to_delete, to_insert, _ = current_state_tuple + to_delete, to_insert = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", [(ev_id,) for ev_id in to_delete.itervalues()], diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7ad2198d96..ad3c9b06d9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from collections import namedtuple from ._base import SQLBaseStore +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii @@ -147,7 +148,7 @@ class RoomMemberStore(SQLBaseStore): hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) defer.returnValue(hosts) - @cached(max_entries=500000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): sql = ( @@ -160,7 +161,7 @@ class RoomMemberStore(SQLBaseStore): ) txn.execute(sql, (room_id, Membership.JOIN,)) - return [r[0] for r in txn] + return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) @cached() @@ -417,25 +418,51 @@ class RoomMemberStore(SQLBaseStore): if key[0] == EventTypes.Member ] - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=['user_id', 'display_name', 'avatar_url'], - keyvalues={ - "membership": Membership.JOIN, - }, - batch_size=500, - desc="_get_joined_users_from_context", + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + event_map = self._get_events_from_cache( + member_event_ids, + allow_rejected=False, ) - users_in_room = { - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), + missing_member_event_ids = [] + users_in_room = {} + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=missing_member_event_ids, + retcols=('user_id', 'display_name', 'avatar_url',), + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", ) - for row in rows - } + + users_in_room.update({ + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + }) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -482,6 +509,44 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(False) + def get_joined_hosts(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + joined_hosts = set() + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + host = get_domain_from_id(state_key) + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + if host in joined_hosts: + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + joined_hosts.add(intern_string(host)) + + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 03981f5d2b..a16afa8df5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -227,18 +227,6 @@ class StateStore(SQLBaseStore): ], ) - # Prefill the state group cache with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=context.state_group, - value=context.current_state_ids, - full=True, - ) - self._simple_insert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 807e147657..aa182eeac7 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.stringutils import to_ascii from . import register_cache @@ -163,10 +164,6 @@ class Cache(object): def invalidate(self, key): self.check_thread() - if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) @@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) - def get_cache_key(args, kwargs): + def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into the cache_key. @@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase): else: yield self.arg_defaults[nm] + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + if self.num_args == 1: + nm = self.arg_names[0] + + def get_cache_key(args, kwargs): + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): + return tuple(get_cache_key_gen(args, kwargs)) + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - cache_key = tuple(get_cache_key(args, kwargs)) + cache_key = get_cache_key(args, kwargs) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) + # If our cache_key is a string, try to convert to ascii to save + # a bit of space in large caches + if isinstance(cache_key, basestring): + cache_key = to_ascii(cache_key) + result_d = ObservableDeferred(ret, consumeErrors=True) cache.set(cache_key, result_d, callback=invalidate_callback) observer = result_d.observe() @@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase): else: return observer - wrapped.invalidate = cache.invalidate + if self.num_args == 1: + wrapped.invalidate = lambda key: cache.invalidate(key[0]) + wrapped.prefill = lambda key, val: cache.prefill(key[0], val) + else: + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.invalidate_all = cache.invalidate_all - wrapped.invalidate_many = cache.invalidate_many - wrapped.prefill = cache.prefill wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped |