summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-05-22 15:02:12 +0100
committerErik Johnston <erik@matrix.org>2017-05-22 15:02:12 +0100
commit7fb80b5eaeb00f27ac46a043e53341bd8a1d1cfc (patch)
tree6ef74197b2273e6f80bfef98c1d0184a438c96f4 /synapse
parentAdd debug logging (diff)
downloadsynapse-7fb80b5eaeb00f27ac46a043e53341bd8a1d1cfc.tar.xz
Check if current event is a membership event
Diffstat (limited to 'synapse')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 6bf203993c..2ee07f2f7e 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -53,7 +53,7 @@ class BulkPushRuleEvaluator(object):
         room_id = event.room_id
         rules_for_room = self._get_rules_for_room(room_id)
 
-        rules_by_user = yield rules_for_room.get_rules(context)
+        rules_by_user = yield rules_for_room.get_rules(event, context)
 
         # if this event is an invite event, we may need to run rules for the user
         # who's been invited, otherwise they won't get told they've been invited
@@ -216,7 +216,7 @@ class RulesForRoom(object):
         self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
 
     @defer.inlineCallbacks
-    def get_rules(self, context):
+    def get_rules(self, event, context):
         """Given an event context return the rules for all users who are
         currently in the room.
         """
@@ -280,7 +280,7 @@ class RulesForRoom(object):
                 # and fetch push rules for them if appropriate.
                 logger.debug("Found new member events %r", missing_member_event_ids)
                 yield self._update_rules_with_member_event_ids(
-                    ret_rules_by_user, missing_member_event_ids, state_group
+                    ret_rules_by_user, missing_member_event_ids, state_group, event
                 )
 
         if logger.isEnabledFor(logging.DEBUG):
@@ -292,7 +292,7 @@ class RulesForRoom(object):
 
     @defer.inlineCallbacks
     def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
-                                            state_group):
+                                            state_group, event):
         """Update the partially filled rules_by_user dict by fetching rules for
         any newly joined users in the `member_event_ids` list.
 
@@ -321,6 +321,11 @@ class RulesForRoom(object):
             for row in rows
         }
 
+        if event.type == EventTypes.Member:
+            for event_id in member_event_ids.itervalues():
+                if event_id == event.event_id:
+                    members[event_id] = (event.state_key, event.membership)
+
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())