summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py8
-rw-r--r--synapse/handlers/_base.py22
-rw-r--r--synapse/handlers/sync.py22
-rw-r--r--synapse/storage/account_data.py50
4 files changed, 93 insertions, 9 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index cd699ef27f..4f5a4281fa 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,6 +15,8 @@
 from synapse.api.errors import SynapseError
 from synapse.types import UserID, RoomID
 
+from twisted.internet import defer
+
 import ujson as json
 
 
@@ -24,10 +26,10 @@ class Filtering(object):
         super(Filtering, self).__init__()
         self.store = hs.get_datastore()
 
+    @defer.inlineCallbacks
     def get_user_filter(self, user_localpart, filter_id):
-        result = self.store.get_user_filter(user_localpart, filter_id)
-        result.addCallback(FilterCollection)
-        return result
+        result = yield self.store.get_user_filter(user_localpart, filter_id)
+        defer.returnValue(FilterCollection(result))
 
     def add_user_filter(self, user_localpart, user_filter):
         self.check_valid_filter(user_filter)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 287024c1ca..745c8901ee 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -84,7 +84,7 @@ class BaseHandler(object):
             events ([synapse.events.EventBase]): list of events to filter
         """
         forgotten = yield defer.gatherResults([
-            self.store.who_forgot_in_room(
+            preserve_fn(self.store.who_forgot_in_room)(
                 room_id,
             )
             for room_id in frozenset(e.room_id for e in events)
@@ -95,13 +95,29 @@ class BaseHandler(object):
             row["event_id"] for rows in forgotten for row in rows
         )
 
-        def allowed(event, user_id, is_peeking):
+        ignore_dict_content = yield self.store.get_global_account_data_by_type_for_users(
+            "m.ignored_user_list", user_ids=[user_id for user_id, _ in user_tuples]
+        )
+
+        # FIXME: This will explode if people upload something incorrect.
+        ignore_dict = {
+            user_id: frozenset(
+                content.get("ignored_users", {}).keys() if content else []
+            )
+            for user_id, content in ignore_dict_content.items()
+        }
+
+        def allowed(event, user_id, is_peeking, ignore_list):
             """
             Args:
                 event (synapse.events.EventBase): event to check
                 user_id (str)
                 is_peeking (bool)
+                ignore_list (list): list of users to ignore
             """
+            if not event.is_state() and event.sender in ignore_list:
+                return False
+
             state = event_id_to_state[event.event_id]
 
             # get the room_visibility at the time of the event.
@@ -186,7 +202,7 @@ class BaseHandler(object):
             user_id: [
                 event
                 for event in events
-                if allowed(event, user_id, is_peeking)
+                if allowed(event, user_id, is_peeking, ignore_dict.get(user_id, []))
             ]
             for user_id, is_peeking in user_tuples
         })
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 231140b655..0bb1913285 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -247,6 +247,10 @@ class SyncHandler(BaseHandler):
             sync_config.user.to_string()
         )
 
+        ignored_users = account_data.get(
+            "m.ignored_user_list", {}
+        ).get("ignored_users", {}).keys()
+
         joined = []
         invited = []
         archived = []
@@ -267,6 +271,8 @@ class SyncHandler(BaseHandler):
                 )
                 joined.append(room_result)
             elif event.membership == Membership.INVITE:
+                if event.sender in ignored_users:
+                    return
                 invite = yield self.store.get_event(event.event_id)
                 invited.append(InvitedSyncResult(
                     room_id=event.room_id,
@@ -515,6 +521,15 @@ class SyncHandler(BaseHandler):
                 sync_config.user
             )
 
+        ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", user_id=user_id,
+        )
+
+        if ignored_account_data:
+            ignored_users = ignored_account_data.get("ignored_users", {}).keys()
+        else:
+            ignored_users = frozenset()
+
         # Get a list of membership change events that have happened.
         rooms_changed = yield self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
@@ -549,9 +564,10 @@ class SyncHandler(BaseHandler):
             # Only bother if we're still currently invited
             should_invite = non_joins[-1].membership == Membership.INVITE
             if should_invite:
-                room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
-                if room_sync:
-                    invited.append(room_sync)
+                if event.sender not in ignored_users:
+                    room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+                    if room_sync:
+                        invited.append(room_sync)
 
             # Always include leave/ban events. Just take the last one.
             # TODO: How do we handle ban -> leave in same batch?
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 7a7fbf1e52..ec7e8d40d2 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -16,6 +16,8 @@
 from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
 import ujson as json
 import logging
 
@@ -24,6 +26,7 @@ logger = logging.getLogger(__name__)
 
 class AccountDataStore(SQLBaseStore):
 
+    @cached()
     def get_account_data_for_user(self, user_id):
         """Get all the client account_data for a user.
 
@@ -60,6 +63,47 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2)
+    def get_global_account_data_by_type_for_user(self, data_type, user_id):
+        """
+        Returns:
+            Deferred: A dict
+        """
+        result = yield self._simple_select_one_onecol(
+            table="account_data",
+            keyvalues={
+                "user_id": user_id,
+                "account_data_type": data_type,
+            },
+            retcol="content",
+            desc="get_global_account_data_by_type_for_user",
+            allow_none=True,
+        )
+
+        if result:
+            defer.returnValue(json.loads(result))
+        else:
+            defer.returnValue(None)
+
+    @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
+                num_args=2, list_name="user_ids", inlineCallbacks=True)
+    def get_global_account_data_by_type_for_users(self, data_type, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="account_data",
+            column="user_id",
+            iterable=user_ids,
+            keyvalues={
+                "account_data_type": data_type,
+            },
+            retcols=("user_id", "content",),
+            desc="get_global_account_data_by_type_for_users",
+        )
+
+        defer.returnValue({
+            row["user_id"]: json.loads(row["content"]) if row["content"] else None
+            for row in rows
+        })
+
     def get_account_data_for_room(self, user_id, room_id):
         """Get all the client account_data for a user for a room.
 
@@ -193,6 +237,7 @@ class AccountDataStore(SQLBaseStore):
                 self._account_data_stream_cache.entity_has_changed,
                 user_id, next_id,
             )
+            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
             self._update_max_stream_id(txn, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
@@ -232,6 +277,11 @@ class AccountDataStore(SQLBaseStore):
                 self._account_data_stream_cache.entity_has_changed,
                 user_id, next_id,
             )
+            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
+            txn.call_after(
+                self.get_global_account_data_by_type_for_user.invalidate,
+                (account_data_type, user_id,)
+            )
             self._update_max_stream_id(txn, next_id)
 
         with self._account_data_id_gen.get_next() as next_id: