diff options
-rw-r--r-- | synapse/events/__init__.py | 9 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 93 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 13 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 3 | ||||
-rw-r--r-- | synapse/push/action_generator.py | 3 | ||||
-rw-r--r-- | synapse/push/baserules.py | 390 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 121 | ||||
-rw-r--r-- | synapse/push/push_rule_evaluator.py | 260 | ||||
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 14 | ||||
-rw-r--r-- | synapse/storage/registration.py | 26 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 13 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 141 | ||||
-rw-r--r-- | tests/handlers/test_room.py | 418 |
13 files changed, 550 insertions, 954 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 869a623090..bbfa5a7265 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -117,6 +117,15 @@ class EventBase(object): def __set__(self, instance, value): raise AttributeError("Unrecognized attribute %s" % (instance,)) + def __getitem__(self, field): + return self._event_dict[field] + + def __contains__(self, field): + return field in self._event_dict + + def items(self): + return self._event_dict.items() + class FrozenEvent(EventBase): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index bb2c6733d5..2d1167296a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -53,16 +53,54 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events, is_guest=False): - # Assumes that user has at some point joined the room if not is_guest. + def _filter_events_for_clients(self, users, events): + """ Returns dict of user_id -> list of events that user is allowed to + see. + """ + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) + ) + + forgotten = yield defer.gatherResults([ + self.store.who_forgot_in_room( + room_id, + ) + for room_id in frozenset(e.room_id for e in events) + ], consumeErrors=True) + + # Set of membership event_ids that have been forgotten + event_id_forgotten = frozenset( + row["event_id"] for rows in forgotten for row in rows + ) + + def allowed(event, user_id, is_guest): + state = event_id_to_state[event.event_id] + + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + if visibility_event: + visibility = visibility_event.content.get("history_visibility", "shared") + else: + visibility = "shared" - def allowed(event, membership, visibility): if visibility == "world_readable": return True if is_guest: return False + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id in event_id_forgotten: + membership = None + else: + membership = membership_event.membership + else: + membership = None + if membership == Membership.JOIN: return True @@ -78,43 +116,20 @@ class BaseHandler(object): return True - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) - - events_to_return = [] - for event in events: - state = event_id_to_state[event.event_id] + defer.returnValue({ + user_id: [ + event + for event in events + if allowed(event, user_id, is_guest) + ] + for user_id, is_guest in users + }) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - was_forgotten_at_event = yield self.store.was_forgotten_at( - membership_event.state_key, - membership_event.room_id, - membership_event.event_id - ) - if was_forgotten_at_event: - membership = None - else: - membership = membership_event.membership - else: - membership = None - - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) - if visibility_event: - visibility = visibility_event.content.get("history_visibility", "shared") - else: - visibility = "shared" - - should_include = allowed(event, membership, visibility) - if should_include: - events_to_return.append(event) - - defer.returnValue(events_to_return) + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, events, is_guest=False): + # Assumes that user has at some point joined the room if not is_guest. + res = yield self._filter_events_for_clients([(user_id, is_guest)], events) + defer.returnValue(res.get(user_id, [])) def ratelimit(self, user_id): time_now = self.clock.time() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 26402ea9cd..4b94940e99 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -36,7 +36,7 @@ from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination -# from synapse.push.action_generator import ActionGenerator +from synapse.push.action_generator import ActionGenerator from twisted.internet import defer @@ -244,12 +244,11 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - # Temporarily disable notifications due to performance concerns. - # if not backfilled and not event.internal_metadata.is_outlier(): - # action_generator = ActionGenerator(self.store) - # yield action_generator.handle_push_actions_for_event( - # event, self - # ) + if not backfilled and not event.internal_metadata.is_outlier(): + action_generator = ActionGenerator(self.store) + yield action_generator.handle_push_actions_for_event( + event, self + ) @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1942268c3c..52202d8e63 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -841,9 +841,6 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room): - # Temporarily disable notifications due to performance concerns. - defer.returnValue([]) - last_unread_event_id = self.last_read_event_id_for_room_and_user( room_id, sync_config.user.to_string(), ephemeral_by_room ) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 73467f3adc..4cf94f6c61 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -36,9 +36,6 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, handler): - # Temporarily disable notifications due to performance concerns. - return - if event.type == EventTypes.Redaction and event.redacts is not None: yield self.store.remove_push_actions_for_event_id( event.room_id, event.redacts diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 8bac7fd6af..e1217b5c52 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -15,27 +15,25 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP -def list_with_base_rules(rawrules, user_id): +def list_with_base_rules(rawrules): ruleslist = [] # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class] )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] )) @@ -43,223 +41,233 @@ def list_with_base_rules(rawrules, user_id): while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] )) return ruleslist -def make_base_append_rules(user, kind): +def make_base_append_rules(kind): rules = [] if kind == 'override': - rules = make_base_append_override_rules() + rules = BASE_APPEND_OVRRIDE_RULES elif kind == 'underride': - rules = make_base_append_underride_rules(user) + rules = BASE_APPEND_UNDERRIDE_RULES elif kind == 'content': - rules = make_base_append_content_rules(user) - - for r in rules: - r['priority_class'] = PRIORITY_CLASS_MAP[kind] - r['default'] = True # Deprecated, left for backwards compat + rules = BASE_APPEND_CONTENT_RULES return rules -def make_base_prepend_rules(user, kind): +def make_base_prepend_rules(kind): rules = [] if kind == 'override': - rules = make_base_prepend_override_rules() - - for r in rules: - r['priority_class'] = PRIORITY_CLASS_MAP[kind] - r['default'] = True # Deprecated, left for backwards compat + rules = BASE_PREPEND_OVERRIDE_RULES return rules -def make_base_append_content_rules(user): - return [ - { - 'rule_id': 'global/content/.m.rule.contains_user_name', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': user.localpart, # Matrix ID match - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default', - }, { - 'set_tweak': 'highlight' - } - ] - }, - ] +BASE_APPEND_CONTENT_RULES = [ + { + 'rule_id': 'global/content/.m.rule.contains_user_name', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern_type': 'user_localpart' + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default', + }, { + 'set_tweak': 'highlight' + } + ] + }, +] + + +BASE_PREPEND_OVERRIDE_RULES = [ + { + 'rule_id': 'global/override/.m.rule.master', + 'enabled': False, + 'conditions': [], + 'actions': [ + "dont_notify" + ] + } +] + + +BASE_APPEND_OVRRIDE_RULES = [ + { + 'rule_id': 'global/override/.m.rule.suppress_notices', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.msgtype', + 'pattern': 'm.notice', + '_id': '_suppress_notices', + } + ], + 'actions': [ + 'dont_notify', + ] + } +] + +BASE_APPEND_UNDERRIDE_RULES = [ + { + 'rule_id': 'global/underride/.m.rule.call', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.call.invite', + '_id': '_call', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'ring' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.contains_display_name', + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight' + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.room_one_to_one', + 'conditions': [ + { + 'kind': 'room_member_count', + 'is': '2', + '_id': 'member_count', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.invite_for_me', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + '_id': '_member', + }, + { + 'kind': 'event_match', + 'key': 'content.membership', + 'pattern': 'invite', + '_id': '_invite_member', + }, + { + 'kind': 'event_match', + 'key': 'state_key', + 'pattern_type': 'user_id' + }, + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.member_event', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.member', + '_id': '_member', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, + { + 'rule_id': 'global/underride/.m.rule.message', + 'enabled': False, + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.message', + '_id': '_message', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': False + } + ] + } +] -def make_base_prepend_override_rules(): - return [ - { - 'rule_id': 'global/override/.m.rule.master', - 'enabled': False, - 'conditions': [], - 'actions': [ - "dont_notify" - ] - } - ] +for r in BASE_APPEND_CONTENT_RULES: + r['priority_class'] = PRIORITY_CLASS_MAP['content'] + r['default'] = True -def make_base_append_override_rules(): - return [ - { - 'rule_id': 'global/override/.m.rule.suppress_notices', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'content.msgtype', - 'pattern': 'm.notice', - } - ], - 'actions': [ - 'dont_notify', - ] - } - ] +for r in BASE_PREPEND_OVERRIDE_RULES: + r['priority_class'] = PRIORITY_CLASS_MAP['override'] + r['default'] = True +for r in BASE_APPEND_OVRRIDE_RULES: + r['priority_class'] = PRIORITY_CLASS_MAP['override'] + r['default'] = True -def make_base_append_underride_rules(user): - return [ - { - 'rule_id': 'global/underride/.m.rule.call', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.call.invite', - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'ring' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] - }, - { - 'rule_id': 'global/underride/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] - }, - { - 'rule_id': 'global/underride/.m.rule.room_one_to_one', - 'conditions': [ - { - 'kind': 'room_member_count', - 'is': '2' - } - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] - }, - { - 'rule_id': 'global/underride/.m.rule.invite_for_me', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - }, - { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - }, - { - 'kind': 'event_match', - 'key': 'state_key', - 'pattern': user.to_string(), - }, - ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] - }, - { - 'rule_id': 'global/underride/.m.rule.member_event', - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - } - ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] - }, - { - 'rule_id': 'global/underride/.m.rule.message', - 'enabled': False, - 'conditions': [ - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - } - ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] - } - ] +for r in BASE_APPEND_UNDERRIDE_RULES: + r['priority_class'] = PRIORITY_CLASS_MAP['underride'] + r['default'] = True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index ce244fa959..b0b3a38db7 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -14,16 +14,15 @@ # limitations under the License. import logging -import simplejson as json +import ujson as json from twisted.internet import defer -from synapse.types import UserID - import baserules -from push_rule_evaluator import PushRuleEvaluator +from push_rule_evaluator import PushRuleEvaluatorForEvent + +from synapse.api.constants import EventTypes -from synapse.events.utils import serialize_event logger = logging.getLogger(__name__) @@ -35,28 +34,25 @@ def decode_rule_json(rule): @defer.inlineCallbacks -def evaluator_for_room_id(room_id, store): - users = yield store.get_users_in_room(room_id) - rules_by_user = yield store.bulk_get_push_rules(users) +def _get_rules(room_id, user_ids, store): + rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = { - uid: baserules.list_with_base_rules( - [decode_rule_json(rule_list) for rule_list in rules_by_user[uid]] - if uid in rules_by_user else [], - UserID.from_string(uid), - ) - for uid in users + uid: baserules.list_with_base_rules([ + decode_rule_json(rule_list) + for rule_list in rules_by_user.get(uid, []) + ]) + for uid in user_ids } - member_events = yield store.get_current_state( - room_id=room_id, - event_type='m.room.member', - ) - display_names = {} - for ev in member_events: - if ev.content.get("displayname"): - display_names[ev.state_key] = ev.content.get("displayname") + defer.returnValue(rules_by_user) + + +@defer.inlineCallbacks +def evaluator_for_room_id(room_id, store): + users = yield store.get_users_in_room(room_id) + rules_by_user = yield _get_rules(room_id, users, store) defer.returnValue(BulkPushRuleEvaluator( - room_id, rules_by_user, display_names, users, store + room_id, rules_by_user, users, store )) @@ -69,10 +65,9 @@ class BulkPushRuleEvaluator: the same logic to run the actual rules, but could be optimised further (see https://matrix.org/jira/browse/SYN-562) """ - def __init__(self, room_id, rules_by_user, display_names, users_in_room, store): + def __init__(self, room_id, rules_by_user, users_in_room, store): self.room_id = room_id self.rules_by_user = rules_by_user - self.display_names = display_names self.users_in_room = users_in_room self.store = store @@ -80,15 +75,30 @@ class BulkPushRuleEvaluator: def action_for_event_by_user(self, event, handler): actions_by_user = {} + users_dict = yield self.store.are_guests(self.rules_by_user.keys()) + + filtered_by_user = yield handler._filter_events_for_clients( + users_dict.items(), [event] + ) + + evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) + + condition_cache = {} + + member_state = yield self.store.get_state_for_event( + event.event_id, + ) + + display_names = {} + for ev in member_state.values(): + nm = ev.content.get("displayname", None) + if nm and ev.type == EventTypes.Member: + display_names[ev.state_key] = nm + for uid, rules in self.rules_by_user.items(): - display_name = None - if uid in self.display_names: - display_name = self.display_names[uid] - - is_guest = yield self.store.is_guest(UserID.from_string(uid)) - filtered = yield handler._filter_events_for_client( - uid, [event], is_guest=is_guest - ) + display_name = display_names.get(uid, None) + + filtered = filtered_by_user[uid] if len(filtered) == 0: continue @@ -96,29 +106,32 @@ class BulkPushRuleEvaluator: if 'enabled' in rule and not rule['enabled']: continue - # XXX: profile tags - if BulkPushRuleEvaluator.event_matches_rule( - event, rule, - display_name, len(self.users_in_room), None - ): + matches = _condition_checker( + evaluator, rule['conditions'], uid, display_name, condition_cache + ) + if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] - if len(actions) > 0: + if actions: actions_by_user[uid] = actions break defer.returnValue(actions_by_user) - @staticmethod - def event_matches_rule(event, rule, - display_name, room_member_count, profile_tag): - matches = True - - # passing the clock all the way into here is extremely awkward and push - # rules do not care about any of the relative timestamps, so we just - # pass 0 for the current time. - client_event = serialize_event(event, 0) - - for cond in rule['conditions']: - matches &= PushRuleEvaluator._event_fulfills_condition( - client_event, cond, display_name, room_member_count, profile_tag - ) - return matches + +def _condition_checker(evaluator, conditions, uid, display_name, cache): + for cond in conditions: + _id = cond.get("_id", None) + if _id: + res = cache.get(_id, None) + if res is False: + break + elif res is True: + continue + + res = evaluator.matches(cond, uid, display_name, None) + if _id: + cache[_id] = res + + if res is False: + return False + + return True diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index b0283743a2..379652c513 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,17 +15,22 @@ from twisted.internet import defer -from synapse.types import UserID - import baserules import logging import simplejson as json import re +from synapse.types import UserID + logger = logging.getLogger(__name__) +GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]') +IS_GLOB = re.compile(r'[\?\*\[\]]') +INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") + + @defer.inlineCallbacks def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) @@ -42,9 +47,34 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): )) +def _room_member_count(ev, condition, room_member_count): + if 'is' not in condition: + return False + m = INEQUALITY_EXPR.match(condition['is']) + if not m: + return False + ineq = m.group(1) + rhs = m.group(2) + if not rhs.isdigit(): + return False + rhs = int(rhs) + + if ineq == '' or ineq == '==': + return room_member_count == rhs + elif ineq == '<': + return room_member_count < rhs + elif ineq == '>': + return room_member_count > rhs + elif ineq == '>=': + return room_member_count >= rhs + elif ineq == '<=': + return room_member_count <= rhs + else: + return False + + class PushRuleEvaluator: DEFAULT_ACTIONS = [] - INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, our_member_event, store): @@ -61,8 +91,7 @@ class PushRuleEvaluator: rule['actions'] = json.loads(raw_rule['actions']) rules.append(rule) - user = UserID.from_string(self.user_id) - self.rules = baserules.list_with_base_rules(rules, user) + self.rules = baserules.list_with_base_rules(rules) self.enabled_map = enabled_map @@ -98,28 +127,19 @@ class PushRuleEvaluator: room_members = yield self.store.get_users_in_room(room_id) room_member_count = len(room_members) + evaluator = PushRuleEvaluatorForEvent(ev, room_member_count) + for r in self.rules: - if r['rule_id'] in self.enabled_map: - r['enabled'] = self.enabled_map[r['rule_id']] - elif 'enabled' not in r: - r['enabled'] = True - if not r['enabled']: + enabled = self.enabled_map.get(r['rule_id'], None) + if enabled is not None and not enabled: + continue + + if not r.get("enabled", True): continue - matches = True conditions = r['conditions'] actions = r['actions'] - for c in conditions: - matches &= self._event_fulfills_condition( - ev, c, display_name=my_display_name, - room_member_count=room_member_count, - profile_tag=self.profile_tag - ) - logger.debug( - "Rule %s %s", - r['rule_id'], "matches" if matches else "doesn't match" - ) # ignore rules with no actions (we have an explict 'dont_notify') if len(actions) == 0: logger.warn( @@ -127,8 +147,22 @@ class PushRuleEvaluator: r['rule_id'], self.user_id ) continue + + matches = True + for c in conditions: + matches = evaluator.matches( + c, self.user_id, my_display_name, self.profile_tag + ) + if not matches: + break + + logger.debug( + "Rule %s %s", + r['rule_id'], "matches" if matches else "doesn't match" + ) + if matches: - logger.info( + logger.debug( "%s matches for user %s, event %s", r['rule_id'], self.user_id, ev['event_id'] ) @@ -139,94 +173,132 @@ class PushRuleEvaluator: defer.returnValue(actions) - logger.info( + logger.debug( "No rules match for user %s, event %s", self.user_id, ev['event_id'] ) defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) - @staticmethod - def _glob_to_regexp(glob): - r = re.escape(glob) - r = re.sub(r'\\\*', r'.*?', r) - r = re.sub(r'\\\?', r'.', r) - # handle [abc], [a-z] and [!a-z] style ranges. - r = re.sub(r'\\\[(\\\!|)(.*)\\\]', - lambda x: ('[%s%s]' % (x.group(1) and '^' or '', - re.sub(r'\\\-', '-', x.group(2)))), r) - return r +class PushRuleEvaluatorForEvent(object): + def __init__(self, event, room_member_count): + self._event = event + self._room_member_count = room_member_count - @staticmethod - def _event_fulfills_condition(ev, condition, - display_name, room_member_count, profile_tag): - if condition['kind'] == 'event_match': - if 'pattern' not in condition: - logger.warn("event_match condition with no pattern") - return False - # XXX: optimisation: cache our pattern regexps - if condition['key'] == 'content.body': - r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern']) - else: - r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern']) - val = _value_for_dotted_key(condition['key'], ev) - if val is None: - return False - return re.search(r, val, flags=re.IGNORECASE) is not None + # Maps strings of e.g. 'content.body' -> event["content"]["body"] + self._value_cache = _flatten_dict(event) + def matches(self, condition, user_id, display_name, profile_tag): + if condition['kind'] == 'event_match': + return self._event_match(condition, user_id) elif condition['kind'] == 'device': if 'profile_tag' not in condition: return True return condition['profile_tag'] == profile_tag - elif condition['kind'] == 'contains_display_name': - # This is special because display names can be different - # between rooms and so you can't really hard code it in a rule. - # Optimisation: we should cache these names and update them from - # the event stream. - if 'content' not in ev or 'body' not in ev['content']: - return False - if not display_name: - return False - return re.search( - r"\b%s\b" % re.escape(display_name), ev['content']['body'], - flags=re.IGNORECASE - ) is not None - + return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': - if 'is' not in condition: - return False - m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is']) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): + return _room_member_count( + self._event, condition, self._room_member_count + ) + else: + return True + + def _event_match(self, condition, user_id): + pattern = condition.get('pattern', None) + + if not pattern: + pattern_type = condition.get('pattern_type', None) + if pattern_type == "user_id": + pattern = user_id + elif pattern_type == "user_localpart": + pattern = UserID.from_string(user_id).localpart + + if not pattern: + logger.warn("event_match condition with no pattern") + return False + + # XXX: optimisation: cache our pattern regexps + if condition['key'] == 'content.body': + body = self._event["content"].get("body", None) + if not body: return False - rhs = int(rhs) - - if ineq == '' or ineq == '==': - return room_member_count == rhs - elif ineq == '<': - return room_member_count < rhs - elif ineq == '>': - return room_member_count > rhs - elif ineq == '>=': - return room_member_count >= rhs - elif ineq == '<=': - return room_member_count <= rhs - else: + + return _glob_matches(pattern, body, word_boundary=True) + else: + haystack = self._get_value(condition['key']) + if haystack is None: return False + + return _glob_matches(pattern, haystack) + + def _contains_display_name(self, display_name): + if not display_name: + return False + + body = self._event["content"].get("body", None) + if not body: + return False + + return _glob_matches(display_name, body, word_boundary=True) + + def _get_value(self, dotted_key): + return self._value_cache.get(dotted_key, None) + + +def _glob_matches(glob, value, word_boundary=False): + """Tests if value matches glob. + + Args: + glob (string) + value (string): String to test against glob. + word_boundary (bool): Whether to match against word boundaries or entire + string. Defaults to False. + + Returns: + bool + """ + if IS_GLOB.search(glob): + r = re.escape(glob) + + r = r.replace(r'\*', '.*?') + r = r.replace(r'\?', '.') + + # handle [abc], [a-z] and [!a-z] style ranges. + r = GLOB_REGEX.sub( + lambda x: ( + '[%s%s]' % ( + x.group(1) and '^' or '', + x.group(2).replace(r'\\\-', '-') + ) + ), + r, + ) + if word_boundary: + r = r"\b%s\b" % (r,) + r = re.compile(r, flags=re.IGNORECASE) + + return r.search(value) else: - return True + r = r + "$" + r = re.compile(r, flags=re.IGNORECASE) + + return r.match(value) + elif word_boundary: + r = re.escape(glob) + r = r"\b%s\b" % (r,) + r = re.compile(r, flags=re.IGNORECASE) + + return r.search(value) + else: + return value.lower() == glob.lower() + +def _flatten_dict(d, prefix=[], result={}): + for key, value in d.items(): + if isinstance(value, basestring): + result[".".join(prefix + [key])] = value.lower() + elif hasattr(value, "items"): + _flatten_dict(value, prefix=(prefix+[key]), result=result) -def _value_for_dotted_key(dotted_key, event): - parts = dotted_key.split(".") - val = event - while len(parts) > 0: - if parts[0] not in val: - return None - val = val[parts[0]] - parts = parts[1:] - return val + return result diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 0cbd9fe08a..2272d66dc7 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -27,6 +27,7 @@ from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) +import copy import simplejson as json @@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet): rule["actions"] = json.loads(rawrule["actions"]) ruleslist.append(rule) - ruleslist = baserules.list_with_base_rules(ruleslist, user) + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet): template_name = _priority_class_to_template_name(r['priority_class']) + # Remove internal stuff. + for c in r["conditions"]: + c.pop("_id", None) + + pattern_type = c.pop("pattern_type", None) + if pattern_type == "user_id": + c["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + c["pattern"] = user.localpart + if r['priority_class'] > PRIORITY_CLASS_MAP['override']: # per-device rule profile_tag = _profile_tag_from_conditions(r["conditions"]) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 999b710fbb..70cde0d04d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList class RegistrationStore(SQLBaseStore): @@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) @cachedInlineCallbacks() - def is_guest(self, user): + def is_guest(self, user_id): res = yield self._simple_select_one_onecol( table="users", - keyvalues={"name": user.to_string()}, + keyvalues={"name": user_id}, retcol="is_guest", allow_none=True, desc="is_guest", @@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) + @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + inlineCallbacks=True) + def are_guests(self, user_ids): + sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( + ",".join("?" for _ in user_ids), + ) + + rows = yield self._execute( + "are_guests", self.cursor_to_dict, sql, *user_ids + ) + + result = {user_id: False for user_id in user_ids} + + result.update({ + row["name"]: bool(row["is_guest"]) + for row in rows + }) + + defer.returnValue(result) + def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7d3ce4579d..68ac88905f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore): txn.execute(sql, (user_id, room_id)) yield self.runInteraction("forget_membership", f) self.was_forgotten_at.invalidate_all() + self.who_forgot_in_room.invalidate_all() self.did_forget.invalidate((user_id, room_id)) @cachedInlineCallbacks(num_args=2) @@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore): return rows[0][0] forgot = yield self.runInteraction("did_forget_membership_at", f) defer.returnValue(forgot == 1) + + @cached() + def who_forgot_in_room(self, room_id): + return self._simple_select_list( + table="room_memberships", + retcols=("user_id", "event_id"), + keyvalues={ + "room_id": room_id, + "forgotten": 1, + }, + desc="who_forgot" + ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py deleted file mode 100644 index 11a3d94bb0..0000000000 --- a/tests/handlers/test_federation.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from tests import unittest - -from synapse.api.constants import EventTypes -from synapse.events import FrozenEvent -from synapse.handlers.federation import FederationHandler - -from mock import NonCallableMock, ANY, Mock - -from ..utils import setup_test_homeserver - - -class FederationTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - - self.state_handler = NonCallableMock(spec_set=[ - "compute_event_context", - ]) - - self.auth = NonCallableMock(spec_set=[ - "check", - "check_host_in_room", - ]) - - self.hostname = "test" - hs = yield setup_test_homeserver( - self.hostname, - datastore=NonCallableMock(spec_set=[ - "persist_event", - "store_room", - "get_room", - "get_destination_retry_timings", - "set_destination_retry_timings", - "have_events", - "get_users_in_room", - "bulk_get_push_rules", - "get_current_state", - "set_push_actions_for_event_and_users", - "is_guest", - "get_state_for_events", - ]), - resource_for_federation=NonCallableMock(), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_member_handler", - "federation_handler", - ]), - auth=self.auth, - state_handler=self.state_handler, - keyring=Mock(), - ) - - self.datastore = hs.get_datastore() - self.handlers = hs.get_handlers() - self.notifier = hs.get_notifier() - self.hs = hs - - self.handlers.federation_handler = FederationHandler(self.hs) - - self.datastore.get_state_for_events.return_value = {"$a:b": {}} - - @defer.inlineCallbacks - def test_msg(self): - pdu = FrozenEvent({ - "type": EventTypes.Message, - "room_id": "foo", - "content": {"msgtype": u"fooo"}, - "origin_server_ts": 0, - "event_id": "$a:b", - "user_id":"@a:b", - "origin": "b", - "auth_events": [], - "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, - }) - - self.datastore.persist_event.return_value = defer.succeed((1,1)) - self.datastore.get_room.return_value = defer.succeed(True) - self.datastore.get_users_in_room.return_value = ["@a:b"] - self.datastore.bulk_get_push_rules.return_value = {} - self.datastore.get_current_state.return_value = {} - self.auth.check_host_in_room.return_value = defer.succeed(True) - - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - self.datastore.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) - ) - - def have_events(event_ids): - return defer.succeed({}) - self.datastore.have_events.side_effect = have_events - - def annotate(ev, old_state=None, outlier=False): - context = Mock() - context.current_state = {} - context.auth_events = {} - return defer.succeed(context) - self.state_handler.compute_event_context.side_effect = annotate - - yield self.handlers.federation_handler.on_receive_pdu( - "fo", pdu, False - ) - - self.datastore.persist_event.assert_called_once_with( - ANY, - is_new_state=True, - backfilled=False, - current_state=None, - context=ANY, - ) - - self.state_handler.compute_event_context.assert_called_once_with( - ANY, old_state=None, outlier=False - ) - - self.auth.check.assert_called_once_with(ANY, auth_events={}) - - self.notifier.on_new_room_event.assert_called_once_with( - ANY, 1, 1, extra_users=[] - ) diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py deleted file mode 100644 index e7a12a2ba2..0000000000 --- a/tests/handlers/test_room.py +++ /dev/null @@ -1,418 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from .. import unittest - -from synapse.api.constants import EventTypes, Membership -from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler -from synapse.handlers.profile import ProfileHandler -from synapse.types import UserID -from ..utils import setup_test_homeserver - -from mock import Mock, NonCallableMock - - -class RoomMemberHandlerTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.hostname = "red" - hs = yield setup_test_homeserver( - self.hostname, - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - datastore=NonCallableMock(spec_set=[ - "persist_event", - "get_room_member", - "get_room", - "store_room", - "get_latest_events_in_room", - "add_event_hashes", - "get_users_in_room", - "bulk_get_push_rules", - "get_current_state", - "set_push_actions_for_event_and_users", - "get_state_for_events", - "is_guest", - ]), - resource_for_federation=NonCallableMock(), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_member_handler", - "profile_handler", - "federation_handler", - ]), - auth=NonCallableMock(spec_set=[ - "check", - "add_auth_events", - "check_host_in_room", - ]), - state_handler=NonCallableMock(spec_set=[ - "compute_event_context", - "get_current_state", - ]), - ) - - self.federation = NonCallableMock(spec_set=[ - "handle_new_event", - "send_invite", - "get_state_for_room", - ]) - - self.datastore = hs.get_datastore() - self.handlers = hs.get_handlers() - self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() - self.distributor = hs.get_distributor() - self.auth = hs.get_auth() - self.hs = hs - - self.handlers.federation_handler = self.federation - - self.distributor.declare("collect_presencelike_data") - - self.handlers.room_member_handler = RoomMemberHandler(self.hs) - self.handlers.profile_handler = ProfileHandler(self.hs) - self.room_member_handler = self.handlers.room_member_handler - - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - self.datastore.persist_event.return_value = (1,1) - self.datastore.add_event_hashes.return_value = [] - self.datastore.get_users_in_room.return_value = ["@bob:red"] - self.datastore.bulk_get_push_rules.return_value = {} - - @defer.inlineCallbacks - def test_invite(self): - room_id = "!foo:red" - user_id = "@bob:red" - target_user_id = "@red:blue" - content = {"membership": Membership.INVITE} - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": target_user_id, - "room_id": room_id, - "content": content, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@alice:green"): self._create_member( - user_id="@alice:green", - room_id=room_id, - ), - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - def send_invite(domain, event): - return defer.succeed(event) - - self.federation.send_invite.side_effect = send_invite - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - yield room_handler.send_membership_event(event, context) - - self.state_handler.compute_event_context.assert_called_once_with( - builder - ) - - self.auth.add_auth_events.assert_called_once_with( - builder, context - ) - - self.federation.send_invite.assert_called_once_with( - "blue", event, - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context, - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[UserID.from_string(target_user_id)] - ) - self.assertFalse(self.datastore.get_room.called) - self.assertFalse(self.datastore.store_room.called) - self.assertFalse(self.federation.get_state_for_room.called) - - @defer.inlineCallbacks - def test_simple_join(self): - room_id = "!foo:red" - user_id = "@bob:red" - user = UserID.from_string(user_id) - - join_signal_observer = Mock() - self.distributor.observe("user_joined_room", join_signal_observer) - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": Membership.JOIN}, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - membership=Membership.INVITE - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - # Actual invocation - yield room_handler.send_membership_event(event, context) - - self.federation.handle_new_event.assert_called_once_with( - event, destinations=set() - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[user] - ) - - join_signal_observer.assert_called_with( - user=user, room_id=room_id - ) - - def _create_member(self, user_id, room_id, membership=Membership.JOIN): - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": membership}, - }) - - return builder.build() - - @defer.inlineCallbacks - def test_simple_leave(self): - room_id = "!foo:red" - user_id = "@bob:red" - user = UserID.from_string(user_id) - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": Membership.LEAVE}, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - membership=Membership.JOIN - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - leave_signal_observer = Mock() - self.distributor.observe("user_left_room", leave_signal_observer) - - # Actual invocation - yield room_handler.send_membership_event(event, context) - - self.federation.handle_new_event.assert_called_once_with( - event, destinations=set(['red']) - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[user] - ) - - leave_signal_observer.assert_called_with( - user=user, room_id=room_id - ) - - -class RoomCreationTest(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.hostname = "red" - - hs = yield setup_test_homeserver( - self.hostname, - datastore=NonCallableMock(spec_set=[ - "store_room", - "snapshot_room", - "persist_event", - "get_joined_hosts_for_room", - ]), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_creation_handler", - "message_handler", - ]), - auth=NonCallableMock(spec_set=["check", "add_auth_events"]), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - - self.federation = NonCallableMock(spec_set=[ - "handle_new_event", - ]) - - self.handlers = hs.get_handlers() - - self.handlers.room_creation_handler = RoomCreationHandler(hs) - self.room_creation_handler = self.handlers.room_creation_handler - - self.message_handler = self.handlers.message_handler - - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - @defer.inlineCallbacks - def test_room_creation(self): - user_id = "@foo:red" - room_id = "!bobs_room:red" - config = {"visibility": "private"} - - yield self.room_creation_handler.create_room( - user_id=user_id, - room_id=room_id, - config=config, - ) - - self.assertTrue(self.message_handler.create_and_send_event.called) - - event_dicts = [ - e[0][0] - for e in self.message_handler.create_and_send_event.call_args_list - ] - - self.assertTrue(len(event_dicts) > 3) - - self.assertDictContainsSubset( - { - "type": EventTypes.Create, - "sender": user_id, - "room_id": room_id, - }, - event_dicts[0] - ) - - self.assertEqual(user_id, event_dicts[0]["content"]["creator"]) - - self.assertDictContainsSubset( - { - "type": EventTypes.Member, - "sender": user_id, - "room_id": room_id, - "state_key": user_id, - }, - event_dicts[1] - ) - - self.assertEqual( - Membership.JOIN, - event_dicts[1]["content"]["membership"] - ) |