diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/_base.py | 51 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 36 | ||||
-rw-r--r-- | synapse/replication/slave/storage/events.py | 9 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 2 | ||||
-rw-r--r-- | synapse/util/async.py | 50 | ||||
-rw-r--r-- | synapse/util/caches/response_cache.py | 2 |
6 files changed, 134 insertions, 16 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c77afe7f51..88d8b9ba54 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = ( ) +MEMBERSHIP_PRIORITY = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, +) + + class BaseHandler(object): """ Common base class for the event handlers. @@ -72,6 +81,7 @@ class BaseHandler(object): * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events + events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -86,6 +96,12 @@ class BaseHandler(object): ) def allowed(event, user_id, is_peeking): + """ + Args: + event (synapse.events.EventBase): event to check + user_id (str) + is_peeking (bool) + """ state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -117,17 +133,30 @@ class BaseHandler(object): if old_priority < new_priority: visibility = prev_visibility - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None + # likewise, if the event is the user's own membership event, use + # the 'most joined' membership + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: + membership = event.content.get("membership", None) + if membership not in MEMBERSHIP_PRIORITY: + membership = "leave" + + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership not in MEMBERSHIP_PRIORITY: + prev_membership = "leave" + + new_priority = MEMBERSHIP_PRIORITY.index(membership) + old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) + if old_priority < new_priority: + membership = prev_membership + + # otherwise, get the user's membership at the time of the event. + if membership is None: + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id not in event_id_forgotten: + membership = membership_event.membership # if the user was a member of the room at the time of the event, # they can see it. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fe2315df8f..b6ef3c91af 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -24,6 +24,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import AuthError, SynapseError, Codes from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.async import Linearizer from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -60,6 +61,8 @@ class RoomMemberHandler(BaseHandler): def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) + self.member_linearizer = Linearizer() + self.clock = hs.get_clock() self.distributor = hs.get_distributor() @@ -183,6 +186,34 @@ class RoomMemberHandler(BaseHandler): third_party_signed=None, ratelimit=True, ): + key = (target, room_id,) + + with (yield self.member_linearizer.queue(key)): + result = yield self._update_membership( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -233,6 +264,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + if requester.is_guest: content["kind"] = "guest" diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 680dc89536..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -89,8 +90,11 @@ class SlavedEventStore(BaseSlavedStore): _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _parse_events_txn = DataStore._parse_events_txn.__func__ _get_events_txn = DataStore._get_events_txn.__func__ + _enqueue_events = DataStore._enqueue_events.__func__ + _do_fetch = DataStore._do_fetch.__func__ _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row = DataStore._get_event_from_row.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ @@ -100,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -142,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets @@ -158,6 +161,8 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.event_id) + self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/synapse/util/async.py b/synapse/util/async.py index cd4d90f3cf..0d6f48e2d8 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,9 +16,13 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext, preserve_fn +from .logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) from synapse.util import unwrapFirstError +from contextlib import contextmanager + @defer.inlineCallbacks def sleep(seconds): @@ -137,3 +141,47 @@ def concurrently_execute(func, args, limit): preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) ], consumeErrors=True).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Linearizes access to resources based on a key. Useful to ensure only one + thing is happening at a time on a given resource. + + Example: + + with (yield linearizer.queue("test_key")): + # do some work. + + """ + def __init__(self): + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + # If there is already a deferred in the queue, we pull it out so that + # we can wait on it later. + # Then we replace it with a deferred that we resolve *after* the + # context manager has exited. + # We only return the context manager after the previous deferred has + # resolved. + # This all has the net effect of creating a chain of deferreds that + # wait for the previous deferred before starting their work. + current_defer = self.key_to_defer.get(key) + + new_defer = defer.Deferred() + self.key_to_defer[key] = new_defer + + if current_defer: + yield preserve_context_over_deferred(current_defer) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + current_d = self.key_to_defer.get(key) + if current_d is new_defer: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index be310ba320..36686b479e 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,7 +35,7 @@ class ResponseCache(object): return None def set(self, key, deferred): - result = ObservableDeferred(deferred) + result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result def remove(r): |