summary refs log tree commit diff
path: root/synapse/storage/push_rule.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/push_rule.py')
-rw-r--r--synapse/storage/push_rule.py35
1 files changed, 14 insertions, 21 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 04a0b59a39..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,20 +14,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
 from synapse.storage.appservice import ApplicationServiceWorkerStore
 from synapse.storage.pusher import PusherWorkerStore
 from synapse.storage.receipts import ReceiptsWorkerStore
 from synapse.storage.roommember import RoomMemberWorkerStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
-from twisted.internet import defer
 
-import abc
-import logging
-import simplejson as json
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -183,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
 
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
     def bulk_get_push_rules_for_room(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -192,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        return self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, context.current_state_ids, event=event
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
         )
+        defer.returnValue(result)
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -244,18 +249,6 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
-        forgotten = yield self.who_forgot_in_room(
-            event.room_id, on_invalidate=cache_context.invalidate,
-        )
-
-        for row in forgotten:
-            user_id = row["user_id"]
-            event_id = row["event_id"]
-
-            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
-            if event_id == mem_id:
-                user_ids.discard(user_id)
-
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )