summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2019-01-29 14:07:23 +0000
committerGitHub <noreply@github.com>2019-01-29 14:07:23 +0000
commitb8d75ef53eae4d4922325c5f1a02e93881f36888 (patch)
tree7eab7f65c77d0987d4ca072efb49b3c75811d896 /synapse
parentImplement MSC1708 (.well-known lookups for server routing) (#4489) (diff)
parentCorrectly set context.app_service (diff)
downloadsynapse-b8d75ef53eae4d4922325c5f1a02e93881f36888.tar.xz
Merge pull request #4481 from matrix-org/erikj/event_builder
Refactor event building into EventBuilder
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py13
-rw-r--r--synapse/crypto/event_signing.py16
-rw-r--r--synapse/events/builder.py282
-rw-r--r--synapse/federation/federation_client.py20
-rw-r--r--synapse/handlers/message.py34
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/event_federation.py23
7 files changed, 260 insertions, 133 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 7b213e54c8..2d78a257d3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -551,17 +551,6 @@ class Auth(object):
         return self.store.is_server_admin(user)
 
     @defer.inlineCallbacks
-    def add_auth_events(self, builder, context):
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
-        auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
-
-        auth_events_entries = yield self.store.add_event_hashes(
-            auth_ids
-        )
-
-        builder.auth_events = auth_events_entries
-
-    @defer.inlineCallbacks
     def compute_auth_events(self, event, current_state_ids, for_verification=False):
         if event.type == EventTypes.Create:
             defer.returnValue([])
@@ -577,7 +566,7 @@ class Auth(object):
         key = (EventTypes.JoinRules, "", )
         join_rule_event_id = current_state_ids.get(key)
 
-        key = (EventTypes.Member, event.user_id, )
+        key = (EventTypes.Member, event.sender, )
         member_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Create, "", )
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index d01ac5075c..1dfa727fcf 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -131,12 +131,12 @@ def compute_event_signature(event_dict, signature_name, signing_key):
     return redact_json["signatures"]
 
 
