summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/__init__.py42
-rw-r--r--synapse/events/builder.py37
-rw-r--r--synapse/events/snapshot.py49
-rw-r--r--synapse/events/spamcheck.py4
-rw-r--r--synapse/events/third_party_rules.py113
-rw-r--r--synapse/events/utils.py58
-rw-r--r--synapse/events/validator.py21
7 files changed, 221 insertions, 103 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 1edd19cc13..d3de70e671 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -92,6 +92,18 @@ class _EventInternalMetadata(object):
         """
         return getattr(self, "soft_failed", False)
 
+    def should_proactively_send(self):
+        """Whether the event, if ours, should be sent to other clients and
+        servers.
+
+        This is used for sending dummy events internally. Servers and clients
+        can still explicitly fetch the event.
+
+        Returns:
+            bool
+        """
+        return getattr(self, "proactively_send", True)
+
 
 def _event_dict_property(key):
     # We want to be able to use hasattr with the event dict properties.
@@ -115,25 +127,25 @@ def _event_dict_property(key):
         except KeyError:
             raise AttributeError(key)
 
-    return property(
-        getter,
-        setter,
-        delete,
-    )
+    return property(getter, setter, delete)
 
 
 class EventBase(object):
-    def __init__(self, event_dict, signatures={}, unsigned={},
-                 internal_metadata_dict={}, rejected_reason=None):
+    def __init__(
+        self,
+        event_dict,
+        signatures={},
+        unsigned={},
+        internal_metadata_dict={},
+        rejected_reason=None,
+    ):
         self.signatures = signatures
         self.unsigned = unsigned
         self.rejected_reason = rejected_reason
 
         self._event_dict = event_dict
 
-        self.internal_metadata = _EventInternalMetadata(
-            internal_metadata_dict
-        )
+        self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
 
     auth_events = _event_dict_property("auth_events")
     depth = _event_dict_property("depth")
@@ -156,10 +168,7 @@ class EventBase(object):
 
     def get_dict(self):
         d = dict(self._event_dict)
-        d.update({
-            "signatures": self.signatures,
-            "unsigned": dict(self.unsigned),
-        })
+        d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
 
         return d
 
@@ -346,6 +355,7 @@ class FrozenEventV2(EventBase):
 
 class FrozenEventV3(FrozenEventV2):
     """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+
     format_version = EventFormatVersions.V3  # All events of this type are V3
 
     @property
@@ -402,6 +412,4 @@ def event_type_from_format_version(format_version):
     elif format_version == EventFormatVersions.V3:
         return FrozenEventV3
     else:
-        raise Exception(
-            "No event format %r" % (format_version,)
-        )
+        raise Exception("No event format %r" % (format_version,))
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 546b6f4982..db011e0407 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -78,7 +78,9 @@ class EventBuilder(object):
     _redacts = attr.ib(default=None)
     _origin_server_ts = attr.ib(default=None)
 
-    internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
+    internal_metadata = attr.ib(
+        default=attr.Factory(lambda: _EventInternalMetadata({}))
+    )
 
     @property
     def state_key(self):
@@ -102,11 +104,9 @@ class EventBuilder(object):
         """
 
         state_ids = yield self._state.get_current_state_ids(
-            self.room_id, prev_event_ids,
-        )
-        auth_ids = yield self._auth.compute_auth_events(
-            self, state_ids,
+            self.room_id, prev_event_ids
         )
+        auth_ids = yield self._auth.compute_auth_events(self, state_ids)
 
         if self.format_version == EventFormatVersions.V1:
             auth_events = yield self._store.add_event_hashes(auth_ids)
@@ -115,9 +115,7 @@ class EventBuilder(object):
             auth_events = auth_ids
             prev_events = prev_event_ids
 
-        old_depth = yield self._store.get_max_depth_of(
-            prev_event_ids,
-        )
+        old_depth = yield self._store.get_max_depth_of(prev_event_ids)
         depth = old_depth + 1
 
         # we cap depth of generated events, to ensure that they are not
@@ -217,9 +215,14 @@ class EventBuilderFactory(object):
         )
 
 
-def create_local_event_from_event_dict(clock, hostname, signing_key,
-                                       format_version, event_dict,
-                                       internal_metadata_dict=None):
+def create_local_event_from_event_dict(
+    clock,
+    hostname,
+    signing_key,
+    format_version,
+    event_dict,
+    internal_metadata_dict=None,
+):
     """Takes a fully formed event dict, ensuring that fields like `origin`
     and `origin_server_ts` have correct values for a locally produced event,
     then signs and hashes it.
