summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/_base.py51
-rw-r--r--synapse/handlers/room_member.py36
-rw-r--r--synapse/replication/slave/storage/events.py9
-rw-r--r--synapse/storage/util/id_generators.py2
-rw-r--r--synapse/util/async.py50
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py142
-rw-r--r--tests/util/test_linearizer.py44
9 files changed, 319 insertions, 19 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c77afe7f51..88d8b9ba54 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = (
 )
 
 
+MEMBERSHIP_PRIORITY = (
+    Membership.JOIN,
+    Membership.INVITE,
+    Membership.KNOCK,
+    Membership.LEAVE,
+    Membership.BAN,
+)
+
+
 class BaseHandler(object):
     """
     Common base class for the event handlers.
@@ -72,6 +81,7 @@ class BaseHandler(object):
                 * the user is not currently a member of the room, and:
                 * the user has not been a member of the room since the
                 given events
+            events ([synapse.events.EventBase]): list of events to filter
         """
         forgotten = yield defer.gatherResults([
             self.store.who_forgot_in_room(
@@ -86,6 +96,12 @@ class BaseHandler(object):
         )
 
         def allowed(event, user_id, is_peeking):
+            """
+            Args:
+                event (synapse.events.EventBase): event to check
+                user_id (str)
+                is_peeking (bool)
+            """
             state = event_id_to_state[event.event_id]
 
             # get the room_visibility at the time of the event.
@@ -117,17 +133,30 @@ class BaseHandler(object):
                 if old_priority < new_priority:
                     visibility = prev_visibility
 
-            # get the user's membership at the time of the event. (or rather,
-            # just *after* the event. Which means that people can see their
-            # own join events, but not (currently) their own leave events.)
-            membership_event = state.get((EventTypes.Member, user_id), None)
-            if membership_event:
-                if membership_event.event_id in event_id_forgotten:
-                    membership = None
-                else:
-                    membership = membership_event.membership
-            else:
-                membership = None
+            # likewise, if the event is the user's own membership event, use
+            # the 'most joined' membership
+            membership = None
+            if event.type == EventTypes.Member and event.state_key == user_id:
+                membership = event.content.get("membership", None)
+                if membership not in MEMBERSHIP_PRIORITY:
+                    membership = "leave"
+
+                prev_content = event.unsigned.get("prev_content", {})
+                prev_membership = prev_content.get("membership", None)
+                if prev_membership not in MEMBERSHIP_PRIORITY:
+                    prev_membership = "leave"
+
+                new_priority = MEMBERSHIP_PRIORITY.index(membership)
+                old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
+                if old_priority < new_priority:
+                    membership = prev_membership
+
+            # otherwise, get the user's membership at the time of the event.
+            if membership is None:
+                membership_event = state.get((EventTypes.Member, user_id), None)
+                if membership_event:
+                    if membership_event.event_id not in event_id_forgotten:
+                        membership = membership_event.membership
 
             # if the user was a member of the room at the time of the event,
             # they can see it.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index fe2315df8f..b6ef3c91af 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
 )
 from synapse.api.errors import AuthError, SynapseError, Codes
 from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.async import Linearizer
 
 from signedjson.sign import verify_signed_json
 from signedjson.key import decode_verify_key_bytes
@@ -60,6 +61,8 @@ class RoomMemberHandler(BaseHandler):
     def __init__(self, hs):
         super(RoomMemberHandler, self).__init__(hs)
 
+        self.member_linearizer = Linearizer()
+
         self.clock = hs.get_clock()
 
         self.distributor = hs.get_distributor()
@@ -183,6 +186,34 @@ class RoomMemberHandler(BaseHandler):
             third_party_signed=None,
             ratelimit=True,
     ):
+        key = (target, room_id,)
+
+        with (yield self.member_linearizer.queue(key)):
+            result = yield self._update_membership(
+                requester,
+                target,
+                room_id,
+                action,
+                txn_id=txn_id,
+                remote_room_hosts=remote_room_hosts,
+                third_party_signed=third_party_signed,
+                ratelimit=ratelimit,
+            )
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _update_membership(
+            self,
+            requester,
+            target,
+            room_id,
+            action,
+            txn_id=None,
+            remote_room_hosts=None,
+            third_party_signed=None,
+            ratelimit=True,
+    ):
         effective_membership_state = action
         if action in ["kick", "unban"]:
             effective_membership_state = "leave"
@@ -233,6 +264,11 @@ class RoomMemberHandler(BaseHandler):
                     remote_room_hosts.append(inviter.domain)
 
                 content = {"membership": Membership.JOIN}
+
+                profile = self.hs.get_handlers().profile_handler
+                content["displayname"] = yield profile.get_displayname(target)
+                content["avatar_url"] = yield profile.get_avatar_url(target)
+
                 if requester.is_guest:
                     content["kind"] = "guest"
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 680dc89536..cfc728a038 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore):
         "_get_current_state_for_key"
     ]
 
