diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..51f9084b90 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.frozenutils import freeze
from synapse.util.caches import intern_dict
-
+from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
@@ -47,14 +46,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key):
+ # We want to be able to use hasattr with the event dict properties.
+ # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+ # we need to transform the KeyError into an AttributeError
def getter(self):
- return self._event_dict[key]
+ try:
+ return self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
def setter(self, v):
- self._event_dict[key] = v
+ try:
+ self._event_dict[key] = v
+ except KeyError:
+ raise AttributeError(key)
def delete(self):
- del self._event_dict[key]
+ try:
+ del self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
return property(
getter,
@@ -134,7 +145,7 @@ class EventBase(object):
return field in self._event_dict
def items(self):
- return self._event_dict.items()
+ return list(self._event_dict.items())
class FrozenEvent(EventBase):
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 365fd96bd2..e662eaef10 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import EventBase, FrozenEvent, _event_dict_property
+import copy
from synapse.types import EventID
-
from synapse.util.stringutils import random_string
-import copy
+from . import EventBase, FrozenEvent, _event_dict_property
class EventBuilder(EventBase):
@@ -55,7 +54,7 @@ class EventBuilderFactory(object):
local_part = str(int(self.clock.time())) + i + random_string(5)
- e_id = EventID.create(local_part, self.hostname)
+ e_id = EventID(local_part, self.hostname)
return e_id.to_string()
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..368b5f6ae4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -13,19 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import iteritems
+
+from frozendict import frozendict
+
+from twisted.internet import defer
+
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
class EventContext(object):
"""
Attributes:
- current_state_ids (dict[(str, str), str]):
- The current state map including the current event.
- (type, state_key) -> event_id
-
- prev_state_ids (dict[(str, str), str]):
- The current state map excluding the current event.
- (type, state_key) -> event_id
-
- state_group (int): state group id
+ state_group (int|None): state group id, if the state has been stored
+ as a state group. This is usually only None if e.g. the event is
+ an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else
False
@@ -39,35 +41,250 @@ class EventContext(object):
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
+
+ _current_state_ids (dict[(str, str), str]|None):
+ The current state map including the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _prev_state_ids (dict[(str, str), str]|None):
+ The current state map excluding the current event. None if outlier
+ or we haven't fetched the state from DB yet.
+ (type, state_key) -> event_id
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+ Only set when state has not been fetched yet.
+
+ _event_state_key (str|None): The state_key of the event the context is
+ associated with. Only set when state has not been fetched yet.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ Only set when state has not been fetched yet.
"""
__slots__ = [
- "current_state_ids",
- "prev_state_ids",
"state_group",
"rejected",
- "push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
"app_service",
+ "_current_state_ids",
+ "_prev_state_ids",
+ "_prev_state_id",
+ "_event_type",
+ "_event_state_key",
+ "_fetching_state_deferred",
]
def __init__(self):
+ self.prev_state_events = []
+ self.rejected = False
+ self.app_service = None
+
+ @staticmethod
+ def with_state(state_group, current_state_ids, prev_state_ids,
+ prev_group=None, delta_ids=None):
+ context = EventContext()
+
# The current state including the current event
- self.current_state_ids = None
+ context._current_state_ids = current_state_ids
# The current state excluding the current event
- self.prev_state_ids = None
- self.state_group = None
+ context._prev_state_ids = prev_state_ids
+ context.state_group = state_group
- self.rejected = False
- self.push_actions = []
+ context._prev_state_id = None
+ context._event_type = None
+ context._event_state_key = None
+ context._fetching_state_deferred = defer.succeed(None)
# A previously persisted state group and a delta between that
# and this state.
- self.prev_group = None
- self.delta_ids = None
+ context.prev_group = prev_group
+ context.delta_ids = delta_ids
- self.prev_state_events = None
+ return context
- self.app_service = None
+ @defer.inlineCallbacks
+ def serialize(self, event, store):
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Args:
+ event (FrozenEvent): The event that this context relates to
+
+ Returns:
+ dict
+ """
+
+ # We don't serialize the full state dicts, instead they get pulled out
+ # of the DB on the other side. However, the other side can't figure out
+ # the prev_state_ids, so if we're a state event we include the event
+ # id that we replaced in the state.
+ if event.is_state():
+ prev_state_ids = yield self.get_prev_state_ids(store)
+ prev_state_id = prev_state_ids.get((event.type, event.state_key))
+ else:
+ prev_state_id = None
+
+ defer.returnValue({
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None
+ })
+
+ @staticmethod
+ def deserialize(store, input):
+ """Converts a dict that was produced by `serialize` back into a
+ EventContext.
+
+ Args:
+ store (DataStore): Used to convert AS ID to AS object
+ input (dict): A dict produced by `serialize`
+
+ Returns:
+ EventContext
+ """
+ context = EventContext()
+
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ context._prev_state_id = input["prev_state_id"]
+ context._event_type = input["event_type"]
+ context._event_state_key = input["event_state_key"]
+
+ context._current_state_ids = None
+ context._prev_state_ids = None
+ context._fetching_state_deferred = None
+
+ context.state_group = input["state_group"]
+ context.prev_group = input["prev_group"]
+ context.delta_ids = _decode_state_dict(input["delta_ids"])
+
+ context.rejected = input["rejected"]
+ context.prev_state_events = input["prev_state_events"]
+
+ app_service_id = input["app_service_id"]
+ if app_service_id:
+ context.app_service = store.get_app_service_by_id(app_service_id)
+
+ return context
+
+ @defer.inlineCallbacks
+ def get_current_state_ids(self, store):
+ """Gets the current state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._current_state_ids)
+
+ @defer.inlineCallbacks
+ def get_prev_state_ids(self, store):
+ """Gets the prev state IDs
+
+ Returns:
+ Deferred[dict[(str, str), str]|None]: Returns None if state_group
+ is None, which happens when the associated event is an outlier.
+ """
+
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store,
+ )
+
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ defer.returnValue(self._prev_state_ids)
+
+ def get_cached_current_state_ids(self):
+ """Gets the current state IDs if we have them already cached.
+
+ Returns:
+ dict[(str, str), str]|None: Returns None if we haven't cached the
+ state or if state_group is None, which happens when the associated
+ event is an outlier.
+ """
+
+ return self._current_state_ids
+
+ @defer.inlineCallbacks
+ def _fill_out_state(self, store):
+ """Called to populate the _current_state_ids and _prev_state_ids
+ attributes by loading from the database.
+ """
+ if self.state_group is None:
+ return
+
+ self._current_state_ids = yield store.get_state_ids_for_group(
+ self.state_group,
+ )
+ if self._prev_state_id and self._event_state_key is not None:
+ self._prev_state_ids = dict(self._current_state_ids)
+
+ key = (self._event_type, self._event_state_key)
+ self._prev_state_ids[key] = self._prev_state_id
+ else:
+ self._prev_state_ids = self._current_state_ids
+
+ @defer.inlineCallbacks
+ def update_state(self, state_group, prev_state_ids, current_state_ids,
+ prev_group, delta_ids):
+ """Replace the state in the context
+ """
+
+ # We need to make sure we wait for any ongoing fetching of state
+ # to complete so that the updated state doesn't get clobbered
+ if self._fetching_state_deferred:
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ self.state_group = state_group
+ self._prev_state_ids = prev_state_ids
+ self.prev_group = prev_group
+ self._current_state_ids = current_state_ids
+ self.delta_ids = delta_ids
+
+ # We need to ensure that that we've marked as having fetched the state
+ self._fetching_state_deferred = defer.succeed(None)
+
+
+def _encode_state_dict(state_dict):
+ """Since dicts of (type, state_key) -> event_id cannot be serialized in
+ JSON we need to convert them to a form that can.
+ """
+ if state_dict is None:
+ return None
+
+ return [
+ (etype, state_key, v)
+ for (etype, state_key), v in iteritems(state_dict)
+ ]
+
+
+def _decode_state_dict(input):
+ """Decodes a state dict encoded using `_encode_state_dict` above
+ """
+ if input is None:
+ return None
+
+ return frozendict({(etype, state_key,): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
new file mode 100644
index 0000000000..633e068eb8
--- /dev/null
+++ b/synapse/events/spamcheck.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class SpamChecker(object):
+ def __init__(self, hs):
+ self.spam_checker = None
+
+ module = None
+ config = None
+ try:
+ module, config = hs.config.spam_checker
+ except Exception:
+ pass
+
+ if module is not None:
+ self.spam_checker = module(config=config)
+
+ def check_event_for_spam(self, event):
+ """Checks if a given event is considered "spammy" by this server.
+
+ If the server considers an event spammy, then it will be rejected if
+ sent by a local user. If it is sent by a user on another server, then
+ users receive a blank event.
+
+ Args:
+ event (synapse.events.EventBase): the event to be checked
+
+ Returns:
+ bool: True if the event is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ return self.spam_checker.check_event_for_spam(event)
+
+ def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ """Checks if a given user may send an invite
+
+ If this method returns false, the invite will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may send an invite, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+
+ def user_may_create_room(self, userid):
+ """Checks if a given user may create a room
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may create a room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room(userid)
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Checks if a given user may create a room alias
+
+ If this method returns false, the association request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_alias (string): The alias to be created
+
+ Returns:
+ bool: True if the user may create a room alias, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+
+ def user_may_publish_room(self, userid, room_id):
+ """Checks if a given user may publish a room to the directory
+
+ If this method returns false, the publish request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_id (string): The ID of the room that would be published
+
+ Returns:
+ bool: True if the user may publish the room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_publish_room(userid, room_id)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 824f4a42e3..652941ca0d 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
-from . import EventBase
+import re
+
+from six import string_types
from frozendict import frozendict
-import re
+from synapse.api.constants import EventTypes
+
+from . import EventBase
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
@@ -277,7 +280,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields:
if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, basestring) for f in only_event_fields)):
+ not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 2f4c8a1018..cf184748a1 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.types import EventID, RoomID, UserID
-from synapse.api.errors import SynapseError
+from six import string_types
+
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
@@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
- if not isinstance(getattr(event, s), basestring):
+ if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
@@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
- if not isinstance(d[s], basestring):
+ if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
|