From 8462c26485fb4f19fc52accc05870c0ea4c8eb6a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jul 2018 12:43:23 +0100 Subject: Improvements to the Limiter * give them names, to improve logging * use a deque rather than a list for efficiency --- synapse/handlers/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a39b852ceb..8c12c6990f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -427,7 +427,7 @@ class EventCreationHandler(object): # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. - self.limiter = Limiter(max_count=5) + self.limiter = Limiter(max_count=5, name="room_event_creation_limit") self.action_generator = hs.get_action_generator() -- cgit 1.5.1 From 7c712f95bbc7f405355d5714c92d65551f64fec2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 20 Jul 2018 13:11:43 +0100 Subject: Combine Limiter and Linearizer Linearizer was effectively a Limiter with max_count=1, so rather than maintaining two sets of code, let's combine them. --- synapse/handlers/message.py | 4 +- synapse/util/async.py | 99 +++++-------------------------------------- tests/util/test_limiter.py | 70 ------------------------------ tests/util/test_linearizer.py | 47 ++++++++++++++++++++ 4 files changed, 59 insertions(+), 161 deletions(-) delete mode 100644 tests/util/test_limiter.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8c12c6990f..abc07ea87c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -33,7 +33,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import send_event_to_master from synapse.types import RoomAlias, RoomStreamToken, UserID -from synapse.util.async import Limiter, ReadWriteLock +from synapse.util.async import Linearizer, ReadWriteLock from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func @@ -427,7 +427,7 @@ class EventCreationHandler(object): # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. - self.limiter = Limiter(max_count=5, name="room_event_creation_limit") + self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") self.action_generator = hs.get_action_generator() diff --git a/synapse/util/async.py b/synapse/util/async.py index 22071ddef7..5a50d9700f 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -157,91 +157,8 @@ def concurrently_execute(func, args, limit): class Linearizer(object): - """Linearizes access to resources based on a key. Useful to ensure only one - thing is happening at a time on a given resource. - - Example: - - with (yield linearizer.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None, clock=None): - if name is None: - self.name = id(self) - else: - self.name = name - self.key_to_defer = {} - - if not clock: - from twisted.internet import reactor - clock = Clock(reactor) - self._clock = clock - - @defer.inlineCallbacks - def queue(self, key): - # If there is already a deferred in the queue, we pull it out so that - # we can wait on it later. - # Then we replace it with a deferred that we resolve *after* the - # context manager has exited. - # We only return the context manager after the previous deferred has - # resolved. - # This all has the net effect of creating a chain of deferreds that - # wait for the previous deferred before starting their work. - current_defer = self.key_to_defer.get(key) - - new_defer = defer.Deferred() - self.key_to_defer[key] = new_defer - - if current_defer: - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key - ) - try: - with PreserveLoggingContext(): - yield current_defer - except Exception: - logger.exception("Unexpected exception in Linearizer") - - logger.info("Acquired linearizer lock %r for key %r", self.name, - key) - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (There's no particular need for it to happen before we return - # the context manager, but it needs to happen while we hold the - # lock, and the context manager's exit code must be synchronous, - # so actually this is the only sensible place. - yield self._clock.sleep(0) - - else: - logger.info("Acquired uncontended linearizer lock %r for key %r", - self.name, key) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - with PreserveLoggingContext(): - new_defer.callback(None) - current_d = self.key_to_defer.get(key) - if current_d is new_defer: - self.key_to_defer.pop(key, None) - - defer.returnValue(_ctx_manager()) - - -class Limiter(object): """Limits concurrent access to resources based on a key. Useful to ensure - only a few thing happen at a time on a given resource. + only a few things happen at a time on a given resource. Example: @@ -249,7 +166,7 @@ class Limiter(object): # do some work. """ - def __init__(self, max_count=1, name=None, clock=None): + def __init__(self, name=None, max_count=1, clock=None): """ Args: max_count(int): The maximum number of concurrent accesses @@ -283,10 +200,12 @@ class Limiter(object): new_defer = defer.Deferred() entry[1].append(new_defer) - logger.info("Waiting to acquire limiter lock %r for key %r", self.name, key) + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) yield make_deferred_yieldable(new_defer) - logger.info("Acquired limiter lock %r for key %r", self.name, key) + logger.info("Acquired linearizer lock %r for key %r", self.name, key) entry[0] += 1 # if the code holding the lock completes synchronously, then it @@ -302,7 +221,9 @@ class Limiter(object): yield self._clock.sleep(0) else: - logger.info("Acquired uncontended limiter lock %r for key %r", self.name, key) + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) entry[0] += 1 @contextmanager @@ -310,7 +231,7 @@ class Limiter(object): try: yield finally: - logger.info("Releasing limiter lock %r for key %r", self.name, key) + logger.info("Releasing linearizer lock %r for key %r", self.name, key) # We've finished executing so check if there are any things # blocked waiting to execute and start one of them diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py deleted file mode 100644 index f7b942f5c1..0000000000 --- a/tests/util/test_limiter.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.util.async import Limiter - -from tests import unittest - - -class LimiterTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def test_limiter(self): - limiter = Limiter(3) - - key = object() - - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) - - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) - - cm4 = yield d4 - self.assertFalse(d5.called) - - with cm3: - self.assertFalse(d5.called) - - cm5 = yield d5 - - with cm2: - pass - - with cm4: - pass - - with cm5: - pass - - d6 = limiter.queue(key) - with (yield d6): - pass diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index c95907b32c..c9563124f9 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,3 +66,49 @@ class LinearizerTestCase(unittest.TestCase): func(i) return func(1000) + + @defer.inlineCallbacks + def test_multiple_entries(self): + limiter = Linearizer(max_count=3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + cm4 = yield d4 + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + cm5 = yield d5 + + with cm2: + pass + + with cm4: + pass + + with cm5: + pass + + d6 = limiter.queue(key) + with (yield d6): + pass -- cgit 1.5.1 From 3132b89f12f0386558045683ad198f090b0e2c90 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 21 Jul 2018 15:47:18 +1000 Subject: Make the rest of the .iterwhatever go away (#3562) --- changelog.d/3562.misc | 0 synapse/app/homeserver.py | 6 ++++-- synapse/app/synctl.py | 4 +++- synapse/events/snapshot.py | 4 +++- synapse/handlers/federation.py | 18 +++++++++--------- synapse/state.py | 6 +++--- synapse/visibility.py | 19 ++++++++++--------- tests/test_federation.py | 3 +-- 8 files changed, 33 insertions(+), 27 deletions(-) create mode 100644 changelog.d/3562.misc (limited to 'synapse/handlers') diff --git a/changelog.d/3562.misc b/changelog.d/3562.misc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 14e6dca522..2ad1beb8d8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,6 +18,8 @@ import logging import os import sys +from six import iteritems + from twisted.application import service from twisted.internet import defer, reactor from twisted.web.resource import EncodingResourceWrapper, NoResource @@ -442,7 +444,7 @@ def run(hs): stats["total_nonbridged_users"] = total_nonbridged_users daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in daily_user_type_results.iteritems(): + for name, count in iteritems(daily_user_type_results): stats["daily_user_type_" + name] = count room_count = yield hs.get_datastore().get_room_count() @@ -453,7 +455,7 @@ def run(hs): stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() - for name, count in r30_results.iteritems(): + for name, count in iteritems(r30_results): stats["r30_users_" + name] = count daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 68acc15a9a..d658f967ba 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -25,6 +25,8 @@ import subprocess import sys import time +from six import iteritems + import yaml SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] @@ -173,7 +175,7 @@ def main(): os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in cache_factors.iteritems(): + for cache_name, factor in iteritems(cache_factors): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) worker_configfiles = [] diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index bcd9bb5946..f83a1581a6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + from frozendict import frozendict from twisted.internet import defer @@ -159,7 +161,7 @@ def _encode_state_dict(state_dict): return [ (etype, state_key, v) - for (etype, state_key), v in state_dict.iteritems() + for (etype, state_key), v in iteritems(state_dict) ] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 65f6041b10..a6d391c4e8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,8 +21,8 @@ import logging import sys import six -from six import iteritems -from six.moves import http_client +from six import iteritems, itervalues +from six.moves import http_client, zip from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json @@ -731,7 +731,7 @@ class FederationHandler(BaseHandler): """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in state.iteritems() + for (e_type, state_key), event in iteritems(state) if e_type == EventTypes.Member and event.membership == Membership.JOIN ] @@ -748,7 +748,7 @@ class FederationHandler(BaseHandler): except Exception: pass - return sorted(joined_domains.iteritems(), key=lambda d: d[1]) + return sorted(joined_domains.items(), key=lambda d: d[1]) curr_domains = get_domains_from_state(curr_state) @@ -811,7 +811,7 @@ class FederationHandler(BaseHandler): tried_domains = set(likely_domains) tried_domains.add(self.server_name) - event_ids = list(extremities.iterkeys()) + event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = logcontext.preserve_fn( @@ -827,15 +827,15 @@ class FederationHandler(BaseHandler): states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( - [e_id for ids in states.itervalues() for e_id in ids.itervalues()], + [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False ) states = { key: { k: state_map[e_id] - for k, e_id in state_dict.iteritems() + for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in states.iteritems() + } for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: @@ -1515,7 +1515,7 @@ class FederationHandler(BaseHandler): yield self.store.persist_events( [ (ev_info["event"], context) - for ev_info, context in itertools.izip(event_infos, contexts) + for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, ) diff --git a/synapse/state.py b/synapse/state.py index 15a593d41c..504caae2f7 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ import hashlib import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems, iterkeys, itervalues from frozendict import frozendict @@ -647,7 +647,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): for event_id in event_ids ) if event_map is not None: - needed_events -= set(event_map.iterkeys()) + needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d conflicted events", len(needed_events)) @@ -668,7 +668,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): new_needed_events = set(itervalues(auth_events)) new_needed_events -= needed_events if event_map is not None: - new_needed_events -= set(event_map.iterkeys()) + new_needed_events -= set(iterkeys(event_map)) logger.info("Asking for %d auth events", len(new_needed_events)) diff --git a/synapse/visibility.py b/synapse/visibility.py index 9b97ea2b83..ba0499a022 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import operator -import six +from six import iteritems, itervalues +from six.moves import map from twisted.internet import defer @@ -221,7 +222,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, return event # check each event: gives an iterable[None|EventBase] - filtered_events = itertools.imap(allowed, events) + filtered_events = map(allowed, events) # remove the None entries filtered_events = filter(operator.truth, filtered_events) @@ -261,7 +262,7 @@ def filter_events_for_server(store, server_name, events): # membership states for the requesting server to determine # if the server is either in the room or has been invited # into the room. - for ev in state.itervalues(): + for ev in itervalues(state): if ev.type != EventTypes.Member: continue try: @@ -295,7 +296,7 @@ def filter_events_for_server(store, server_name, events): ) visibility_ids = set() - for sids in event_to_state_ids.itervalues(): + for sids in itervalues(event_to_state_ids): hist = sids.get((EventTypes.RoomHistoryVisibility, "")) if hist: visibility_ids.add(hist) @@ -308,7 +309,7 @@ def filter_events_for_server(store, server_name, events): event_map = yield store.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in event_map.itervalues() + for e in itervalues(event_map) ) if all_open: @@ -346,7 +347,7 @@ def filter_events_for_server(store, server_name, events): # state_key_to_event_id_set = { e - for key_to_eid in six.itervalues(event_to_state_ids) + for key_to_eid in itervalues(event_to_state_ids) for e in key_to_eid.items() } @@ -369,10 +370,10 @@ def filter_events_for_server(store, server_name, events): event_to_state = { e_id: { key: event_map[inner_e_id] - for key, inner_e_id in key_to_eid.iteritems() + for key, inner_e_id in iteritems(key_to_eid) if inner_e_id in event_map } - for e_id, key_to_eid in event_to_state_ids.iteritems() + for e_id, key_to_eid in iteritems(event_to_state_ids) } defer.returnValue([ diff --git a/tests/test_federation.py b/tests/test_federation.py index 159a136971..f40ff29b52 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase): ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") - @unittest.DEBUG def test_cant_hide_past_history(self): """ If you send a message, you must be able to provide the direct @@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase): for x, y in d.items() if x == ("m.room.member", "@us:test") ], - "auth_chain_ids": d.values(), + "auth_chain_ids": list(d.values()), } ) -- cgit 1.5.1 From e42510ba635b3e4d83215e4f5634ca51411996e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:00:22 +0100 Subject: Use new getters --- synapse/api/auth.py | 6 ++++-- synapse/handlers/_base.py | 3 ++- synapse/handlers/federation.py | 23 ++++++++++++++++------- synapse/handlers/message.py | 26 ++++++++++++++++---------- synapse/handlers/room_member.py | 9 ++++++--- synapse/push/bulk_push_rule_evaluator.py | 7 ++++--- synapse/storage/events.py | 2 +- synapse/storage/push_rule.py | 7 +++++-- synapse/storage/roommember.py | 7 +++++-- 9 files changed, 59 insertions(+), 31 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index bc629832d9..535bdb449d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -65,8 +65,9 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -544,7 +545,8 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) + prev_state_ids = yield context.get_prev_state_ids(self.store) + auth_ids = yield self.compute_auth_events(builder, prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b6a8b3aa3b..704181d2d3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -112,8 +112,9 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: + current_state_ids = yield context.get_current_state_ids(self.store) current_state = yield self.store.get_events( - list(context.current_state_ids.values()) + list(current_state_ids.values()) ) else: current_state = yield self.state_handler.get_current_state( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a6d391c4e8..98dd4a7fd1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -486,7 +486,10 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.prev_state_ids.get( + + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_state_id = prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -1106,10 +1109,12 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - state_ids = list(context.prev_state_ids.values()) + prev_state_ids = yield context.get_prev_state_ids(self.store) + + state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) - state = yield self.store.get_events(list(context.prev_state_ids.values())) + state = yield self.store.get_events(list(prev_state_ids.values())) defer.returnValue({ "state": list(state.values()), @@ -1635,8 +1640,9 @@ class FederationHandler(BaseHandler): ) if not auth_events: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1876,9 +1882,10 @@ class FederationHandler(BaseHandler): break if do_resolution: + prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids + event, prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain( auth_ids, include_given=True @@ -2222,7 +2229,8 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.prev_state_ids.get(key) + prev_state_ids = yield context.get_prev_state_ids(self.store) + original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -2264,7 +2272,8 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + invite_event_id = prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index abc07ea87c..c4bcd9018b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -630,7 +630,8 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -752,8 +753,8 @@ class EventCreationHandler(object): event = builder.build() logger.debug( - "Created event %s with state: %s", - event.event_id, context.prev_state_ids, + "Created event %s", + event.event_id, ) defer.returnValue( @@ -884,9 +885,11 @@ class EventCreationHandler(object): e.sender == event.sender ) + current_state_ids = yield context.get_current_state_ids(self.store) + state_to_include_ids = [ e_id - for k, e_id in iteritems(context.current_state_ids) + for k, e_id in iteritems(current_state_ids) if k[0] in self.hs.config.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -922,8 +925,9 @@ class EventCreationHandler(object): ) if event.type == EventTypes.Redaction: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -943,11 +947,13 @@ class EventCreationHandler(object): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + if event.type == EventTypes.Create: + prev_state_ids = yield context.get_prev_state_ids(self.store) + if prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00f2e279bc..a832d91809 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -201,7 +201,9 @@ class RoomMemberHandler(object): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -496,9 +498,10 @@ class RoomMemberHandler(object): if prev_event is not None: return + prev_state_ids = yield context.get_prev_state_ids(self.store) if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.prev_state_ids) + guest_can_join = yield self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -517,7 +520,7 @@ class RoomMemberHandler(object): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, event.state_key), None ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bb181d94ee..1d14d3639c 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -112,7 +112,8 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - pl_event_id = context.prev_state_ids.get(POWER_KEY) + prev_state_ids = yield context.get_prev_state_ids(self.store) + pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case @@ -120,7 +121,7 @@ class BulkPushRuleEvaluator(object): auth_events = {POWER_KEY: pl_event} else: auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=False, + event, prev_state_ids, for_verification=False, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -304,7 +305,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = context.current_state_ids + current_state_ids = yield context.get_current_state_ids(self.store) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4ff0fdc4ab..bf4f3ee92a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore): if ctx.state_group in state_groups_map: continue - state_groups_map[ctx.state_group] = ctx.current_state_ids + state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self) # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index be655d287b..af564b1b4e 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -186,6 +186,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, defer.returnValue(results) + @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: @@ -195,9 +196,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room( - event.room_id, state_group, context.current_state_ids, event=event + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event ) + defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 02a802bed9..a27702a7a0 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): defer.returnValue(user_who_share_room) + @defer.inlineCallbacks def get_joined_users_from_context(self, event, context): state_group = context.state_group if not state_group: @@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context, ) + defer.returnValue(result) def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group -- cgit 1.5.1 From 027bc01a1bc254fe08140c6e91a9fb945b08486f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:02:09 +0100 Subject: Add support for updating state --- synapse/events/snapshot.py | 19 +++++++++++++++++++ synapse/handlers/federation.py | 32 +++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 9 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f9568638a1..b090751bf1 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -228,6 +228,25 @@ class EventContext(object): else: self._prev_state_ids = self._current_state_ids + @defer.inlineCallbacks + def update_state(self, state_group, prev_state_ids, current_state_ids, + delta_ids): + """Replace the state in the context + """ + + # We need to make sure we wait for any ongoing fetching of state + # to complete so that the updated state doesn't get clobbered + if self._fetching_state_deferred: + yield make_deferred_yieldable(self._fetching_state_deferred) + + self.state_group = state_group + self._prev_state_ids = prev_state_ids + self._current_state_ids = current_state_ids + self.delta_ids = delta_ids + + # We need to ensure that that we've marked as having fetched the state + self._fetching_state_deferred = defer.succeed(None) + def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 98dd4a7fd1..14654d59f1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1975,21 +1975,35 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update(state_updates) + current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = dict(current_state_ids) + + current_state_ids.update(state_updates) + if context.delta_ids is not None: - context.delta_ids = dict(context.delta_ids) - context.delta_ids.update(state_updates) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ + delta_ids = dict(context.delta_ids) + delta_ids.update(state_updates) + + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({ k: a.event_id for k, a in iteritems(auth_events) }) - context.state_group = yield self.store.store_state_group( + + state_group = yield self.store.store_state_group( event.event_id, event.room_id, prev_group=context.prev_group, - delta_ids=context.delta_ids, - current_state_ids=context.current_state_ids, + delta_ids=delta_ids, + current_state_ids=current_state_ids, + ) + + yield context.update_state( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + delta_ids=delta_ids, ) @defer.inlineCallbacks -- cgit 1.5.1