+    get_event = DataStore.get_event.__func__
     get_current_state = DataStore.get_current_state.__func__
     get_current_state_for_key = DataStore.get_current_state_for_key.__func__
     get_rooms_for_user_where_membership_is = (
@@ -89,8 +90,11 @@ class SlavedEventStore(BaseSlavedStore):
     _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
     _parse_events_txn = DataStore._parse_events_txn.__func__
     _get_events_txn = DataStore._get_events_txn.__func__
+    _enqueue_events = DataStore._enqueue_events.__func__
+    _do_fetch = DataStore._do_fetch.__func__
     _fetch_events_txn = DataStore._fetch_events_txn.__func__
     _fetch_event_rows = DataStore._fetch_event_rows.__func__
+    _get_event_from_row = DataStore._get_event_from_row.__func__
     _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
     _get_rooms_for_user_where_membership_is_txn = (
         DataStore._get_rooms_for_user_where_membership_is_txn.__func__
@@ -100,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore):
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
         result["events"] = self._stream_id_gen.get_current_token()
-        result["backfilled"] = self._backfill_id_gen.get_current_token()
+        result["backfill"] = self._backfill_id_gen.get_current_token()
         return result
 
     def process_replication(self, result):
@@ -142,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore):
         position = row[0]
         internal = json.loads(row[1])
         event_json = json.loads(row[2])
-
         event = FrozenEvent(event_json, internal_metadata_dict=internal)
         self._invalidate_caches_for_event(
             event, backfilled, reset_state=position in state_resets
@@ -158,6 +161,8 @@ class SlavedEventStore(BaseSlavedStore):
 
         self._invalidate_get_event_cache(event.event_id)
 
+        self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+
         if not backfilled:
             self._events_stream_cache.entity_has_changed(
                 event.room_id, event.internal_metadata.stream_ordering
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f69f1cdad4..46cf93ff87 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -112,7 +112,7 @@ class StreamIdGenerator(object):
                 self._current + self._step * (n + 1),
                 self._step
             )
-            self._current += n
+            self._current += n * self._step
 
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
diff --git a/synapse/util/async.py b/synapse/util/async.py
index cd4d90f3cf..0d6f48e2d8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,9 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext, preserve_fn
+from .logcontext import (
+    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
 from synapse.util import unwrapFirstError
 
+from contextlib import contextmanager
+
 
 @defer.inlineCallbacks
 def sleep(seconds):
@@ -137,3 +141,47 @@ def concurrently_execute(func, args, limit):
         preserve_fn(_concurrently_execute_inner)()
         for _ in xrange(limit)
     ], consumeErrors=True).addErrback(unwrapFirstError)
+
+
+class Linearizer(object):
+    """Linearizes access to resources based on a key. Useful to ensure only one
+    thing is happening at a time on a given resource.
+
+    Example:
+
+        with (yield linearizer.queue("test_key")):
+            # do some work.
+
+    """
+    def __init__(self):
+        self.key_to_defer = {}
+
+    @defer.inlineCallbacks
+    def queue(self, key):
+        # If there is already a deferred in the queue, we pull it out so that
+        # we can wait on it later.
+        # Then we replace it with a deferred that we resolve *after* the
+        # context manager has exited.
+        # We only return the context manager after the previous deferred has
+        # resolved.
+        # This all has the net effect of creating a chain of deferreds that
+        # wait for the previous deferred before starting their work.
+        current_defer = self.key_to_defer.get(key)
+
+        new_defer = defer.Deferred()
+        self.key_to_defer[key] = new_defer
+
+        if current_defer:
+            yield preserve_context_over_deferred(current_defer)
+
+        @contextmanager
+        def _ctx_manager():
+            try:
+                yield
+            finally:
+                new_defer.callback(None)
+                current_d = self.key_to_defer.get(key)
+                if current_d is new_defer:
+                    self.key_to_defer.pop(key, None)
+
+        defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index be310ba320..36686b479e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,7 +35,7 @@ class ResponseCache(object):
             return None
 
     def set(self, key, deferred):
-        result = ObservableDeferred(deferred)
+        result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
         def remove(r):
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 0f525a8943..983caafe8a 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
     def check(self, method, args, expected_result=None):
         master_result = yield getattr(self.master_store, method)(*args)
         slaved_result = yield getattr(self.slaved_store, method)(*args)
-        self.assertEqual(master_result, slaved_result)
         if expected_result is not None:
             self.assertEqual(master_result, expected_result)
             self.assertEqual(slaved_result, expected_result)
+        self.assertEqual(master_result, slaved_result)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 351d777fb2..baa4a26eb5 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -14,20 +14,47 @@
 
 from ._base import BaseSlavedStoreTestCase
 
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
 from synapse.storage.roommember import RoomsForUser
 
 from twisted.internet import defer
 
