diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 869a623090..bbfa5a7265 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
+ def __getitem__(self, field):
+ return self._event_dict[field]
+
+ def __contains__(self, field):
+ return field in self._event_dict
+
+ def items(self):
+ return self._event_dict.items()
+
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb2c6733d5..2d1167296a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
- def _filter_events_for_client(self, user_id, events, is_guest=False):
- # Assumes that user has at some point joined the room if not is_guest.
+ def _filter_events_for_clients(self, users, events):
+ """ Returns dict of user_id -> list of events that user is allowed to
+ see.
+ """
+ event_id_to_state = yield self.store.get_state_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ )
+ )
+
+ forgotten = yield defer.gatherResults([
+ self.store.who_forgot_in_room(
+ room_id,
+ )
+ for room_id in frozenset(e.room_id for e in events)
+ ], consumeErrors=True)
+
+ # Set of membership event_ids that have been forgotten
+ event_id_forgotten = frozenset(
+ row["event_id"] for rows in forgotten for row in rows
+ )
+
+ def allowed(event, user_id, is_guest):
+ state = event_id_to_state[event.event_id]
+
+ visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+ if visibility_event:
+ visibility = visibility_event.content.get("history_visibility", "shared")
+ else:
+ visibility = "shared"
- def allowed(event, membership, visibility):
if visibility == "world_readable":
return True
if is_guest:
return False
+ membership_event = state.get((EventTypes.Member, user_id), None)
+ if membership_event:
+ if membership_event.event_id in event_id_forgotten:
+ membership = None
+ else:
+ membership = membership_event.membership
+ else:
+ membership = None
+
if membership == Membership.JOIN:
return True
@@ -78,43 +116,20 @@ class BaseHandler(object):
return True
- event_id_to_state = yield self.store.get_state_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
- )
- )
-
- events_to_return = []
- for event in events:
- state = event_id_to_state[event.event_id]
+ defer.returnValue({
+ user_id: [
+ event
+ for event in events
+ if allowed(event, user_id, is_guest)
+ ]
+ for user_id, is_guest in users
+ })
- membership_event = state.get((EventTypes.Member, user_id), None)
- if membership_event:
- was_forgotten_at_event = yield self.store.was_forgotten_at(
- membership_event.state_key,
- membership_event.room_id,
- membership_event.event_id
- )
- if was_forgotten_at_event:
- membership = None
- else:
- membership = membership_event.membership
- else:
- membership = None
-
- visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
- if visibility_event:
- visibility = visibility_event.content.get("history_visibility", "shared")
- else:
- visibility = "shared"
-
- should_include = allowed(event, membership, visibility)
- if should_include:
- events_to_return.append(event)
-
- defer.returnValue(events_to_return)
+ @defer.inlineCallbacks
+ def _filter_events_for_client(self, user_id, events, is_guest=False):
+ # Assumes that user has at some point joined the room if not is_guest.
+ res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
+ defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id):
time_now = self.clock.time()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 26402ea9cd..4b94940e99 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
-# from synapse.push.action_generator import ActionGenerator
+from synapse.push.action_generator import ActionGenerator
from twisted.internet import defer
@@ -244,12 +244,11 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- # Temporarily disable notifications due to performance concerns.
- # if not backfilled and not event.internal_metadata.is_outlier():
- # action_generator = ActionGenerator(self.store)
- # yield action_generator.handle_push_actions_for_event(
- # event, self
- # )
+ if not backfilled and not event.internal_metadata.is_outlier():
+ action_generator = ActionGenerator(self.store)
+ yield action_generator.handle_push_actions_for_event(
+ event, self
+ )
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1942268c3c..52202d8e63 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -841,9 +841,6 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
- # Temporarily disable notifications due to performance concerns.
- defer.returnValue([])
-
last_unread_event_id = self.last_read_event_id_for_room_and_user(
room_id, sync_config.user.to_string(), ephemeral_by_room
)
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 73467f3adc..4cf94f6c61 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -36,9 +36,6 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, handler):
- # Temporarily disable notifications due to performance concerns.
- return
-
if event.type == EventTypes.Redaction and event.redacts is not None:
yield self.store.remove_push_actions_for_event_id(
event.room_id, event.redacts
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 8bac7fd6af..e1217b5c52 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,27 +15,25 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
-def list_with_base_rules(rawrules, user_id):
+def list_with_base_rules(rawrules):
ruleslist = []
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules(
- user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules(
- user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
- user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
@@ -43,223 +41,233 @@ def list_with_base_rules(rawrules, user_id):
while current_prio_class > 0:
ruleslist.extend(make_base_append_rules(
- user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
- user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
return ruleslist
-def make_base_append_rules(user, kind):
+def make_base_append_rules(kind):
rules = []
if kind == 'override':
- rules = make_base_append_override_rules()
+ rules = BASE_APPEND_OVRRIDE_RULES
elif kind == 'underride':
- rules = make_base_append_underride_rules(user)
+ rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content':
- rules = make_base_append_content_rules(user)
-
- for r in rules:
- r['priority_class'] = PRIORITY_CLASS_MAP[kind]
- r['default'] = True # Deprecated, left for backwards compat
+ rules = BASE_APPEND_CONTENT_RULES
return rules
-def make_base_prepend_rules(user, kind):
+def make_base_prepend_rules(kind):
rules = []
if kind == 'override':
- rules = make_base_prepend_override_rules()
-
- for r in rules:
- r['priority_class'] = PRIORITY_CLASS_MAP[kind]
- r['default'] = True # Deprecated, left for backwards compat
+ rules = BASE_PREPEND_OVERRIDE_RULES
return rules
-def make_base_append_content_rules(user):
- return [
- {
- 'rule_id': 'global/content/.m.rule.contains_user_name',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': user.localpart, # Matrix ID match
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default',
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
- ]
+BASE_APPEND_CONTENT_RULES = [
+ {
+ 'rule_id': 'global/content/.m.rule.contains_user_name',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern_type': 'user_localpart'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default',
+ }, {
+ 'set_tweak': 'highlight'
+ }
+ ]
+ },
+]
+
+
+BASE_PREPEND_OVERRIDE_RULES = [
+ {
+ 'rule_id': 'global/override/.m.rule.master',
+ 'enabled': False,
+ 'conditions': [],
+ 'actions': [
+ "dont_notify"
+ ]
+ }
+]
+
+
+BASE_APPEND_OVRRIDE_RULES = [
+ {
+ 'rule_id': 'global/override/.m.rule.suppress_notices',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.msgtype',
+ 'pattern': 'm.notice',
+ '_id': '_suppress_notices',
+ }
+ ],
+ 'actions': [
+ 'dont_notify',
+ ]
+ }
+]
+
+BASE_APPEND_UNDERRIDE_RULES = [
+ {
+ 'rule_id': 'global/underride/.m.rule.call',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.call.invite',
+ '_id': '_call',
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'ring'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.contains_display_name',
+ 'conditions': [
+ {
+ 'kind': 'contains_display_name'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight'
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.room_one_to_one',
+ 'conditions': [
+ {
+ 'kind': 'room_member_count',
+ 'is': '2',
+ '_id': 'member_count',
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.invite_for_me',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.member',
+ '_id': '_member',
+ },
+ {
+ 'kind': 'event_match',
+ 'key': 'content.membership',
+ 'pattern': 'invite',
+ '_id': '_invite_member',
+ },
+ {
+ 'kind': 'event_match',
+ 'key': 'state_key',
+ 'pattern_type': 'user_id'
+ },
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.member_event',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.member',
+ '_id': '_member',
+ }
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.message',
+ 'enabled': False,
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.message',
+ '_id': '_message',
+ }
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ }
+]
-def make_base_prepend_override_rules():
- return [
- {
- 'rule_id': 'global/override/.m.rule.master',
- 'enabled': False,
- 'conditions': [],
- 'actions': [
- "dont_notify"
- ]
- }
- ]
+for r in BASE_APPEND_CONTENT_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['content']
+ r['default'] = True
-def make_base_append_override_rules():
- return [
- {
- 'rule_id': 'global/override/.m.rule.suppress_notices',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'content.msgtype',
- 'pattern': 'm.notice',
- }
- ],
- 'actions': [
- 'dont_notify',
- ]
- }
- ]
+for r in BASE_PREPEND_OVERRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['override']
+ r['default'] = True
+for r in BASE_APPEND_OVRRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['override']
+ r['default'] = True
-def make_base_append_underride_rules(user):
- return [
- {
- 'rule_id': 'global/underride/.m.rule.call',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.call.invite',
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'ring'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.contains_display_name',
- 'conditions': [
- {
- 'kind': 'contains_display_name'
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.room_one_to_one',
- 'conditions': [
- {
- 'kind': 'room_member_count',
- 'is': '2'
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.invite_for_me',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- },
- {
- 'kind': 'event_match',
- 'key': 'content.membership',
- 'pattern': 'invite',
- },
- {
- 'kind': 'event_match',
- 'key': 'state_key',
- 'pattern': user.to_string(),
- },
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.member_event',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- }
- ],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.message',
- 'enabled': False,
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- }
- ],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- }
- ]
+for r in BASE_APPEND_UNDERRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['underride']
+ r['default'] = True
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index ce244fa959..b0b3a38db7 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -14,16 +14,15 @@
# limitations under the License.
import logging
-import simplejson as json
+import ujson as json
from twisted.internet import defer
-from synapse.types import UserID
-
import baserules
-from push_rule_evaluator import PushRuleEvaluator
+from push_rule_evaluator import PushRuleEvaluatorForEvent
+
+from synapse.api.constants import EventTypes
-from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__)
@@ -35,28 +34,25 @@ def decode_rule_json(rule):
@defer.inlineCallbacks
-def evaluator_for_room_id(room_id, store):
- users = yield store.get_users_in_room(room_id)
- rules_by_user = yield store.bulk_get_push_rules(users)
+def _get_rules(room_id, user_ids, store):
+ rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = {
- uid: baserules.list_with_base_rules(
- [decode_rule_json(rule_list) for rule_list in rules_by_user[uid]]
- if uid in rules_by_user else [],
- UserID.from_string(uid),
- )
- for uid in users
+ uid: baserules.list_with_base_rules([
+ decode_rule_json(rule_list)
+ for rule_list in rules_by_user.get(uid, [])
+ ])
+ for uid in user_ids
}
- member_events = yield store.get_current_state(
- room_id=room_id,
- event_type='m.room.member',
- )
- display_names = {}
- for ev in member_events:
- if ev.content.get("displayname"):
- display_names[ev.state_key] = ev.content.get("displayname")
+ defer.returnValue(rules_by_user)
+
+
+@defer.inlineCallbacks
+def evaluator_for_room_id(room_id, store):
+ users = yield store.get_users_in_room(room_id)
+ rules_by_user = yield _get_rules(room_id, users, store)
defer.returnValue(BulkPushRuleEvaluator(
- room_id, rules_by_user, display_names, users, store
+ room_id, rules_by_user, users, store
))
@@ -69,10 +65,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
- def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
+ def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
- self.display_names = display_names
self.users_in_room = users_in_room
self.store = store
@@ -80,15 +75,30 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler):
actions_by_user = {}
+ users_dict = yield self.store.are_guests(self.rules_by_user.keys())
+
+ filtered_by_user = yield handler._filter_events_for_clients(
+ users_dict.items(), [event]
+ )
+
+ evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
+
+ condition_cache = {}
+
+ member_state = yield self.store.get_state_for_event(
+ event.event_id,
+ )
+
+ display_names = {}
+ for ev in member_state.values():
+ nm = ev.content.get("displayname", None)
+ if nm and ev.type == EventTypes.Member:
+ display_names[ev.state_key] = nm
+
for uid, rules in self.rules_by_user.items():
- display_name = None
- if uid in self.display_names:
- display_name = self.display_names[uid]
-
- is_guest = yield self.store.is_guest(UserID.from_string(uid))
- filtered = yield handler._filter_events_for_client(
- uid, [event], is_guest=is_guest
- )
+ display_name = display_names.get(uid, None)
+
+ filtered = filtered_by_user[uid]
if len(filtered) == 0:
continue
@@ -96,29 +106,32 @@ class BulkPushRuleEvaluator:
if 'enabled' in rule and not rule['enabled']:
continue
- # XXX: profile tags
- if BulkPushRuleEvaluator.event_matches_rule(
- event, rule,
- display_name, len(self.users_in_room), None
- ):
+ matches = _condition_checker(
+ evaluator, rule['conditions'], uid, display_name, condition_cache
+ )
+ if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify']
- if len(actions) > 0:
+ if actions:
actions_by_user[uid] = actions
break
defer.returnValue(actions_by_user)
- @staticmethod
- def event_matches_rule(event, rule,
- display_name, room_member_count, profile_tag):
- matches = True
-
- # passing the clock all the way into here is extremely awkward and push
- # rules do not care about any of the relative timestamps, so we just
- # pass 0 for the current time.
- client_event = serialize_event(event, 0)
-
- for cond in rule['conditions']:
- matches &= PushRuleEvaluator._event_fulfills_condition(
- client_event, cond, display_name, room_member_count, profile_tag
- )
- return matches
+
+def _condition_checker(evaluator, conditions, uid, display_name, cache):
+ for cond in conditions:
+ _id = cond.get("_id", None)
+ if _id:
+ res = cache.get(_id, None)
+ if res is False:
+ break
+ elif res is True:
+ continue
+
+ res = evaluator.matches(cond, uid, display_name, None)
+ if _id:
+ cache[_id] = res
+
+ if res is False:
+ return False
+
+ return True
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index b0283743a2..379652c513 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,17 +15,22 @@
from twisted.internet import defer
-from synapse.types import UserID
-
import baserules
import logging
import simplejson as json
import re
+from synapse.types import UserID
+
logger = logging.getLogger(__name__)
+GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
+IS_GLOB = re.compile(r'[\?\*\[\]]')
+INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
+
+
@defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id)
@@ -42,9 +47,34 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
))
+def _room_member_count(ev, condition, room_member_count):
+ if 'is' not in condition:
+ return False
+ m = INEQUALITY_EXPR.match(condition['is'])
+ if not m:
+ return False
+ ineq = m.group(1)
+ rhs = m.group(2)
+ if not rhs.isdigit():
+ return False
+ rhs = int(rhs)
+
+ if ineq == '' or ineq == '==':
+ return room_member_count == rhs
+ elif ineq == '<':
+ return room_member_count < rhs
+ elif ineq == '>':
+ return room_member_count > rhs
+ elif ineq == '>=':
+ return room_member_count >= rhs
+ elif ineq == '<=':
+ return room_member_count <= rhs
+ else:
+ return False
+
+
class PushRuleEvaluator:
DEFAULT_ACTIONS = []
- INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
@@ -61,8 +91,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
- user = UserID.from_string(self.user_id)
- self.rules = baserules.list_with_base_rules(rules, user)
+ self.rules = baserules.list_with_base_rules(rules)
self.enabled_map = enabled_map
@@ -98,28 +127,19 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
+ evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
+
for r in self.rules:
- if r['rule_id'] in self.enabled_map:
- r['enabled'] = self.enabled_map[r['rule_id']]
- elif 'enabled' not in r:
- r['enabled'] = True
- if not r['enabled']:
+ enabled = self.enabled_map.get(r['rule_id'], None)
+ if enabled is not None and not enabled:
+ continue
+
+ if not r.get("enabled", True):
continue
- matches = True
conditions = r['conditions']
actions = r['actions']
- for c in conditions:
- matches &= self._event_fulfills_condition(
- ev, c, display_name=my_display_name,
- room_member_count=room_member_count,
- profile_tag=self.profile_tag
- )
- logger.debug(
- "Rule %s %s",
- r['rule_id'], "matches" if matches else "doesn't match"
- )
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
@@ -127,8 +147,22 @@ class PushRuleEvaluator:
r['rule_id'], self.user_id
)
continue
+
+ matches = True
+ for c in conditions:
+ matches = evaluator.matches(
+ c, self.user_id, my_display_name, self.profile_tag
+ )
+ if not matches:
+ break
+
+ logger.debug(
+ "Rule %s %s",
+ r['rule_id'], "matches" if matches else "doesn't match"
+ )
+
if matches:
- logger.info(
+ logger.debug(
"%s matches for user %s, event %s",
r['rule_id'], self.user_id, ev['event_id']
)
@@ -139,94 +173,132 @@ class PushRuleEvaluator:
defer.returnValue(actions)
- logger.info(
+ logger.debug(
"No rules match for user %s, event %s",
self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
- @staticmethod
- def _glob_to_regexp(glob):
- r = re.escape(glob)
- r = re.sub(r'\\\*', r'.*?', r)
- r = re.sub(r'\\\?', r'.', r)
- # handle [abc], [a-z] and [!a-z] style ranges.
- r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
- lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
- re.sub(r'\\\-', '-', x.group(2)))), r)
- return r
+class PushRuleEvaluatorForEvent(object):
+ def __init__(self, event, room_member_count):
+ self._event = event
+ self._room_member_count = room_member_count
- @staticmethod
- def _event_fulfills_condition(ev, condition,
- display_name, room_member_count, profile_tag):
- if condition['kind'] == 'event_match':
- if 'pattern' not in condition:
- logger.warn("event_match condition with no pattern")
- return False
- # XXX: optimisation: cache our pattern regexps
- if condition['key'] == 'content.body':
- r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
- else:
- r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
- val = _value_for_dotted_key(condition['key'], ev)
- if val is None:
- return False
- return re.search(r, val, flags=re.IGNORECASE) is not None
+ # Maps strings of e.g. 'content.body' -> event["content"]["body"]
+ self._value_cache = _flatten_dict(event)
+ def matches(self, condition, user_id, display_name, profile_tag):
+ if condition['kind'] == 'event_match':
+ return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
-
elif condition['kind'] == 'contains_display_name':
- # This is special because display names can be different
- # between rooms and so you can't really hard code it in a rule.
- # Optimisation: we should cache these names and update them from
- # the event stream.
- if 'content' not in ev or 'body' not in ev['content']:
- return False
- if not display_name:
- return False
- return re.search(
- r"\b%s\b" % re.escape(display_name), ev['content']['body'],
- flags=re.IGNORECASE
- ) is not None
-
+ return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count':
- if 'is' not in condition:
- return False
- m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
- if not m:
- return False
- ineq = m.group(1)
- rhs = m.group(2)
- if not rhs.isdigit():
+ return _room_member_count(
+ self._event, condition, self._room_member_count
+ )
+ else:
+ return True
+
+ def _event_match(self, condition, user_id):
+ pattern = condition.get('pattern', None)
+
+ if not pattern:
+ pattern_type = condition.get('pattern_type', None)
+ if pattern_type == "user_id":
+ pattern = user_id
+ elif pattern_type == "user_localpart":
+ pattern = UserID.from_string(user_id).localpart
+
+ if not pattern:
+ logger.warn("event_match condition with no pattern")
+ return False
+
+ # XXX: optimisation: cache our pattern regexps
+ if condition['key'] == 'content.body':
+ body = self._event["content"].get("body", None)
+ if not body:
return False
- rhs = int(rhs)
-
- if ineq == '' or ineq == '==':
- return room_member_count == rhs
- elif ineq == '<':
- return room_member_count < rhs
- elif ineq == '>':
- return room_member_count > rhs
- elif ineq == '>=':
- return room_member_count >= rhs
- elif ineq == '<=':
- return room_member_count <= rhs
- else:
+
+ return _glob_matches(pattern, body, word_boundary=True)
+ else:
+ haystack = self._get_value(condition['key'])
+ if haystack is None:
return False
+
+ return _glob_matches(pattern, haystack)
+
+ def _contains_display_name(self, display_name):
+ if not display_name:
+ return False
+
+ body = self._event["content"].get("body", None)
+ if not body:
+ return False
+
+ return _glob_matches(display_name, body, word_boundary=True)
+
+ def _get_value(self, dotted_key):
+ return self._value_cache.get(dotted_key, None)
+
+
+def _glob_matches(glob, value, word_boundary=False):
+ """Tests if value matches glob.
+
+ Args:
+ glob (string)
+ value (string): String to test against glob.
+ word_boundary (bool): Whether to match against word boundaries or entire
+ string. Defaults to False.
+
+ Returns:
+ bool
+ """
+ if IS_GLOB.search(glob):
+ r = re.escape(glob)
+
+ r = r.replace(r'\*', '.*?')
+ r = r.replace(r'\?', '.')
+
+ # handle [abc], [a-z] and [!a-z] style ranges.
+ r = GLOB_REGEX.sub(
+ lambda x: (
+ '[%s%s]' % (
+ x.group(1) and '^' or '',
+ x.group(2).replace(r'\\\-', '-')
+ )
+ ),
+ r,
+ )
+ if word_boundary:
+ r = r"\b%s\b" % (r,)
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.search(value)
else:
- return True
+ r = r + "$"
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.match(value)
+ elif word_boundary:
+ r = re.escape(glob)
+ r = r"\b%s\b" % (r,)
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.search(value)
+ else:
+ return value.lower() == glob.lower()
+
+def _flatten_dict(d, prefix=[], result={}):
+ for key, value in d.items():
+ if isinstance(value, basestring):
+ result[".".join(prefix + [key])] = value.lower()
+ elif hasattr(value, "items"):
+ _flatten_dict(value, prefix=(prefix+[key]), result=result)
-def _value_for_dotted_key(dotted_key, event):
- parts = dotted_key.split(".")
- val = event
- while len(parts) > 0:
- if parts[0] not in val:
- return None
- val = val[parts[0]]
- parts = parts[1:]
- return val
+ return result
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 0cbd9fe08a..2272d66dc7 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
+import copy
import simplejson as json
@@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
- ruleslist = baserules.list_with_base_rules(ruleslist, user)
+ # We're going to be mutating this a lot, so do a deep copy
+ ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
@@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class'])
+ # Remove internal stuff.
+ for c in r["conditions"]:
+ c.pop("_id", None)
+
+ pattern_type = c.pop("pattern_type", None)
+ if pattern_type == "user_id":
+ c["pattern"] = user.to_string()
+ elif pattern_type == "user_localpart":
+ c["pattern"] = user.localpart
+
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 999b710fbb..70cde0d04d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore):
@@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
@cachedInlineCallbacks()
- def is_guest(self, user):
+ def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
table="users",
- keyvalues={"name": user.to_string()},
+ keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
@@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
+ @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+ inlineCallbacks=True)
+ def are_guests(self, user_ids):
+ sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+ ",".join("?" for _ in user_ids),
+ )
+
+ rows = yield self._execute(
+ "are_guests", self.cursor_to_dict, sql, *user_ids
+ )
+
+ result = {user_id: False for user_id in user_ids}
+
+ result.update({
+ row["name"]: bool(row["is_guest"])
+ for row in rows
+ })
+
+ defer.returnValue(result)
+
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7d3ce4579d..68ac88905f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all()
+ self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2)
@@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
+
+ @cached()
+ def who_forgot_in_room(self, room_id):
+ return self._simple_select_list(
+ table="room_memberships",
+ retcols=("user_id", "event_id"),
+ keyvalues={
+ "room_id": room_id,
+ "forgotten": 1,
+ },
+ desc="who_forgot"
+ )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
deleted file mode 100644
index 11a3d94bb0..0000000000
--- a/tests/handlers/test_federation.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-from tests import unittest
-
-from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
-from synapse.handlers.federation import FederationHandler
-
-from mock import NonCallableMock, ANY, Mock
-
-from ..utils import setup_test_homeserver
-
-
-class FederationTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
-
- self.state_handler = NonCallableMock(spec_set=[
- "compute_event_context",
- ])
-
- self.auth = NonCallableMock(spec_set=[
- "check",
- "check_host_in_room",
- ])
-
- self.hostname = "test"
- hs = yield setup_test_homeserver(
- self.hostname,
- datastore=NonCallableMock(spec_set=[
- "persist_event",
- "store_room",
- "get_room",
- "get_destination_retry_timings",
- "set_destination_retry_timings",
- "have_events",
- "get_users_in_room",
- "bulk_get_push_rules",
- "get_current_state",
- "set_push_actions_for_event_and_users",
- "is_guest",
- "get_state_for_events",
- ]),
- resource_for_federation=NonCallableMock(),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_member_handler",
- "federation_handler",
- ]),
- auth=self.auth,
- state_handler=self.state_handler,
- keyring=Mock(),
- )
-
- self.datastore = hs.get_datastore()
- self.handlers = hs.get_handlers()
- self.notifier = hs.get_notifier()
- self.hs = hs
-
- self.handlers.federation_handler = FederationHandler(self.hs)
-
- self.datastore.get_state_for_events.return_value = {"$a:b": {}}
-
- @defer.inlineCallbacks
- def test_msg(self):
- pdu = FrozenEvent({
- "type": EventTypes.Message,
- "room_id": "foo",
- "content": {"msgtype": u"fooo"},
- "origin_server_ts": 0,
- "event_id": "$a:b",
- "user_id":"@a:b",
- "origin": "b",
- "auth_events": [],
- "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
- })
-
- self.datastore.persist_event.return_value = defer.succeed((1,1))
- self.datastore.get_room.return_value = defer.succeed(True)
- self.datastore.get_users_in_room.return_value = ["@a:b"]
- self.datastore.bulk_get_push_rules.return_value = {}
- self.datastore.get_current_state.return_value = {}
- self.auth.check_host_in_room.return_value = defer.succeed(True)
-
- retry_timings_res = {
- "destination": "",
- "retry_last_ts": 0,
- "retry_interval": 0,
- }
- self.datastore.get_destination_retry_timings.return_value = (
- defer.succeed(retry_timings_res)
- )
-
- def have_events(event_ids):
- return defer.succeed({})
- self.datastore.have_events.side_effect = have_events
-
- def annotate(ev, old_state=None, outlier=False):
- context = Mock()
- context.current_state = {}
- context.auth_events = {}
- return defer.succeed(context)
- self.state_handler.compute_event_context.side_effect = annotate
-
- yield self.handlers.federation_handler.on_receive_pdu(
- "fo", pdu, False
- )
-
- self.datastore.persist_event.assert_called_once_with(
- ANY,
- is_new_state=True,
- backfilled=False,
- current_state=None,
- context=ANY,
- )
-
- self.state_handler.compute_event_context.assert_called_once_with(
- ANY, old_state=None, outlier=False
- )
-
- self.auth.check.assert_called_once_with(ANY, auth_events={})
-
- self.notifier.on_new_room_event.assert_called_once_with(
- ANY, 1, 1, extra_users=[]
- )
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
deleted file mode 100644
index e7a12a2ba2..0000000000
--- a/tests/handlers/test_room.py
+++ /dev/null
@@ -1,418 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-from .. import unittest
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
-from synapse.handlers.profile import ProfileHandler
-from synapse.types import UserID
-from ..utils import setup_test_homeserver
-
-from mock import Mock, NonCallableMock
-
-
-class RoomMemberHandlerTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
- self.hostname = "red"
- hs = yield setup_test_homeserver(
- self.hostname,
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- datastore=NonCallableMock(spec_set=[
- "persist_event",
- "get_room_member",
- "get_room",
- "store_room",
- "get_latest_events_in_room",
- "add_event_hashes",
- "get_users_in_room",
- "bulk_get_push_rules",
- "get_current_state",
- "set_push_actions_for_event_and_users",
- "get_state_for_events",
- "is_guest",
- ]),
- resource_for_federation=NonCallableMock(),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_member_handler",
- "profile_handler",
- "federation_handler",
- ]),
- auth=NonCallableMock(spec_set=[
- "check",
- "add_auth_events",
- "check_host_in_room",
- ]),
- state_handler=NonCallableMock(spec_set=[
- "compute_event_context",
- "get_current_state",
- ]),
- )
-
- self.federation = NonCallableMock(spec_set=[
- "handle_new_event",
- "send_invite",
- "get_state_for_room",
- ])
-
- self.datastore = hs.get_datastore()
- self.handlers = hs.get_handlers()
- self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
- self.distributor = hs.get_distributor()
- self.auth = hs.get_auth()
- self.hs = hs
-
- self.handlers.federation_handler = self.federation
-
- self.distributor.declare("collect_presencelike_data")
-
- self.handlers.room_member_handler = RoomMemberHandler(self.hs)
- self.handlers.profile_handler = ProfileHandler(self.hs)
- self.room_member_handler = self.handlers.room_member_handler
-
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- self.datastore.persist_event.return_value = (1,1)
- self.datastore.add_event_hashes.return_value = []
- self.datastore.get_users_in_room.return_value = ["@bob:red"]
- self.datastore.bulk_get_push_rules.return_value = {}
-
- @defer.inlineCallbacks
- def test_invite(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- target_user_id = "@red:blue"
- content = {"membership": Membership.INVITE}
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": target_user_id,
- "room_id": room_id,
- "content": content,
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
- self.datastore.get_current_state.return_value = {}
- self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@alice:green"): self._create_member(
- user_id="@alice:green",
- room_id=room_id,
- ),
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- def send_invite(domain, event):
- return defer.succeed(event)
-
- self.federation.send_invite.side_effect = send_invite
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- yield room_handler.send_membership_event(event, context)
-
- self.state_handler.compute_event_context.assert_called_once_with(
- builder
- )
-
- self.auth.add_auth_events.assert_called_once_with(
- builder, context
- )
-
- self.federation.send_invite.assert_called_once_with(
- "blue", event,
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context,
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
- )
- self.assertFalse(self.datastore.get_room.called)
- self.assertFalse(self.datastore.store_room.called)
- self.assertFalse(self.federation.get_state_for_room.called)
-
- @defer.inlineCallbacks
- def test_simple_join(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- user = UserID.from_string(user_id)
-
- join_signal_observer = Mock()
- self.distributor.observe("user_joined_room", join_signal_observer)
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": Membership.JOIN},
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
- self.datastore.get_current_state.return_value = {}
- self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- membership=Membership.INVITE
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- # Actual invocation
- yield room_handler.send_membership_event(event, context)
-
- self.federation.handle_new_event.assert_called_once_with(
- event, destinations=set()
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[user]
- )
-
- join_signal_observer.assert_called_with(
- user=user, room_id=room_id
- )
-
- def _create_member(self, user_id, room_id, membership=Membership.JOIN):
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": membership},
- })
-
- return builder.build()
-
- @defer.inlineCallbacks
- def test_simple_leave(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- user = UserID.from_string(user_id)
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": Membership.LEAVE},
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
- self.datastore.get_current_state.return_value = {}
- self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- membership=Membership.JOIN
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- leave_signal_observer = Mock()
- self.distributor.observe("user_left_room", leave_signal_observer)
-
- # Actual invocation
- yield room_handler.send_membership_event(event, context)
-
- self.federation.handle_new_event.assert_called_once_with(
- event, destinations=set(['red'])
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[user]
- )
-
- leave_signal_observer.assert_called_with(
- user=user, room_id=room_id
- )
-
-
-class RoomCreationTest(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
- self.hostname = "red"
-
- hs = yield setup_test_homeserver(
- self.hostname,
- datastore=NonCallableMock(spec_set=[
- "store_room",
- "snapshot_room",
- "persist_event",
- "get_joined_hosts_for_room",
- ]),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_creation_handler",
- "message_handler",
- ]),
- auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- )
-
- self.federation = NonCallableMock(spec_set=[
- "handle_new_event",
- ])
-
- self.handlers = hs.get_handlers()
-
- self.handlers.room_creation_handler = RoomCreationHandler(hs)
- self.room_creation_handler = self.handlers.room_creation_handler
-
- self.message_handler = self.handlers.message_handler
-
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- @defer.inlineCallbacks
- def test_room_creation(self):
- user_id = "@foo:red"
- room_id = "!bobs_room:red"
- config = {"visibility": "private"}
-
- yield self.room_creation_handler.create_room(
- user_id=user_id,
- room_id=room_id,
- config=config,
- )
-
- self.assertTrue(self.message_handler.create_and_send_event.called)
-
- event_dicts = [
- e[0][0]
- for e in self.message_handler.create_and_send_event.call_args_list
- ]
-
- self.assertTrue(len(event_dicts) > 3)
-
- self.assertDictContainsSubset(
- {
- "type": EventTypes.Create,
- "sender": user_id,
- "room_id": room_id,
- },
- event_dicts[0]
- )
-
- self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
-
- self.assertDictContainsSubset(
- {
- "type": EventTypes.Member,
- "sender": user_id,
- "room_id": room_id,
- "state_key": user_id,
- },
- event_dicts[1]
- )
-
- self.assertEqual(
- Membership.JOIN,
- event_dicts[1]["content"]["membership"]
- )
|