summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/appservice/__init__.py6
-rw-r--r--synapse/federation/transaction_queue.py31
-rw-r--r--synapse/federation/transport/server.py8
-rw-r--r--synapse/handlers/appservice.py7
-rw-r--r--synapse/handlers/presence.py163
-rw-r--r--synapse/storage/event_federation.py1
-rw-r--r--tests/appservice/test_appservice.py15
-rw-r--r--tests/handlers/test_presence.py13
-rw-r--r--tests/rest/client/v1/test_presence.py3
9 files changed, 145 insertions, 102 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 63a18b802b..e3ca45de83 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -148,8 +148,8 @@ class ApplicationService(object):
                 and self.is_interested_in_user(event.state_key)):
             return True
         # check joined member events
-        for member in member_list:
-            if self.is_interested_in_user(member.state_key):
+        for user_id in member_list:
+            if self.is_interested_in_user(user_id):
                 return True
         return False
 
@@ -173,7 +173,7 @@ class ApplicationService(object):
             restrict_to(str): The namespace to restrict regex tests to.
             aliases_for_event(list): A list of all the known room aliases for
             this event.
-            member_list(list): A list of all joined room members in this room.
+            member_list(list): A list of all joined user_ids in this room.
         Returns:
             bool: True if this service would like to know about this event.
         """
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index ca04822fb3..32fa5e8c15 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -207,13 +207,13 @@ class TransactionQueue(object):
             # request at which point pending_pdus_by_dest just keeps growing.
             # we need application-layer timeouts of some flavour of these
             # requests
-            logger.info(
+            logger.debug(
                 "TX [%s] Transaction already in progress",
                 destination
             )
             return
 
-        logger.info("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
         # list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@@ -221,11 +221,11 @@ class TransactionQueue(object):
         pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
         if pending_pdus:
-            logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                        destination, len(pending_pdus))
+            logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                         destination, len(pending_pdus))
 
         if not pending_pdus and not pending_edus and not pending_failures:
-            logger.info("TX [%s] Nothing to send", destination)
+            logger.debug("TX [%s] Nothing to send", destination)
             return
 
         # Sort based on the order field
@@ -242,6 +242,8 @@ class TransactionQueue(object):
         try:
             self.pending_transactions[destination] = 1
 
+            txn_id = str(self._next_txn_id)
+
             limiter = yield get_retry_limiter(
                 destination,
                 self._clock,
@@ -249,9 +251,9 @@ class TransactionQueue(object):
             )
 
             logger.debug(
-                "TX [%s] Attempting new transaction"
+                "TX [%s] {%s} Attempting new transaction"
                 " (pdus: %d, edus: %d, failures: %d)",
-                destination,
+                destination, txn_id,
                 len(pending_pdus),
                 len(pending_edus),
                 len(pending_failures)
@@ -261,7 +263,7 @@ class TransactionQueue(object):
 
             transaction = Transaction.create_new(
                 origin_server_ts=int(self._clock.time_msec()),
-                transaction_id=str(self._next_txn_id),
+                transaction_id=txn_id,
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
@@ -275,9 +277,13 @@ class TransactionQueue(object):
 
             logger.debug("TX [%s] Persisted transaction", destination)
             logger.info(
-                "TX [%s] Sending transaction [%s]",
-                destination,
+                "TX [%s] {%s} Sending transaction [%s],"
+                " (PDUs: %d, EDUs: %d, failures: %d)",
+                destination, txn_id,
                 transaction.transaction_id,
+                len(pending_pdus),
+                len(pending_edus),
+                len(pending_failures),
             )
 
             with limiter:
@@ -313,7 +319,10 @@ class TransactionQueue(object):
                     code = e.code
                     response = e.response
 
-                logger.info("TX [%s] got %d response", destination, code)
+                logger.info(
+                    "TX [%s] {%s} got %d response",
+                    destination, txn_id, code
+                )
 
                 logger.debug("TX [%s] Sent transaction", destination)
                 logger.debug("TX [%s] Marking as delivered...", destination)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2bfe0f3c9b..af87805f34 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
                 transaction_id, str(transaction_data)
             )
 
+            logger.info(
+                "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+                transaction_id, origin,
+                len(transaction_data.get("pdus", [])),
+                len(transaction_data.get("edus", [])),
+                len(transaction_data.get("failures", [])),
+            )
+
             # We should ideally be getting this from the security layer.
             # origin = body["origin"]
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 355ab317df..8269482e47 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
 from synapse.appservice import ApplicationService
 from synapse.types import UserID
 
@@ -147,10 +147,7 @@ class ApplicationServicesHandler(object):
                 )
             # We need to know the members associated with this event.room_id,
             # if any.
-            member_list = yield self.store.get_room_members(
-                room_id=event.room_id,
-                membership=Membership.JOIN
-            )
+            member_list = yield self.store.get_users_in_room(event.room_id)
 
         services = yield self.store.get_app_services()
         interested_list = [
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 40794187b1..670c1d353f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
         self._user_cachemap = {}
         self._user_cachemap_latest_serial = 0
 
+        # map room_ids to the latest presence serial for a member of that
+        # room
+        self._room_serials = {}
+
         metrics.register_callback(
             "userCachemap:size",
             lambda: len(self._user_cachemap),
@@ -297,13 +301,34 @@ class PresenceHandler(BaseHandler):
 
         self.changed_presencelike_data(user, {"last_active": now})
 
+    def get_joined_rooms_for_user(self, user):
+        """Get the list of rooms a user is joined to.
+
+        Args:
+            user(UserID): The user.
+        Returns:
+            A Deferred of a list of room id strings.
+        """
+        rm_handler = self.homeserver.get_handlers().room_member_handler
+        return rm_handler.get_joined_rooms_for_user(user)
+
+    def get_joined_users_for_room_id(self, room_id):
+        rm_handler = self.homeserver.get_handlers().room_member_handler
+        return rm_handler.get_room_members(room_id)
+
+    @defer.inlineCallbacks
     def changed_presencelike_data(self, user, state):
-        statuscache = self._get_or_make_usercache(user)
+        """Updates the presence state of a local user.
 