+
 USER_ID = "@feeling:blue"
 USER_ID_2 = "@bright:blue"
 OUTLIER = {"outlier": True}
 ROOM_ID = "!room:blue"
 
 
+def dict_equals(self, other):
+    return self.__dict__ == other.__dict__
+
+
+def patch__eq__(cls):
+    eq = getattr(cls, "__eq__", None)
+    cls.__eq__ = dict_equals
+
+    def unpatch():
+        if eq is not None:
+            cls.__eq__ = eq
+    return unpatch
+
+
 class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
+    def setUp(self):
+        # Patch up the equality operator for events so that we can check
+        # whether lists of events match using assertEquals
+        self.unpatches = [
+            patch__eq__(_EventInternalMetadata),
+            patch__eq__(FrozenEvent),
+        ]
+        return super(SlavedEventStoreTestCase, self).setUp()
+
+    def tearDown(self):
+        [unpatch() for unpatch in self.unpatches]
+
     @defer.inlineCallbacks
     def test_room_name_and_aliases(self):
         create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -116,13 +143,121 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
         yield self.check("get_rooms_for_user", (USER_ID_2,), [])
 
+    @defer.inlineCallbacks
+    def test_get_latest_event_ids_in_room(self):
+        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.replicate()
+        yield self.check(
+            "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
+        )
+
+        join = yield self.persist(
+            type="m.room.member", key=USER_ID, membership="join",
+            prev_events=[(create.event_id, {})],
+        )
+        yield self.replicate()
+        yield self.check(
+            "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
+        )
+
+    @defer.inlineCallbacks
+    def test_get_current_state(self):
+        # Create the room.
+        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.replicate()
+        yield self.check(
+            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
+        )
+
+        # Join the room.
+        join1 = yield self.persist(
+            type="m.room.member", key=USER_ID, membership="join",
+        )
+        yield self.replicate()
+        yield self.check(
+            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+            [join1]
+        )
+
+        # Add some other user to the room.
+        join2 = yield self.persist(
+            type="m.room.member", key=USER_ID_2, membership="join",
+        )
+        yield self.replicate()
+        yield self.check(
+            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+            [join2]
+        )
+
+        # Leave the room, then rejoin the room clobbering state.
+        yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+        join3 = yield self.persist(
+            type="m.room.member", key=USER_ID, membership="join",
+            reset_state=[create]
+        )
+        yield self.replicate()
+        yield self.check(
+            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+            []
+        )
+        yield self.check(
+            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+            [join3]
+        )
+
+    @defer.inlineCallbacks
+    def test_redactions(self):
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+        msg = yield self.persist(
+            type="m.room.message", msgtype="m.text", body="Hello"
+        )
+        yield self.replicate()
+        yield self.check("get_event", [msg.event_id], msg)
+
+        redaction = yield self.persist(
+            type="m.room.redaction", redacts=msg.event_id
+        )
+        yield self.replicate()
+
+        msg_dict = msg.get_dict()
+        msg_dict["content"] = {}
+        msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+        msg_dict["unsigned"]["redacted_because"] = redaction
+        redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+        yield self.check("get_event", [msg.event_id], redacted)
+
+    @defer.inlineCallbacks
+    def test_backfilled_redactions(self):
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+        msg = yield self.persist(
+            type="m.room.message", msgtype="m.text", body="Hello"
+        )
+        yield self.replicate()
+        yield self.check("get_event", [msg.event_id], msg)
+
+        redaction = yield self.persist(
+            type="m.room.redaction", redacts=msg.event_id, backfill=True
+        )
+        yield self.replicate()
+
+        msg_dict = msg.get_dict()
+        msg_dict["content"] = {}
+        msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+        msg_dict["unsigned"]["redacted_because"] = redaction
+        redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+        yield self.check("get_event", [msg.event_id], redacted)
+
     event_id = 0
 
     @defer.inlineCallbacks
     def persist(
         self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
         state=None, reset_state=False, backfill=False,
-        depth=None, prev_events=[], auth_events=[], prev_state=[],
+        depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
         **content
     ):
         """
@@ -147,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             event_dict["state_key"] = key
             event_dict["prev_state"] = prev_state
 
+        if redacts is not None:
+            event_dict["redacts"] = redacts
+
         event = FrozenEvent(event_dict, internal_metadata_dict=internal)
 
         self.event_id += 1
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
new file mode 100644
index 0000000000..afcba482f9
--- /dev/null
+++ b/tests/util/test_linearizer.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from twisted.internet import defer
+
+from synapse.util.async import Linearizer
+
+
+class LinearizerTestCase(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def test_linearizer(self):
+        linearizer = Linearizer()
+
+        key = object()
+
+        d1 = linearizer.queue(key)
+        cm1 = yield d1
+
+        d2 = linearizer.queue(key)
+        self.assertFalse(d2.called)
+
+        with cm1:
+            self.assertFalse(d2.called)
+
+        self.assertTrue(d2.called)
+
+        with (yield d2):
+            pass