-def add_hashes_and_signatures(event, signature_name, signing_key,
+def add_hashes_and_signatures(event_dict, signature_name, signing_key,
                               hash_algorithm=hashlib.sha256):
     """Add content hash and sign the event
 
     Args:
-        event_dict (EventBuilder): The event to add hashes to and sign
+        event_dict (dict): The event to add hashes to and sign
         signature_name (str): The name of the entity signing the event
             (typically the server's hostname).
         signing_key (syutil.crypto.SigningKey): The key to sign with
@@ -144,16 +144,12 @@ def add_hashes_and_signatures(event, signature_name, signing_key,
             to hash the event
     """
 
-    name, digest = compute_content_hash(
-        event.get_pdu_json(), hash_algorithm=hash_algorithm,
-    )
+    name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
 
-    if not hasattr(event, "hashes"):
-        event.hashes = {}
-    event.hashes[name] = encode_base64(digest)
+    event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
 
-    event.signatures = compute_event_signature(
-        event.get_pdu_json(),
+    event_dict["signatures"] = compute_event_signature(
+        event_dict,
         signature_name=signature_name,
         signing_key=signing_key,
     )
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 7e63371095..fb0683cea8 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -13,79 +13,156 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
+import attr
 
-from synapse.api.constants import RoomVersions
+from twisted.internet import defer
+
+from synapse.api.constants import (
+    KNOWN_EVENT_FORMAT_VERSIONS,
+    KNOWN_ROOM_VERSIONS,
+    MAX_DEPTH,
+)
+from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.types import EventID
 from synapse.util.stringutils import random_string
 
-from . import EventBase, FrozenEvent, _event_dict_property
+from . import (
+    _EventInternalMetadata,
+    event_type_from_format_version,
+    room_version_to_event_format,
+)
 
 
-def get_event_builder(room_version, key_values={}, internal_metadata_dict={}):
-    """Generate an event builder appropriate for the given room version
+@attr.s(slots=True, cmp=False, frozen=True)
+class EventBuilder(object):
+    """A format independent event builder used to build up the event content
+    before signing the event.
 
-    Args:
-        room_version (str): Version of the room that we're creating an
-            event builder for
-        key_values (dict): Fields used as the basis of the new event
-        internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
-            object.
+    (Note that while objects of this class are frozen, the
+    content/unsigned/internal_metadata fields are still mutable)
 
-    Returns:
-        EventBuilder
+    Attributes:
+        format_version (int): Event format version
+        room_id (str)
+        type (str)
+        sender (str)
+        content (dict)
+        unsigned (dict)
+        internal_metadata (_EventInternalMetadata)
+
+        _state (StateHandler)
+        _auth (synapse.api.Auth)
+        _store (DataStore)
+        _clock (Clock)
+        _hostname (str): The hostname of the server creating the event
+        _signing_key: The signing key to use to sign the event as the server
     """
-    if room_version in {
-        RoomVersions.V1,
-        RoomVersions.V2,
-        RoomVersions.VDH_TEST,
-        RoomVersions.STATE_V2_TEST,
-    }:
-        return EventBuilder(key_values, internal_metadata_dict)
-    else:
-        raise Exception(
-            "No event format defined for version %r" % (room_version,)
-        )
 
+    _state = attr.ib()
+    _auth = attr.ib()
+    _store = attr.ib()
+    _clock = attr.ib()
+    _hostname = attr.ib()
+    _signing_key = attr.ib()
+
+    format_version = attr.ib()
+
+    room_id = attr.ib()
+    type = attr.ib()
+    sender = attr.ib()
+
+    content = attr.ib(default=attr.Factory(dict))
+    unsigned = attr.ib(default=attr.Factory(dict))
+
+    # These only exist on a subset of events, so they raise AttributeError if
+    # someone tries to get them when they don't exist.
+    _state_key = attr.ib(default=None)
+    _redacts = attr.ib(default=None)
+
+    internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
+
+    @property
+    def state_key(self):
+        if self._state_key is not None:
+            return self._state_key
+
+        raise AttributeError("state_key")
+
+    def is_state(self):
+        return self._state_key is not None
 
-class EventBuilder(EventBase):
-    def __init__(self, key_values={}, internal_metadata_dict={}):
-        signatures = copy.deepcopy(key_values.pop("signatures", {}))
-        unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
+    @defer.inlineCallbacks
+    def build(self, prev_event_ids):
+        """Transform into a fully signed and hashed event
 
-        super(EventBuilder, self).__init__(
-            key_values,
-            signatures=signatures,
-            unsigned=unsigned,
-            internal_metadata_dict=internal_metadata_dict,
+        Args:
+            prev_event_ids (list[str]): The event IDs to use as the prev events
+
+        Returns:
+            Deferred[FrozenEvent]
+        """
+
+        state_ids = yield self._state.get_current_state_ids(
+            self.room_id, prev_event_ids,
+        )
+        auth_ids = yield self._auth.compute_auth_events(
+            self, state_ids,
         )
 
-    event_id = _event_dict_property("event_id")
-    state_key = _event_dict_property("state_key")
-    type = _event_dict_property("type")
+        auth_events = yield self._store.add_event_hashes(auth_ids)
+        prev_events = yield self._store.add_event_hashes(prev_event_ids)
 
-    def build(self):
-        return FrozenEvent.from_event(self)
+        old_depth = yield self._store.get_max_depth_of(
+            prev_event_ids,
+        )
+        depth = old_depth + 1
 
+        # we cap depth of generated events, to ensure that they are not
+        # rejected by other servers (and so that they can be persisted in
+        # the db)
+        depth = min(depth, MAX_DEPTH)
 
-class EventBuilderFactory(object):
-    def __init__(self, clock, hostname):
-        self.clock = clock
-        self.hostname = hostname
+        event_dict = {
+            "auth_events": auth_events,
+            "prev_events": prev_events,
+            "type": self.type,
+            "room_id": self.room_id,
+            "sender": self.sender,
+            "content": self.content,
+            "unsigned": self.unsigned,
+            "depth": depth,
+            "prev_state": [],
+        }
+
+        if self.is_state():
+            event_dict["state_key"] = self._state_key
 
-        self.event_id_count = 0
+        if self._redacts is not None:
+            event_dict["redacts"] = self._redacts
 
-    def create_event_id(self):
-        i = str(self.event_id_count)
-        self.event_id_count += 1
+        defer.returnValue(
+            create_local_event_from_event_dict(
+                clock=self._clock,
+                hostname=self._hostname,
+                signing_key=self._signing_key,
+                format_version=self.format_version,
+                event_dict=event_dict,
+                internal_metadata_dict=self.internal_metadata.get_dict(),
+            )
+        )
 
-        local_part = str(int(self.clock.time())) + i + random_string(5)
 
-        e_id = EventID(local_part, self.hostname)
+class EventBuilderFactory(object):
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+        self.hostname = hs.hostname
+        self.signing_key = hs.config.signing_key[0]
 
-        return e_id.to_string()
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.auth = hs.get_auth()
 
-    def new(self, room_version, key_values={}):
+    def new(self, room_version, key_values):
         """Generate an event builder appropriate for the given room version
 
         Args:
@@ -98,27 +175,102 @@ class EventBuilderFactory(object):
         """
 
         # There's currently only the one event version defined
-        if room_version not in {
-            RoomVersions.V1,
-            RoomVersions.V2,
-            RoomVersions.VDH_TEST,
-            RoomVersions.STATE_V2_TEST,
-        }:
+        if room_version not in KNOWN_ROOM_VERSIONS:
             raise Exception(
                 "No event format defined for version %r" % (room_version,)
             )
 
-        key_values["event_id"] = self.create_event_id()
+        return EventBuilder(
+            store=self.store,
+            state=self.state,
+            auth=self.auth,
+            clock=self.clock,
+            hostname=self.hostname,
+            signing_key=self.signing_key,
+            format_version=room_version_to_event_format(room_version),
+            type=key_values["type"],
+            state_key=key_values.get("state_key"),
+            room_id=key_values["room_id"],
+            sender=key_values["sender"],
+            content=key_values.get("content", {}),
+            unsigned=key_values.get("unsigned", {}),
+            redacts=key_values.get("redacts", None),
+        )
+
+
+def create_local_event_from_event_dict(clock, hostname, signing_key,
+                                       format_version, event_dict,
+                                       internal_metadata_dict=None):
+    """Takes a fully formed event dict, ensuring that fields like `origin`
+    and `origin_server_ts` have correct values for a locally produced event,
+    then signs and hashes it.
+
+    Args:
+        clock (Clock)
+        hostname (str)
+        signing_key
+        format_version (int)
+        event_dict (dict)
+        internal_metadata_dict (dict|None)
+
+    Returns:
+        FrozenEvent
+    """
+
+    # There's currently only the one event version defined
+    if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
+        raise Exception(
+            "No event format defined for version %r" % (format_version,)
+        )
+
+    if internal_metadata_dict is None:
+        internal_metadata_dict = {}
+
+    time_now = int(clock.time_msec())
+
+    event_dict["event_id"] = _create_event_id(clock, hostname)
+
+    event_dict["origin"] = hostname
+    event_dict["origin_server_ts"] = time_now
+
+    event_dict.setdefault("unsigned", {})
+    age = event_dict["unsigned"].pop("age", 0)
+    event_dict["unsigned"].setdefault("age_ts", time_now - age)
+
+    event_dict.setdefault("signatures", {})
+
+    add_hashes_and_signatures(
+        event_dict,
+        hostname,
+        signing_key,
+    )
+    return event_type_from_format_version(format_version)(
+        event_dict, internal_metadata_dict=internal_metadata_dict,
+    )
+
+
+# A counter used when generating new event IDs
+_event_id_counter = 0
+
+
+def _create_event_id(clock, hostname):
+    """Create a new event ID
+
+    Args:
+        clock (Clock)
+        hostname (str): The server name for the event ID
+
+    Returns:
+        str
+    """
 
-        time_now = int(self.clock.time_msec())
+    global _event_id_counter
 
-        key_values.setdefault("origin", self.hostname)
-        key_values.setdefault("origin_server_ts", time_now)
+    i = str(_event_id_counter)
+    _event_id_counter += 1
 
-        key_values.setdefault("unsigned", {})
-        age = key_values["unsigned"].pop("age", 0)
-        key_values["unsigned"].setdefault("age_ts", time_now - age)
+    local_part = str(int(clock.time())) + i + random_string(5)
 
-        key_values["signatures"] = {}
+    e_id = EventID(local_part, hostname)
 
-        return EventBuilder(key_values=key_values,)
+    return e_id.to_string()
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index cacb1c8aaf..9b4acd2ed7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -37,8 +37,7 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
-from synapse.crypto.event_signing import add_hashes_and_signatures
-from synapse.events import room_version_to_event_format
+from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -72,7 +71,8 @@ class FederationClient(FederationBase):
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
 
-        self.event_builder_factory = hs.get_event_builder_factory()
+        self.hostname = hs.hostname
+        self.signing_key = hs.config.signing_key[0]
 
         self._get_pdu_cache = ExpiringCache(
             cache_name="get_pdu_cache",
@@ -608,18 +608,10 @@ class FederationClient(FederationBase):
             if "prev_state" not in pdu_dict:
                 pdu_dict["prev_state"] = []
 
-            # Strip off the fields that we want to clobber.
-            pdu_dict.pop("origin", None)
-            pdu_dict.pop("origin_server_ts", None)
-            pdu_dict.pop("unsigned", None)
-
-            builder = self.event_builder_factory.new(room_version, pdu_dict)
-            add_hashes_and_signatures(
-                builder,
-                self.hs.hostname,
-                self.hs.config.signing_key[0]
+            ev = builder.create_local_event_from_event_dict(
+                self._clock, self.hostname, self.signing_key,
+                format_version=event_format, event_dict=pdu_dict,
             )
-            ev = builder.build()
 
             defer.returnValue(
                 (destination, ev, event_format)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 05d1370c18..37a7dca794 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
 from twisted.internet import defer
 from twisted.internet.defer import succeed
 
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
+from synapse.api.constants import EventTypes, Membership, RoomVersions
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -31,7 +31,6 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.urls import ConsentURIBuilder
-from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -545,40 +544,19 @@ class EventCreationHandler(object):
             prev_events_and_hashes = \
                 yield self.store.get_prev_events_for_room(builder.room_id)
 
-        if prev_events_and_hashes:
-            depth = max([d for _, _, d in prev_events_and_hashes]) + 1
-            # we cap depth of generated events, to ensure that they are not
-            # rejected by other servers (and so that they can be persisted in
-            # the db)
-            depth = min(depth, MAX_DEPTH)
-        else:
-            depth = 1
-
         prev_events = [
             (event_id, prev_hashes)
             for event_id, prev_hashes, _ in prev_events_and_hashes
         ]
 
-        builder.prev_events = prev_events
-        builder.depth = depth
-
-        context = yield self.state.compute_event_context(builder)
+        event = yield builder.build(
+            prev_event_ids=[p for p, _ in prev_events],
+        )
+        context = yield self.state.compute_event_context(event)
         if requester:
             context.app_service = requester.app_service
 
-        if builder.is_state():
-            builder.prev_state = yield self.store.add_event_hashes(
-                context.prev_state_events
-            )
-
-        yield self.auth.add_auth_events(builder, context)
-
-        signing_key = self.hs.config.signing_key[0]
-        add_hashes_and_signatures(
-            builder, self.server_name, signing_key
-        )
-
-        event = builder.build()
+        self.validator.validate_new(event)
 
         logger.debug(
             "Created event %s",
diff --git a/synapse/server.py b/synapse/server.py
index c8914302cf..6c52101616 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -355,10 +355,7 @@ class HomeServer(object):
         return Keyring(self)
 
     def build_event_builder_factory(self):
-        return EventBuilderFactory(
-            clock=self.get_clock(),
-            hostname=self.hostname,
-        )
+        return EventBuilderFactory(self)
 
     def build_filtering(self):
         return Filtering(self)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index d3b9dea1d6..38809ed0fc 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
 
         return dict(txn)
 
+    @defer.inlineCallbacks
+    def get_max_depth_of(self, event_ids):
+        """Returns the max depth of a set of event IDs
+
+        Args:
+            event_ids (list[str])
+
+        Returns
+            Deferred[int]
+        """
+        rows = yield self._simple_select_many_batch(
+            table="events",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("depth",),
+            desc="get_max_depth_of",
+        )
+
+        if not rows:
+            defer.returnValue(0)
+        else:
+            defer.returnValue(max(row["depth"] for row in rows))
+
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
             txn,