diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index dc49df0812..8028663fa8 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -59,7 +59,7 @@ class DictProperty:
#
# To exclude the KeyError from the traceback, we explicitly
# 'raise from e1.__context__' (which is better than 'raise from None',
- # becuase that would omit any *earlier* exceptions).
+ # because that would omit any *earlier* exceptions).
#
raise AttributeError(
"'%s' has no '%s' property" % (type(instance), self.key)
@@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty):
class _EventInternalMetadata:
- __slots__ = ["_dict"]
+ __slots__ = ["_dict", "stream_ordering"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
# reused. TODO: fix that
self._dict = dict(internal_metadata_dict)
+ # the stream ordering of this event. None, until it has been persisted.
+ self.stream_ordering = None # type: Optional[int]
+
outlier = DictProperty("outlier") # type: bool
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
@@ -113,7 +116,6 @@ class _EventInternalMetadata:
redacted = DictProperty("redacted") # type: bool
txn_id = DictProperty("txn_id") # type: str
token_id = DictProperty("token_id") # type: str
- stream_ordering = DictProperty("stream_ordering") # type: int
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
@@ -310,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta):
"""
return [e for e, _ in self.auth_events]
+ def freeze(self):
+ """'Freeze' the event dict, so it cannot be modified by accident"""
+
+ # this will be a no-op if the event dict is already frozen.
+ self._dict = freeze(self._dict)
+
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
@@ -360,7 +368,7 @@ class FrozenEvent(EventBase):
return self.__repr__()
def __repr__(self):
- return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
+ return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
@@ -443,7 +451,7 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
- return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+ return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index b6c47be646..07df258e6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self):
return self._state_key is not None
- async def build(self, prev_event_ids: List[str]) -> EventBase:
+ async def build(
+ self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
+ ) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids: The event IDs to use as the prev events
+ auth_event_ids: The event IDs to use as the auth events.
+ Should normally be set to None, which will cause them to be calculated
+ based on the room state at the prev_events.
Returns:
The signed and hashed event.
"""
-
- state_ids = await self._state.get_current_state_ids(
- self.room_id, prev_event_ids
- )
- auth_ids = self._auth.compute_auth_events(self, state_ids)
+ if auth_event_ids is None:
+ state_ids = await self._state.get_current_state_ids(
+ self.room_id, prev_event_ids
+ )
+ auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
- auth_ids
+ auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
- auth_events = auth_ids
+ auth_events = auth_event_ids
prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids)
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index b0fc859a47..936896656a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,26 +15,26 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
+ import synapse.events
import synapse.server
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
self.spam_checkers = [] # type: List[Any]
+ api = hs.get_module_api()
for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
- api = SpamCheckerApi(hs)
self.spam_checkers.append(module(config=config, api=api))
else:
self.spam_checkers.append(module(config=config))
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..77fbd3f68a 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable, Union
+
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules:
@@ -38,20 +40,25 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=hs.get_module_api(),
)
async def check_event_allowed(
self, event: EventBase, context: EventContext
- ) -> bool:
+ ) -> Union[bool, dict]:
"""Check if a provided event should be allowed in the given context.
+ The module can return:
+ * True: the event is allowed.
+ * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
+ * a dict: replacement event data.
+
Args:
event: The event to be checked.
context: The context of the event.
Returns:
- True if the event should be allowed, False if not.
+ The result from the ThirdPartyRules module, as above
"""
if self.third_party_rules is None:
return True
@@ -59,12 +66,15 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
- state_events = {}
- for key, event_id in prev_state_ids.items():
- state_events[key] = await self.store.get_event(event_id, allow_none=True)
+ events = await self.store.get_events(prev_state_ids.values())
+ state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
- ret = await self.third_party_rules.check_event_allowed(event, state_events)
- return ret
+ # Ensure that the event is frozen, to make sure that the module is not tempted
+ # to try to modify it. Any attempt to modify it at this point will invalidate
+ # the hashes and signatures.
+ event.freeze()
+
+ return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
@@ -106,6 +116,48 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +165,4 @@ class ThirdPartyEventRules:
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 32c73d3413..14f7f1156f 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase:
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
+ # copy the internal fields
+ pruned_event.internal_metadata.stream_ordering = (
+ event.internal_metadata.stream_ordering
+ )
+
# Mark the event as redacted
pruned_event.internal_metadata.redacted = True
@@ -175,7 +180,7 @@ def only_fields(dictionary, fields):
in 'fields'.
If there are no event fields specified then all fields are included.
- The entries may include '.' charaters to indicate sub-fields.
+ The entries may include '.' characters to indicate sub-fields.
So ['content.body'] will include the 'body' field of the 'content' object.
A literal '.' character in a field name may be escaped using a '\'.
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 9df35b54ba..f8f3b1a31e 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,20 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
+from synapse.config.homeserver import HomeServerConfig
+from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
from synapse.events.utils import validate_canonicaljson
+from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
class EventValidator:
- def validate_new(self, event, config):
+ def validate_new(self, event: EventBase, config: HomeServerConfig):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent): The event to validate.
- config (Config): The homeserver's configuration.
+ event: The event to validate.
+ config: The homeserver's configuration.
"""
self.validate_builder(event)
@@ -76,13 +82,22 @@ class EventValidator:
if event.type == EventTypes.Retention:
self._validate_retention(event)
- def _validate_retention(self, event):
+ if event.type == EventTypes.ServerACL:
+ if not server_matches_acl_event(config.server_name, event):
+ raise SynapseError(
+ 400, "Can't create an ACL event that denies the local server"
+ )
+
+ def _validate_retention(self, event: EventBase):
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
Args:
- event (FrozenEvent): The event to validate.
+ event: The event to validate.
"""
+ if not event.is_state():
+ raise SynapseError(code=400, msg="must be a state event")
+
min_lifetime = event.content.get("min_lifetime")
max_lifetime = event.content.get("max_lifetime")
@@ -113,13 +128,10 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
- def validate_builder(self, event):
+ def validate_builder(self, event: Union[EventBase, EventBuilder]):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
-
- Args:
- event (EventBuilder|FrozenEvent)
"""
strings = ["room_id", "sender", "type"]
|