diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 04a0b59a39..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,20 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
from synapse.storage.appservice import ApplicationServiceWorkerStore
from synapse.storage.pusher import PusherWorkerStore
from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
-from twisted.internet import defer
-import abc
-import logging
-import simplejson as json
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -183,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results)
+ @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
@@ -192,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._bulk_get_push_rules_for_room(
- event.room_id, state_group, context.current_state_ids, event=event
+ current_state_ids = yield context.get_current_state_ids(self)
+ result = yield self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state_ids, event=event
)
+ defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -244,18 +249,6 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
if uid in local_users_in_room:
user_ids.add(uid)
- forgotten = yield self.who_forgot_in_room(
- event.room_id, on_invalidate=cache_context.invalidate,
- )
-
- for row in forgotten:
- user_id = row["user_id"]
- event_id = row["event_id"]
-
- mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
- if event_id == mem_id:
- user_ids.discard(user_id)
-
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
|