+        Args:
+            user(UserID): The user being updated.
+            state(dict): The new presence state for the user.
+        Returns:
+            A Deferred
+        """
         self._user_cachemap_latest_serial += 1
-        statuscache.update(state, serial=self._user_cachemap_latest_serial)
-
-        return self.push_presence(user, statuscache=statuscache)
+        statuscache = yield self.update_presence_cache(user, state)
+        yield self.push_presence(user, statuscache=statuscache)
 
     @log_function
     def started_user_eventstream(self, user):
@@ -326,13 +351,12 @@ class PresenceHandler(BaseHandler):
             room_id(str): The room id the user joined.
         """
         if self.hs.is_mine(user):
-            statuscache = self._get_or_make_usercache(user)
-
             # No actual update but we need to bump the serial anyway for the
             # event source
             self._user_cachemap_latest_serial += 1
-            statuscache.update({}, serial=self._user_cachemap_latest_serial)
-
+            statuscache = yield self.update_presence_cache(
+                user, room_ids=[room_id]
+            )
             self.push_update_to_local_and_remote(
                 observed_user=user,
                 room_ids=[room_id],
@@ -340,16 +364,17 @@ class PresenceHandler(BaseHandler):
             )
 
         # We also want to tell them about current presence of people.
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        curr_users = yield rm_handler.get_room_members(room_id)
+        curr_users = yield self.get_joined_users_for_room_id(room_id)
 
         for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
-            statuscache = self._get_or_offline_usercache(local_user)
-            statuscache.update({}, serial=self._user_cachemap_latest_serial)
+            statuscache = yield self.update_presence_cache(
+                local_user, room_ids=[room_id], add_to_cache=False
+            )
+
             self.push_update_to_local_and_remote(
                 observed_user=local_user,
                 users_to_push=[user],
-                statuscache=self._get_or_offline_usercache(local_user),
+                statuscache=statuscache,
             )
 
     @defer.inlineCallbacks
@@ -546,8 +571,7 @@ class PresenceHandler(BaseHandler):
 
             # Also include people in all my rooms
 
-            rm_handler = self.homeserver.get_handlers().room_member_handler
-            room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+            room_ids = yield self.get_joined_rooms_for_user(user)
 
         if state is None:
             state = yield self.store.get_presence_state(user.localpart)
@@ -747,8 +771,7 @@ class PresenceHandler(BaseHandler):
         # and also user is informed of server-forced pushes
         localusers.add(user)
 
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+        room_ids = yield self.get_joined_rooms_for_user(user)
 
         if not localusers and not room_ids:
             defer.returnValue(None)
@@ -793,8 +816,7 @@ class PresenceHandler(BaseHandler):
                     " | %d interested local observers %r", len(observers), observers
                 )
 
-            rm_handler = self.homeserver.get_handlers().room_member_handler
-            room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+            room_ids = yield self.get_joined_rooms_for_user(user)
             if room_ids:
                 logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
 
@@ -813,10 +835,8 @@ class PresenceHandler(BaseHandler):
                     self.clock.time_msec() - state.pop("last_active_ago")
                 )
 
-            statuscache = self._get_or_make_usercache(user)
-
             self._user_cachemap_latest_serial += 1
-            statuscache.update(state, serial=self._user_cachemap_latest_serial)
+            yield self.update_presence_cache(user, state, room_ids=room_ids)
 
             if not observers and not room_ids:
                 logger.debug(" | no interested observers or room IDs")
@@ -875,6 +895,35 @@ class PresenceHandler(BaseHandler):
         yield defer.DeferredList(deferreds, consumeErrors=True)
 
     @defer.inlineCallbacks
+    def update_presence_cache(self, user, state={}, room_ids=None,
+                              add_to_cache=True):
+        """Update the presence cache for a user with a new state and bump the
+        serial to the latest value.
+
+        Args:
+            user(UserID): The user being updated
+            state(dict): The presence state being updated
+            room_ids(None or list of str): A list of room_ids to update. If
+                room_ids is None then fetch the list of room_ids the user is
+                joined to.
+            add_to_cache: Whether to add an entry to the presence cache if the
+                user isn't already in the cache.
+        Returns:
+            A Deferred UserPresenceCache for the user being updated.
+        """
+        if room_ids is None:
+            room_ids = yield self.get_joined_rooms_for_user(user)
+
+        for room_id in room_ids:
+            self._room_serials[room_id] = self._user_cachemap_latest_serial
+        if add_to_cache:
+            statuscache = self._get_or_make_usercache(user)
+        else:
+            statuscache = self._get_or_offline_usercache(user)
+        statuscache.update(state, serial=self._user_cachemap_latest_serial)
+        defer.returnValue(statuscache)
+
+    @defer.inlineCallbacks
     def push_update_to_local_and_remote(self, observed_user, statuscache,
                                         users_to_push=[], room_ids=[],
                                         remote_domains=[]):
@@ -997,38 +1046,10 @@ class PresenceEventSource(object):
         self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def is_visible(self, observer_user, observed_user):
-        if observer_user == observed_user:
-            defer.returnValue(True)
-
-        presence = self.hs.get_handlers().presence_handler
-
-        if (yield presence.store.user_rooms_intersect(
-                [u.to_string() for u in observer_user, observed_user])):
-            defer.returnValue(True)
-
-        if self.hs.is_mine(observed_user):
-            pushmap = presence._local_pushmap
-
-            defer.returnValue(
-                observed_user.localpart in pushmap and
-                observer_user in pushmap[observed_user.localpart]
-            )
-        else:
-            recvmap = presence._remote_recvmap
-
-            defer.returnValue(
-                observed_user in recvmap and
-                observer_user in recvmap[observed_user]
-            )
-
-    @defer.inlineCallbacks
     @log_function
     def get_new_events_for_user(self, user, from_key, limit):
         from_key = int(from_key)
 
-        observer_user = user
-
         presence = self.hs.get_handlers().presence_handler
         cachemap = presence._user_cachemap
 
@@ -1037,17 +1058,27 @@ class PresenceEventSource(object):
         clock = self.clock
         latest_serial = 0
 
+        user_ids_to_check = {user}
+        presence_list = yield presence.store.get_presence_list(
+            user.localpart, accepted=True
+        )
+        if presence_list is not None:
+            user_ids_to_check |= set(
+                UserID.from_string(p["observed_user_id"]) for p in presence_list
+            )
+        room_ids = yield presence.get_joined_rooms_for_user(user)
+        for room_id in set(room_ids) & set(presence._room_serials):
+            if presence._room_serials[room_id] > from_key:
+                joined = yield presence.get_joined_users_for_room_id(room_id)
+                user_ids_to_check |= set(joined)
+
         updates = []
-        # TODO(paul): use a DeferredList ? How to limit concurrency.
-        for observed_user in cachemap.keys():
+        for observed_user in user_ids_to_check & set(cachemap):
             cached = cachemap[observed_user]
 
             if cached.serial <= from_key or cached.serial > max_serial:
                 continue
 
-            if not (yield self.is_visible(observer_user, observed_user)):
-                continue
-
             latest_serial = max(cached.serial, latest_serial)
             updates.append(cached.make_event(user=observed_user, clock=clock))
 
@@ -1084,8 +1115,6 @@ class PresenceEventSource(object):
     def get_pagination_rows(self, user, pagination_config, key):
         # TODO (erikj): Does this make sense? Ordering?
 
-        observer_user = user
-
         from_key = int(pagination_config.from_key)
 
         if pagination_config.to_key:
@@ -1096,14 +1125,26 @@ class PresenceEventSource(object):
         presence = self.hs.get_handlers().presence_handler
         cachemap = presence._user_cachemap
 
+        user_ids_to_check = {user}
+        presence_list = yield presence.store.get_presence_list(
+            user.localpart, accepted=True
+        )
+        if presence_list is not None:
+            user_ids_to_check |= set(
+                UserID.from_string(p["observed_user_id"]) for p in presence_list
+            )
+        room_ids = yield presence.get_joined_rooms_for_user(user)
+        for room_id in set(room_ids) & set(presence._room_serials):
+            if presence._room_serials[room_id] >= from_key:
+                joined = yield presence.get_joined_users_for_room_id(room_id)
+                user_ids_to_check |= set(joined)
+
         updates = []
-        # TODO(paul): use a DeferredList ? How to limit concurrency.
-        for observed_user in cachemap.keys():
+        for observed_user in user_ids_to_check & set(cachemap):
             if not (to_key < cachemap[observed_user].serial <= from_key):
                 continue
 
-            if (yield self.is_visible(observer_user, observed_user)):
-                updates.append((observed_user, cachemap[observed_user]))
+            updates.append((observed_user, cachemap[observed_user]))
 
         # TODO(paul): limit
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5d4b7843f3..23573e8b2b 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -472,3 +472,4 @@ class EventFederationStore(SQLBaseStore):
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
+        txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 62149d6902..8ce8dc0a87 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -217,18 +217,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
             _regex("@irc_.*")
         )
         join_list = [
-            Mock(
-                type="m.room.member", room_id="!foo:bar", sender="@alice:here",
-                state_key="@alice:here"
-            ),
-            Mock(
-                type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
-                state_key="@irc_fo:here"  # AS user
-            ),
-            Mock(
-                type="m.room.member", room_id="!foo:bar", sender="@bob:here",
-                state_key="@bob:here"
-            )
+            "@alice:here",
+            "@irc_fo:here",  # AS user
+            "@bob:here",
         ]
 
         self.event.sender = "@xmpp_foobar:matrix.org"
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index ee773797e7..12cf5747a2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -624,6 +624,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
     """
     PRESENCE_LIST = {
             'apple': [ "@banana:test", "@clementine:test" ],
+            'banana': [ "@apple:test" ],
     }
 
     @defer.inlineCallbacks
@@ -836,12 +837,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
 
     @defer.inlineCallbacks
     def test_recv_remote(self):
-        # TODO(paul): Gut-wrenching
-        potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
-                set())
-        potato_set.add(self.u_apple)
-
-        self.room_members = [self.u_banana, self.u_potato]
+        self.room_members = [self.u_apple, self.u_banana, self.u_potato]
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
@@ -886,11 +882,8 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
     @defer.inlineCallbacks
     def test_recv_remote_offline(self):
         """ Various tests relating to SYN-261 """
-        potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
-                set())
-        potato_set.add(self.u_apple)
 
-        self.room_members = [self.u_banana, self.u_potato]
+        self.room_members = [self.u_apple, self.u_banana, self.u_potato]
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 21f42b3d3e..523b30cf8a 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -297,6 +297,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
             else:
                 return []
         hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
+        hs.handlers.room_member_handler.get_room_members = (
+            lambda r: self.room_members if r == "a-room" else []
+        )
 
         self.mock_datastore = hs.get_datastore()
         self.mock_datastore.get_app_service_by_token = Mock(return_value=None)