diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 1edd19cc13..d3de70e671 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -92,6 +92,18 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "soft_failed", False)
+ def should_proactively_send(self):
+ """Whether the event, if ours, should be sent to other clients and
+ servers.
+
+ This is used for sending dummy events internally. Servers and clients
+ can still explicitly fetch the event.
+
+ Returns:
+ bool
+ """
+ return getattr(self, "proactively_send", True)
+
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@@ -115,25 +127,25 @@ def _event_dict_property(key):
except KeyError:
raise AttributeError(key)
- return property(
- getter,
- setter,
- delete,
- )
+ return property(getter, setter, delete)
class EventBase(object):
- def __init__(self, event_dict, signatures={}, unsigned={},
- internal_metadata_dict={}, rejected_reason=None):
+ def __init__(
+ self,
+ event_dict,
+ signatures={},
+ unsigned={},
+ internal_metadata_dict={},
+ rejected_reason=None,
+ ):
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict
- self.internal_metadata = _EventInternalMetadata(
- internal_metadata_dict
- )
+ self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
@@ -156,10 +168,7 @@ class EventBase(object):
def get_dict(self):
d = dict(self._event_dict)
- d.update({
- "signatures": self.signatures,
- "unsigned": dict(self.unsigned),
- })
+ d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
return d
@@ -346,6 +355,7 @@ class FrozenEventV2(EventBase):
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
@@ -402,6 +412,4 @@ def event_type_from_format_version(format_version):
elif format_version == EventFormatVersions.V3:
return FrozenEventV3
else:
- raise Exception(
- "No event format %r" % (format_version,)
- )
+ raise Exception("No event format %r" % (format_version,))
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 546b6f4982..db011e0407 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -78,7 +78,9 @@ class EventBuilder(object):
_redacts = attr.ib(default=None)
_origin_server_ts = attr.ib(default=None)
- internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
+ internal_metadata = attr.ib(
+ default=attr.Factory(lambda: _EventInternalMetadata({}))
+ )
@property
def state_key(self):
@@ -102,11 +104,9 @@ class EventBuilder(object):
"""
state_ids = yield self._state.get_current_state_ids(
- self.room_id, prev_event_ids,
- )
- auth_ids = yield self._auth.compute_auth_events(
- self, state_ids,
+ self.room_id, prev_event_ids
)
+ auth_ids = yield self._auth.compute_auth_events(self, state_ids)
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
@@ -115,9 +115,7 @@ class EventBuilder(object):
auth_events = auth_ids
prev_events = prev_event_ids
- old_depth = yield self._store.get_max_depth_of(
- prev_event_ids,
- )
+ old_depth = yield self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
@@ -217,9 +215,14 @@ class EventBuilderFactory(object):
)
-def create_local_event_from_event_dict(clock, hostname, signing_key,
- format_version, event_dict,
- internal_metadata_dict=None):
+def create_local_event_from_event_dict(
+ clock,
+ hostname,
+ signing_key,
+ format_version,
+ event_dict,
+ internal_metadata_dict=None,
+):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
@@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
"""
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
- raise Exception(
- "No event format defined for version %r" % (format_version,)
- )
+ raise Exception("No event format defined for version %r" % (format_version,))
if internal_metadata_dict is None:
internal_metadata_dict = {}
@@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict.setdefault("signatures", {})
- add_hashes_and_signatures(
- event_dict,
- hostname,
- signing_key,
- )
+ add_hashes_and_signatures(event_dict, hostname, signing_key)
return event_type_from_format_version(format_version)(
- event_dict, internal_metadata_dict=internal_metadata_dict,
+ event_dict, internal_metadata_dict=internal_metadata_dict
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index fa09c132a0..a96cdada3d 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -88,8 +88,9 @@ class EventContext(object):
self.app_service = None
@staticmethod
- def with_state(state_group, current_state_ids, prev_state_ids,
- prev_group=None, delta_ids=None):
+ def with_state(
+ state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
+ ):
context = EventContext()
# The current state including the current event
@@ -132,17 +133,19 @@ class EventContext(object):
else:
prev_state_id = None
- defer.returnValue({
- "prev_state_id": prev_state_id,
- "event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
- "rejected": self.rejected,
- "prev_group": self.prev_group,
- "delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
- "app_service_id": self.app_service.id if self.app_service else None
- })
+ defer.returnValue(
+ {
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None,
+ }
+ )
@staticmethod
def deserialize(store, input):
@@ -194,7 +197,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
+ self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -214,7 +217,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
+ self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -240,9 +243,7 @@ class EventContext(object):
if self.state_group is None:
return
- self._current_state_ids = yield store.get_state_ids_for_group(
- self.state_group,
- )
+ self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
@@ -252,8 +253,9 @@ class EventContext(object):
self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
- def update_state(self, state_group, prev_state_ids, current_state_ids,
- prev_group, delta_ids):
+ def update_state(
+ self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
+ ):
"""Replace the state in the context
"""
@@ -279,10 +281,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
- return [
- (etype, state_key, v)
- for (etype, state_key), v in iteritems(state_dict)
- ]
+ return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
def _decode_state_dict(input):
@@ -291,4 +290,4 @@ def _decode_state_dict(input):
if input is None:
return None
- return frozendict({(etype, state_key,): v for etype, state_key, v in input})
+ return frozendict({(etype, state_key): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 6058077f75..129771f183 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -60,7 +60,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
- return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ return self.spam_checker.user_may_invite(
+ inviter_userid, invitee_userid, room_id
+ )
def user_may_create_room(self, userid):
"""Checks if a given user may create a room
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
new file mode 100644
index 0000000000..8f5d95696b
--- /dev/null
+++ b/synapse/events/third_party_rules.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+
+class ThirdPartyEventRules(object):
+ """Allows server admins to provide a Python module implementing an extra
+ set of rules to apply when processing events.
+
+ This is designed to help admins of closed federations with enforcing custom
+ behaviours.
+ """
+
+ def __init__(self, hs):
+ self.third_party_rules = None
+
+ self.store = hs.get_datastore()
+
+ module = None
+ config = None
+ if hs.config.third_party_event_rules:
+ module, config = hs.config.third_party_event_rules
+
+ if module is not None:
+ self.third_party_rules = module(
+ config=config, http_client=hs.get_simple_http_client()
+ )
+
+ @defer.inlineCallbacks
+ def check_event_allowed(self, event, context):
+ """Check if a provided event should be allowed in the given context.
+
+ Args:
+ event (synapse.events.EventBase): The event to be checked.
+ context (synapse.events.snapshot.EventContext): The context of the event.
+
+ Returns:
+ defer.Deferred[bool]: True if the event should be allowed, False if not.
+ """
+ if self.third_party_rules is None:
+ defer.returnValue(True)
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ # Retrieve the state events from the database.
+ state_events = {}
+ for key, event_id in prev_state_ids.items():
+ state_events[key] = yield self.store.get_event(event_id, allow_none=True)
+
+ ret = yield self.third_party_rules.check_event_allowed(event, state_events)
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def on_create_room(self, requester, config, is_requester_admin):
+ """Intercept requests to create room to allow, deny or update the
+ request config.
+
+ Args:
+ requester (Requester)
+ config (dict): The creation config from the client.
+ is_requester_admin (bool): If the requester is an admin
+
+ Returns:
+ defer.Deferred
+ """
+
+ if self.third_party_rules is None:
+ return
+
+ yield self.third_party_rules.on_create_room(
+ requester, config, is_requester_admin
+ )
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, room_id):
+ """Check if a provided 3PID can be invited in the given room.
+
+ Args:
+ medium (str): The 3PID's medium.
+ address (str): The 3PID's address.
+ room_id (str): The room we want to invite the threepid to.
+
+ Returns:
+ defer.Deferred[bool], True if the 3PID can be invited, False if not.
+ """
+
+ if self.third_party_rules is None:
+ defer.returnValue(True)
+
+ state_ids = yield self.store.get_filtered_current_state_ids(room_id)
+ room_state_events = yield self.store.get_events(state_ids.values())
+
+ state_events = {}
+ for key, event_id in state_ids.items():
+ state_events[key] = room_state_events[event_id]
+
+ ret = yield self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ defer.returnValue(ret)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e2d4384de1..f24f0c16f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -31,7 +31,7 @@ from . import EventBase
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
-SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event):
@@ -51,6 +51,7 @@ def prune_event(event):
pruned_event_dict = prune_event_dict(event.get_dict())
from . import event_type_from_format_version
+
return event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
@@ -116,11 +117,7 @@ def prune_event_dict(event_dict):
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
- allowed_fields = {
- k: v
- for k, v in event_dict.items()
- if k in allowed_keys
- }
+ allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
allowed_fields["content"] = new_content
@@ -205,7 +202,7 @@ def only_fields(dictionary, fields):
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
- [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+ [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
@@ -226,7 +223,10 @@ def format_event_for_client_v1(d):
d["user_id"] = sender
copy_keys = (
- "age", "redacted_because", "replaces_state", "prev_content",
+ "age",
+ "redacted_because",
+ "replaces_state",
+ "prev_content",
"invite_room_state",
)
for key in copy_keys:
@@ -238,8 +238,13 @@ def format_event_for_client_v1(d):
def format_event_for_client_v2(d):
drop_keys = (
- "auth_events", "prev_events", "hashes", "signatures", "depth",
- "origin", "prev_state",
+ "auth_events",
+ "prev_events",
+ "hashes",
+ "signatures",
+ "depth",
+ "origin",
+ "prev_state",
)
for key in drop_keys:
d.pop(key, None)
@@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d):
return d
-def serialize_event(e, time_now_ms, as_client_event=True,
- event_format=format_event_for_client_v1,
- token_id=None, only_event_fields=None, is_invite=False):
+def serialize_event(
+ e,
+ time_now_ms,
+ as_client_event=True,
+ event_format=format_event_for_client_v1,
+ token_id=None,
+ only_event_fields=None,
+ is_invite=False,
+):
"""Serialize event for clients
Args:
@@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
- e.unsigned["redacted_because"], time_now_ms,
- event_format=event_format
+ e.unsigned["redacted_because"], time_now_ms, event_format=event_format
)
if token_id is not None:
@@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = event_format(d)
if only_event_fields:
- if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, string_types) for f in only_event_fields)):
+ if not isinstance(only_event_fields, list) or not all(
+ isinstance(f, string_types) for f in only_event_fields
+ ):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
@@ -352,11 +363,9 @@ class EventClientSerializer(object):
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
if self.experimental_msc1849_support_enabled and bundle_aggregations:
- annotations = yield self.store.get_aggregation_groups_for_event(
- event_id,
- )
+ annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f",
+ event_id, RelationTypes.REFERENCE, direction="f"
)
if annotations.chunk:
@@ -383,9 +392,7 @@ class EventClientSerializer(object):
serialized_event["content"].pop("m.relates_to", None)
r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.REPLACE] = {
- "event_id": edit.event_id,
- }
+ r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
defer.returnValue(serialized_event)
@@ -401,6 +408,5 @@ class EventClientSerializer(object):
Deferred[list[dict]]: The list of serialized events
"""
return yieldable_gather_results(
- self.serialize_event, events,
- time_now=time_now, **kwargs
+ self.serialize_event, events, time_now=time_now, **kwargs
)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 711af512b2..f7ffd1d561 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -48,9 +48,7 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
- event_strings = [
- "origin",
- ]
+ event_strings = ["origin"]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
@@ -62,8 +60,10 @@ class EventValidator(object):
if len(alias) > MAX_ALIAS_LENGTH:
raise SynapseError(
400,
- ("Can't create aliases longer than"
- " %d characters" % (MAX_ALIAS_LENGTH,)),
+ (
+ "Can't create aliases longer than"
+ " %d characters" % (MAX_ALIAS_LENGTH,)
+ ),
Codes.INVALID_PARAM,
)
@@ -76,11 +76,7 @@ class EventValidator(object):
event (EventBuilder|FrozenEvent)
"""
- strings = [
- "room_id",
- "sender",
- "type",
- ]
+ strings = ["room_id", "sender", "type"]
if hasattr(event, "state_key"):
strings.append("state_key")
@@ -93,10 +89,7 @@ class EventValidator(object):
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
- strings = [
- "body",
- "msgtype",
- ]
+ strings = ["body", "msgtype"]
self._ensure_strings(event.content, strings)
|