summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/room_member.py31
-rw-r--r--synapse/util/async.py42
-rw-r--r--synapse/util/caches/response_cache.py2
3 files changed, 74 insertions, 1 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index fe2315df8f..0fcc9445a8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
 )
 from synapse.api.errors import AuthError, SynapseError, Codes
 from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.async import Linearizer
 
 from signedjson.sign import verify_signed_json
 from signedjson.key import decode_verify_key_bytes
@@ -60,6 +61,8 @@ class RoomMemberHandler(BaseHandler):
     def __init__(self, hs):
         super(RoomMemberHandler, self).__init__(hs)
 
+        self.member_linearizer = Linearizer()
+
         self.clock = hs.get_clock()
 
         self.distributor = hs.get_distributor()
@@ -183,6 +186,34 @@ class RoomMemberHandler(BaseHandler):
             third_party_signed=None,
             ratelimit=True,
     ):
+        key = (target, room_id,)
+
+        with (yield self.member_linearizer.queue(key)):
+            result = yield self._update_membership(
+                requester,
+                target,
+                room_id,
+                action,
+                txn_id=txn_id,
+                remote_room_hosts=remote_room_hosts,
+                third_party_signed=third_party_signed,
+                ratelimit=ratelimit,
+            )
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _update_membership(
+            self,
+            requester,
+            target,
+            room_id,
+            action,
+            txn_id=None,
+            remote_room_hosts=None,
+            third_party_signed=None,
+            ratelimit=True,
+    ):
         effective_membership_state = action
         if action in ["kick", "unban"]:
             effective_membership_state = "leave"
diff --git a/synapse/util/async.py b/synapse/util/async.py
index cd4d90f3cf..408c86be91 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -19,6 +19,8 @@ from twisted.internet import defer, reactor
 from .logcontext import PreserveLoggingContext, preserve_fn
 from synapse.util import unwrapFirstError
 
+from contextlib import contextmanager
+
 
 @defer.inlineCallbacks
 def sleep(seconds):
@@ -137,3 +139,43 @@ def concurrently_execute(func, args, limit):
         preserve_fn(_concurrently_execute_inner)()
         for _ in xrange(limit)
     ], consumeErrors=True).addErrback(unwrapFirstError)
+
+
+@contextmanager
+def _trigger_defer_manager(d):
+    try:
+        yield
+    finally:
+        d.callback(None)
+
+
+class Linearizer(object):
+    """Linearizes access to resources based on a key. Useful to ensure only one
+    thing is happening at a time on a given resource.
+
+    Example:
+
+        with (yield linearizer.queue("test_key")):
+            # do some work.
+
+    """
+    def __init__(self):
+        self.key_to_defer = {}
+
+    @defer.inlineCallbacks
+    def queue(self, key):
+        current_defer = self.key_to_defer.get(key)
+
+        new_defer = defer.Deferred()
+        self.key_to_defer[key] = new_defer
+
+        def remove_if_current(_):
+            d = self.key_to_defer.get(key)
+            if d is new_defer:
+                self.key_to_defer.pop(key, None)
+
+        new_defer.addBoth(remove_if_current)
+
+        yield current_defer
+
+        defer.returnValue(_trigger_defer_manager(new_defer))
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index be310ba320..36686b479e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,7 +35,7 @@ class ResponseCache(object):
             return None
 
     def set(self, key, deferred):
-        result = ObservableDeferred(deferred)
+        result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
         def remove(r):