@@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
     """
 
     if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
-        raise Exception(
-            "No event format defined for version %r" % (format_version,)
-        )
+        raise Exception("No event format defined for version %r" % (format_version,))
 
     if internal_metadata_dict is None:
         internal_metadata_dict = {}
@@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
 
     event_dict.setdefault("signatures", {})
 
-    add_hashes_and_signatures(
-        event_dict,
-        hostname,
-        signing_key,
-    )
+    add_hashes_and_signatures(event_dict, hostname, signing_key)
     return event_type_from_format_version(format_version)(
-        event_dict, internal_metadata_dict=internal_metadata_dict,
+        event_dict, internal_metadata_dict=internal_metadata_dict
     )
 
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index fa09c132a0..a96cdada3d 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -88,8 +88,9 @@ class EventContext(object):
         self.app_service = None
 
     @staticmethod
-    def with_state(state_group, current_state_ids, prev_state_ids,
-                   prev_group=None, delta_ids=None):
+    def with_state(
+        state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
+    ):
         context = EventContext()
 
         # The current state including the current event
@@ -132,17 +133,19 @@ class EventContext(object):
         else:
             prev_state_id = None
 
-        defer.returnValue({
-            "prev_state_id": prev_state_id,
-            "event_type": event.type,
-            "event_state_key": event.state_key if event.is_state() else None,
-            "state_group": self.state_group,
-            "rejected": self.rejected,
-            "prev_group": self.prev_group,
-            "delta_ids": _encode_state_dict(self.delta_ids),
-            "prev_state_events": self.prev_state_events,
-            "app_service_id": self.app_service.id if self.app_service else None
-        })
+        defer.returnValue(
+            {
+                "prev_state_id": prev_state_id,
+                "event_type": event.type,
+                "event_state_key": event.state_key if event.is_state() else None,
+                "state_group": self.state_group,
+                "rejected": self.rejected,
+                "prev_group": self.prev_group,
+                "delta_ids": _encode_state_dict(self.delta_ids),
+                "prev_state_events": self.prev_state_events,
+                "app_service_id": self.app_service.id if self.app_service else None,
+            }
+        )
 
     @staticmethod
     def deserialize(store, input):
@@ -194,7 +197,7 @@ class EventContext(object):
 
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store,
+                self._fill_out_state, store
             )
 
         yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -214,7 +217,7 @@ class EventContext(object):
 
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store,
+                self._fill_out_state, store
             )
 
         yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -240,9 +243,7 @@ class EventContext(object):
         if self.state_group is None:
             return
 
-        self._current_state_ids = yield store.get_state_ids_for_group(
-            self.state_group,
-        )
+        self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
         if self._prev_state_id and self._event_state_key is not None:
             self._prev_state_ids = dict(self._current_state_ids)
 
@@ -252,8 +253,9 @@ class EventContext(object):
             self._prev_state_ids = self._current_state_ids
 
     @defer.inlineCallbacks
-    def update_state(self, state_group, prev_state_ids, current_state_ids,
-                     prev_group, delta_ids):
+    def update_state(
+        self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
+    ):
         """Replace the state in the context
         """
 
@@ -279,10 +281,7 @@ def _encode_state_dict(state_dict):
     if state_dict is None:
         return None
 
-    return [
-        (etype, state_key, v)
-        for (etype, state_key), v in iteritems(state_dict)
-    ]
+    return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
 
 
 def _decode_state_dict(input):
@@ -291,4 +290,4 @@ def _decode_state_dict(input):
     if input is None:
         return None
 
-    return frozendict({(etype, state_key,): v for etype, state_key, v in input})
+    return frozendict({(etype, state_key): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 6058077f75..129771f183 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -60,7 +60,9 @@ class SpamChecker(object):
         if self.spam_checker is None:
             return True
 
-        return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+        return self.spam_checker.user_may_invite(
+            inviter_userid, invitee_userid, room_id
+        )
 
     def user_may_create_room(self, userid):
         """Checks if a given user may create a room
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
new file mode 100644
index 0000000000..8f5d95696b
--- /dev/null
+++ b/synapse/events/third_party_rules.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+
+class ThirdPartyEventRules(object):
+    """Allows server admins to provide a Python module implementing an extra
+    set of rules to apply when processing events.
+
+    This is designed to help admins of closed federations with enforcing custom
+    behaviours.
+    """
+
+    def __init__(self, hs):
+        self.third_party_rules = None
+
+        self.store = hs.get_datastore()
+
+        module = None
+        config = None
+        if hs.config.third_party_event_rules:
+            module, config = hs.config.third_party_event_rules
+
+        if module is not None:
+            self.third_party_rules = module(
+                config=config, http_client=hs.get_simple_http_client()
+            )
+
+    @defer.inlineCallbacks
+    def check_event_allowed(self, event, context):
+        """Check if a provided event should be allowed in the given context.
+
+        Args:
+            event (synapse.events.EventBase): The event to be checked.
+            context (synapse.events.snapshot.EventContext): The context of the event.
+
+        Returns:
+            defer.Deferred[bool]: True if the event should be allowed, False if not.
+        """
+        if self.third_party_rules is None:
+            defer.returnValue(True)
+
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+        # Retrieve the state events from the database.
+        state_events = {}
+        for key, event_id in prev_state_ids.items():
+            state_events[key] = yield self.store.get_event(event_id, allow_none=True)
+
+        ret = yield self.third_party_rules.check_event_allowed(event, state_events)
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def on_create_room(self, requester, config, is_requester_admin):
+        """Intercept requests to create room to allow, deny or update the
+        request config.
+
+        Args:
+            requester (Requester)
+            config (dict): The creation config from the client.
+            is_requester_admin (bool): If the requester is an admin
+
+        Returns:
+            defer.Deferred
+        """
+
+        if self.third_party_rules is None:
+            return
+
+        yield self.third_party_rules.on_create_room(
+            requester, config, is_requester_admin
+        )
+
+    @defer.inlineCallbacks
+    def check_threepid_can_be_invited(self, medium, address, room_id):
+        """Check if a provided 3PID can be invited in the given room.
+
+        Args:
+            medium (str): The 3PID's medium.
+            address (str): The 3PID's address.
+            room_id (str): The room we want to invite the threepid to.
+
+        Returns:
+            defer.Deferred[bool], True if the 3PID can be invited, False if not.
+        """
+
+        if self.third_party_rules is None:
+            defer.returnValue(True)
+
+        state_ids = yield self.store.get_filtered_current_state_ids(room_id)
+        room_state_events = yield self.store.get_events(state_ids.values())
+
+        state_events = {}
+        for key, event_id in state_ids.items():
+            state_events[key] = room_state_events[event_id]
+
+        ret = yield self.third_party_rules.check_threepid_can_be_invited(
+            medium, address, state_events
+        )
+        defer.returnValue(ret)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e2d4384de1..f24f0c16f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -31,7 +31,7 @@ from . import EventBase
 # by a match for 'stuff'.
 # TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
 #       the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
-SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
 
 
 def prune_event(event):
@@ -51,6 +51,7 @@ def prune_event(event):
     pruned_event_dict = prune_event_dict(event.get_dict())
 
     from . import event_type_from_format_version
+
     return event_type_from_format_version(event.format_version)(
         pruned_event_dict, event.internal_metadata.get_dict()
     )
@@ -116,11 +117,7 @@ def prune_event_dict(event_dict):
     elif event_type == EventTypes.RoomHistoryVisibility:
         add_fields("history_visibility")
 
-    allowed_fields = {
-        k: v
-        for k, v in event_dict.items()
-        if k in allowed_keys
-    }
+    allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
     allowed_fields["content"] = new_content
 
@@ -205,7 +202,7 @@ def only_fields(dictionary, fields):
     # for each element of the output array of arrays:
     # remove escaping so we can use the right key names.
     split_fields[:] = [
-        [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+        [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
     ]
 
     output = {}
@@ -226,7 +223,10 @@ def format_event_for_client_v1(d):
         d["user_id"] = sender
 
     copy_keys = (
-        "age", "redacted_because", "replaces_state", "prev_content",
+        "age",
+        "redacted_because",
+        "replaces_state",
+        "prev_content",
         "invite_room_state",
     )
     for key in copy_keys:
@@ -238,8 +238,13 @@ def format_event_for_client_v1(d):
 
 def format_event_for_client_v2(d):
     drop_keys = (
-        "auth_events", "prev_events", "hashes", "signatures", "depth",
-        "origin", "prev_state",
+        "auth_events",
+        "prev_events",
+        "hashes",
+        "signatures",
+        "depth",
+        "origin",
+        "prev_state",
     )
     for key in drop_keys:
         d.pop(key, None)
@@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d):
     return d
 
 
-def serialize_event(e, time_now_ms, as_client_event=True,
-                    event_format=format_event_for_client_v1,
-                    token_id=None, only_event_fields=None, is_invite=False):
+def serialize_event(
+    e,
+    time_now_ms,
+    as_client_event=True,
+    event_format=format_event_for_client_v1,
+    token_id=None,
+    only_event_fields=None,
+    is_invite=False,
+):
     """Serialize event for clients
 
     Args:
@@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
 
     if "redacted_because" in e.unsigned:
         d["unsigned"]["redacted_because"] = serialize_event(
-            e.unsigned["redacted_because"], time_now_ms,
-            event_format=event_format
+            e.unsigned["redacted_because"], time_now_ms, event_format=event_format
         )
 
     if token_id is not None:
@@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
         d = event_format(d)
 
     if only_event_fields:
-        if (not isinstance(only_event_fields, list) or
-                not all(isinstance(f, string_types) for f in only_event_fields)):
+        if not isinstance(only_event_fields, list) or not all(
+            isinstance(f, string_types) for f in only_event_fields
+        ):
             raise TypeError("only_event_fields must be a list of strings")
         d = only_fields(d, only_event_fields)
 
@@ -352,11 +363,9 @@ class EventClientSerializer(object):
         # If MSC1849 is enabled then we need to look if thre are any relations
         # we need to bundle in with the event
         if self.experimental_msc1849_support_enabled and bundle_aggregations:
-            annotations = yield self.store.get_aggregation_groups_for_event(
-                event_id,
-            )
+            annotations = yield self.store.get_aggregation_groups_for_event(event_id)
             references = yield self.store.get_relations_for_event(
-                event_id, RelationTypes.REFERENCE, direction="f",
+                event_id, RelationTypes.REFERENCE, direction="f"
             )
 
             if annotations.chunk:
@@ -383,9 +392,7 @@ class EventClientSerializer(object):
                     serialized_event["content"].pop("m.relates_to", None)
 
                 r = serialized_event["unsigned"].setdefault("m.relations", {})
-                r[RelationTypes.REPLACE] = {
-                    "event_id": edit.event_id,
-                }
+                r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
 
         defer.returnValue(serialized_event)
 
@@ -401,6 +408,5 @@ class EventClientSerializer(object):
             Deferred[list[dict]]: The list of serialized events
         """
         return yieldable_gather_results(
-            self.serialize_event, events,
-            time_now=time_now, **kwargs
+            self.serialize_event, events, time_now=time_now, **kwargs
         )
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 711af512b2..f7ffd1d561 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -48,9 +48,7 @@ class EventValidator(object):
                 raise SynapseError(400, "Event does not have key %s" % (k,))
 
         # Check that the following keys have string values
-        event_strings = [
-            "origin",
-        ]
+        event_strings = ["origin"]
 
         for s in event_strings:
             if not isinstance(getattr(event, s), string_types):
@@ -62,8 +60,10 @@ class EventValidator(object):
                     if len(alias) > MAX_ALIAS_LENGTH:
                         raise SynapseError(
                             400,
-                            ("Can't create aliases longer than"
-                             " %d characters" % (MAX_ALIAS_LENGTH,)),
+                            (
+                                "Can't create aliases longer than"
+                                " %d characters" % (MAX_ALIAS_LENGTH,)
+                            ),
                             Codes.INVALID_PARAM,
                         )
 
@@ -76,11 +76,7 @@ class EventValidator(object):
             event (EventBuilder|FrozenEvent)
         """
 
-        strings = [
-            "room_id",
-            "sender",
-            "type",
-        ]
+        strings = ["room_id", "sender", "type"]
 
         if hasattr(event, "state_key"):
             strings.append("state_key")
@@ -93,10 +89,7 @@ class EventValidator(object):
         UserID.from_string(event.sender)
 
         if event.type == EventTypes.Message:
-            strings = [
-                "body",
-                "msgtype",
-            ]
+            strings = ["body", "msgtype"]
 
             self._ensure_strings(event.content, strings)