diff options
Diffstat (limited to 'synapse/events')
-rw-r--r-- | synapse/events/__init__.py | 42 | ||||
-rw-r--r-- | synapse/events/builder.py | 37 | ||||
-rw-r--r-- | synapse/events/snapshot.py | 49 | ||||
-rw-r--r-- | synapse/events/spamcheck.py | 4 | ||||
-rw-r--r-- | synapse/events/third_party_rules.py | 59 | ||||
-rw-r--r-- | synapse/events/utils.py | 58 | ||||
-rw-r--r-- | synapse/events/validator.py | 21 |
7 files changed, 163 insertions, 107 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 1edd19cc13..d3de70e671 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -92,6 +92,18 @@ class _EventInternalMetadata(object): """ return getattr(self, "soft_failed", False) + def should_proactively_send(self): + """Whether the event, if ours, should be sent to other clients and + servers. + + This is used for sending dummy events internally. Servers and clients + can still explicitly fetch the event. + + Returns: + bool + """ + return getattr(self, "proactively_send", True) + def _event_dict_property(key): # We want to be able to use hasattr with the event dict properties. @@ -115,25 +127,25 @@ def _event_dict_property(key): except KeyError: raise AttributeError(key) - return property( - getter, - setter, - delete, - ) + return property(getter, setter, delete) class EventBase(object): - def __init__(self, event_dict, signatures={}, unsigned={}, - internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict, + signatures={}, + unsigned={}, + internal_metadata_dict={}, + rejected_reason=None, + ): self.signatures = signatures self.unsigned = unsigned self.rejected_reason = rejected_reason self._event_dict = event_dict - self.internal_metadata = _EventInternalMetadata( - internal_metadata_dict - ) + self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") @@ -156,10 +168,7 @@ class EventBase(object): def get_dict(self): d = dict(self._event_dict) - d.update({ - "signatures": self.signatures, - "unsigned": dict(self.unsigned), - }) + d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d @@ -346,6 +355,7 @@ class FrozenEventV2(EventBase): class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 @property @@ -402,6 +412,4 @@ def event_type_from_format_version(format_version): elif format_version == EventFormatVersions.V3: return FrozenEventV3 else: - raise Exception( - "No event format %r" % (format_version,) - ) + raise Exception("No event format %r" % (format_version,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 546b6f4982..db011e0407 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -78,7 +78,9 @@ class EventBuilder(object): _redacts = attr.ib(default=None) _origin_server_ts = attr.ib(default=None) - internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({}))) + internal_metadata = attr.ib( + default=attr.Factory(lambda: _EventInternalMetadata({})) + ) @property def state_key(self): @@ -102,11 +104,9 @@ class EventBuilder(object): """ state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids, - ) - auth_ids = yield self._auth.compute_auth_events( - self, state_ids, + self.room_id, prev_event_ids ) + auth_ids = yield self._auth.compute_auth_events(self, state_ids) if self.format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) @@ -115,9 +115,7 @@ class EventBuilder(object): auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of( - prev_event_ids, - ) + old_depth = yield self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not @@ -217,9 +215,14 @@ class EventBuilderFactory(object): ) -def create_local_event_from_event_dict(clock, hostname, signing_key, - format_version, event_dict, - internal_metadata_dict=None): +def create_local_event_from_event_dict( + clock, + hostname, + signing_key, + format_version, + event_dict, + internal_metadata_dict=None, +): """Takes a fully formed event dict, ensuring that fields like `origin` and `origin_server_ts` have correct values for a locally produced event, then signs and hashes it. @@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, """ if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: - raise Exception( - "No event format defined for version %r" % (format_version,) - ) + raise Exception("No event format defined for version %r" % (format_version,)) if internal_metadata_dict is None: internal_metadata_dict = {} @@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, event_dict.setdefault("signatures", {}) - add_hashes_and_signatures( - event_dict, - hostname, - signing_key, - ) + add_hashes_and_signatures(event_dict, hostname, signing_key) return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict, + event_dict, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index fa09c132a0..a96cdada3d 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -88,8 +88,9 @@ class EventContext(object): self.app_service = None @staticmethod - def with_state(state_group, current_state_ids, prev_state_ids, - prev_group=None, delta_ids=None): + def with_state( + state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None + ): context = EventContext() # The current state including the current event @@ -132,17 +133,19 @@ class EventContext(object): else: prev_state_id = None - defer.returnValue({ - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, - "rejected": self.rejected, - "prev_group": self.prev_group, - "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, - "app_service_id": self.app_service.id if self.app_service else None - }) + defer.returnValue( + { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None, + } + ) @staticmethod def deserialize(store, input): @@ -194,7 +197,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -214,7 +217,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -240,9 +243,7 @@ class EventContext(object): if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group( - self.state_group, - ) + self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) @@ -252,8 +253,9 @@ class EventContext(object): self._prev_state_ids = self._current_state_ids @defer.inlineCallbacks - def update_state(self, state_group, prev_state_ids, current_state_ids, - prev_group, delta_ids): + def update_state( + self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids + ): """Replace the state in the context """ @@ -279,10 +281,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [ - (etype, state_key, v) - for (etype, state_key), v in iteritems(state_dict) - ] + return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] def _decode_state_dict(input): @@ -291,4 +290,4 @@ def _decode_state_dict(input): if input is None: return None - return frozendict({(etype, state_key,): v for etype, state_key, v in input}) + return frozendict({(etype, state_key): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 6058077f75..129771f183 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -60,7 +60,9 @@ class SpamChecker(object): if self.spam_checker is None: return True - return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + return self.spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) def user_may_create_room(self, userid): """Checks if a given user may create a room diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9f98d51523..8f5d95696b 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -17,8 +17,8 @@ from twisted.internet import defer class ThirdPartyEventRules(object): - """Allows server admins to provide a Python module implementing an extra set of rules - to apply when processing events. + """Allows server admins to provide a Python module implementing an extra + set of rules to apply when processing events. This is designed to help admins of closed federations with enforcing custom behaviours. @@ -35,7 +35,9 @@ class ThirdPartyEventRules(object): module, config = hs.config.third_party_event_rules if module is not None: - self.third_party_rules = module(config=config) + self.third_party_rules = module( + config=config, http_client=hs.get_simple_http_client() + ) @defer.inlineCallbacks def check_event_allowed(self, event, context): @@ -46,7 +48,7 @@ class ThirdPartyEventRules(object): context (synapse.events.snapshot.EventContext): The context of the event. Returns: - defer.Deferred(bool), True if the event should be allowed, False if not. + defer.Deferred[bool]: True if the event should be allowed, False if not. """ if self.third_party_rules is None: defer.returnValue(True) @@ -60,3 +62,52 @@ class ThirdPartyEventRules(object): ret = yield self.third_party_rules.check_event_allowed(event, state_events) defer.returnValue(ret) + + @defer.inlineCallbacks + def on_create_room(self, requester, config, is_requester_admin): + """Intercept requests to create room to allow, deny or update the + request config. + + Args: + requester (Requester) + config (dict): The creation config from the client. + is_requester_admin (bool): If the requester is an admin + + Returns: + defer.Deferred + """ + + if self.third_party_rules is None: + return + + yield self.third_party_rules.on_create_room( + requester, config, is_requester_admin + ) + + @defer.inlineCallbacks + def check_threepid_can_be_invited(self, medium, address, room_id): + """Check if a provided 3PID can be invited in the given room. + + Args: + medium (str): The 3PID's medium. + address (str): The 3PID's address. + room_id (str): The room we want to invite the threepid to. + + Returns: + defer.Deferred[bool], True if the 3PID can be invited, False if not. + """ + + if self.third_party_rules is None: + defer.returnValue(True) + + state_ids = yield self.store.get_filtered_current_state_ids(room_id) + room_state_events = yield self.store.get_events(state_ids.values()) + + state_events = {} + for key, event_id in state_ids.items(): + state_events[key] = room_state_events[event_id] + + ret = yield self.third_party_rules.check_threepid_can_be_invited( + medium, address, state_events + ) + defer.returnValue(ret) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e2d4384de1..f24f0c16f0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -31,7 +31,7 @@ from . import EventBase # by a match for 'stuff'. # TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" -SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.') +SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") def prune_event(event): @@ -51,6 +51,7 @@ def prune_event(event): pruned_event_dict = prune_event_dict(event.get_dict()) from . import event_type_from_format_version + return event_type_from_format_version(event.format_version)( pruned_event_dict, event.internal_metadata.get_dict() ) @@ -116,11 +117,7 @@ def prune_event_dict(event_dict): elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") - allowed_fields = { - k: v - for k, v in event_dict.items() - if k in allowed_keys - } + allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} allowed_fields["content"] = new_content @@ -205,7 +202,7 @@ def only_fields(dictionary, fields): # for each element of the output array of arrays: # remove escaping so we can use the right key names. split_fields[:] = [ - [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields ] output = {} @@ -226,7 +223,10 @@ def format_event_for_client_v1(d): d["user_id"] = sender copy_keys = ( - "age", "redacted_because", "replaces_state", "prev_content", + "age", + "redacted_because", + "replaces_state", + "prev_content", "invite_room_state", ) for key in copy_keys: @@ -238,8 +238,13 @@ def format_event_for_client_v1(d): def format_event_for_client_v2(d): drop_keys = ( - "auth_events", "prev_events", "hashes", "signatures", "depth", - "origin", "prev_state", + "auth_events", + "prev_events", + "hashes", + "signatures", + "depth", + "origin", + "prev_state", ) for key in drop_keys: d.pop(key, None) @@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d): return d -def serialize_event(e, time_now_ms, as_client_event=True, - event_format=format_event_for_client_v1, - token_id=None, only_event_fields=None, is_invite=False): +def serialize_event( + e, + time_now_ms, + as_client_event=True, + event_format=format_event_for_client_v1, + token_id=None, + only_event_fields=None, + is_invite=False, +): """Serialize event for clients Args: @@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True, if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms, - event_format=event_format + e.unsigned["redacted_because"], time_now_ms, event_format=event_format ) if token_id is not None: @@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True, d = event_format(d) if only_event_fields: - if (not isinstance(only_event_fields, list) or - not all(isinstance(f, string_types) for f in only_event_fields)): + if not isinstance(only_event_fields, list) or not all( + isinstance(f, string_types) for f in only_event_fields + ): raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) @@ -352,11 +363,9 @@ class EventClientSerializer(object): # If MSC1849 is enabled then we need to look if thre are any relations # we need to bundle in with the event if self.experimental_msc1849_support_enabled and bundle_aggregations: - annotations = yield self.store.get_aggregation_groups_for_event( - event_id, - ) + annotations = yield self.store.get_aggregation_groups_for_event(event_id) references = yield self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f", + event_id, RelationTypes.REFERENCE, direction="f" ) if annotations.chunk: @@ -383,9 +392,7 @@ class EventClientSerializer(object): serialized_event["content"].pop("m.relates_to", None) r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - } + r[RelationTypes.REPLACE] = {"event_id": edit.event_id} defer.returnValue(serialized_event) @@ -401,6 +408,5 @@ class EventClientSerializer(object): Deferred[list[dict]]: The list of serialized events """ return yieldable_gather_results( - self.serialize_event, events, - time_now=time_now, **kwargs + self.serialize_event, events, time_now=time_now, **kwargs ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 711af512b2..f7ffd1d561 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -48,9 +48,7 @@ class EventValidator(object): raise SynapseError(400, "Event does not have key %s" % (k,)) # Check that the following keys have string values - event_strings = [ - "origin", - ] + event_strings = ["origin"] for s in event_strings: if not isinstance(getattr(event, s), string_types): @@ -62,8 +60,10 @@ class EventValidator(object): if len(alias) > MAX_ALIAS_LENGTH: raise SynapseError( 400, - ("Can't create aliases longer than" - " %d characters" % (MAX_ALIAS_LENGTH,)), + ( + "Can't create aliases longer than" + " %d characters" % (MAX_ALIAS_LENGTH,) + ), Codes.INVALID_PARAM, ) @@ -76,11 +76,7 @@ class EventValidator(object): event (EventBuilder|FrozenEvent) """ - strings = [ - "room_id", - "sender", - "type", - ] + strings = ["room_id", "sender", "type"] if hasattr(event, "state_key"): strings.append("state_key") @@ -93,10 +89,7 @@ class EventValidator(object): UserID.from_string(event.sender) if event.type == EventTypes.Message: - strings = [ - "body", - "msgtype", - ] + strings = ["body", "msgtype"] self._ensure_strings(event.content, strings) |