diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e1b1823cd7..c684265101 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -21,6 +21,8 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
+ RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent,
+ RoomCreateEvent,
)
from synapse.util.logutils import log_function
@@ -47,42 +49,60 @@ class Auth(object):
"""
try:
if hasattr(event, "room_id"):
+ if event.old_state_events is None:
+ # Oh, we don't know what the state of the room was, so we
+ # are trusting that this is allowed (at least for now)
+ defer.returnValue(True)
+
+ if hasattr(event, "outlier") and event.outlier is True:
+ # TODO (erikj): Auth for outliers is done differently.
+ defer.returnValue(True)
+
is_state = hasattr(event, "state_key")
+ if event.type == RoomCreateEvent.TYPE:
+ # FIXME
+ defer.returnValue(True)
+
if event.type == RoomMemberEvent.TYPE:
- yield self._can_replace_state(event)
- allowed = yield self.is_membership_change_allowed(event)
+ self._can_replace_state(event)
+ allowed = self.is_membership_change_allowed(event)
+ if allowed:
+ logger.debug("Allowing! %s", event)
+ else:
+ logger.debug("Denying! %s", event)
defer.returnValue(allowed)
return
- self._check_joined_room(
- member=snapshot.membership_state,
- user_id=snapshot.user_id,
- room_id=snapshot.room_id,
- )
+ if not event.type == InviteJoinEvent.TYPE:
+ self.check_event_sender_in_room(event)
if is_state:
# TODO (erikj): This really only should be called for *new*
# state
yield self._can_add_state(event)
- yield self._can_replace_state(event)
+ self._can_replace_state(event)
else:
yield self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
- yield self._check_power_levels(event)
+ self._check_power_levels(event)
if event.type == RoomRedactionEvent.TYPE:
- yield self._check_redaction(event)
+ self._check_redaction(event)
+
+ logger.debug("Allowing! %s", event)
defer.returnValue(True)
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e:
logger.info("Event auth check failed on event %s with msg: %s",
event, e.msg)
+ logger.info("Denying! %s", event)
if raises:
raise e
+
defer.returnValue(False)
@defer.inlineCallbacks
@@ -98,45 +118,72 @@ class Auth(object):
pass
defer.returnValue(None)
+ def check_event_sender_in_room(self, event):
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ member_event = event.state_events.get(key)
+
+ return self._check_joined_room(
+ member_event,
+ event.user_id,
+ event.room_id
+ )
+
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
- @defer.inlineCallbacks
+ @log_function
def is_membership_change_allowed(self, event):
target_user_id = event.state_key
- # does this room even exist
- room = yield self.store.get_room(event.room_id)
- if not room:
- raise AuthError(403, "Room does not exist")
-
# get info about the caller
- try:
- caller = yield self.store.get_room_member(
- user_id=event.user_id,
- room_id=event.room_id)
- except:
- caller = None
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ caller = event.old_state_events.get(key)
+
caller_in_room = caller and caller.membership == "join"
# get info about the target
- try:
- target = yield self.store.get_room_member(
- user_id=target_user_id,
- room_id=event.room_id)
- except:
- target = None
+ key = (RoomMemberEvent.TYPE, target_user_id, )
+ target = event.old_state_events.get(key)
+
target_in_room = target and target.membership == "join"
membership = event.content["membership"]
- join_rule = yield self.store.get_room_join_rule(event.room_id)
- if not join_rule:
+ key = (RoomJoinRulesEvent.TYPE, "", )
+ join_rule_event = event.old_state_events.get(key)
+ if join_rule_event:
+ join_rule = join_rule_event.content.get(
+ "join_rule", JoinRules.INVITE
+ )
+ else:
join_rule = JoinRules.INVITE
+ user_level = self._get_power_level_from_event_state(
+ event,
+ event.user_id,
+ )
+
+ ban_level, kick_level, redact_level = (
+ self._get_ops_level_from_event_state(
+ event
+ )
+ )
+
+ logger.debug(
+ "is_membership_change_allowed: %s",
+ {
+ "caller_in_room": caller_in_room,
+ "target_in_room": target_in_room,
+ "membership": membership,
+ "join_rule": join_rule,
+ "target_user_id": target_user_id,
+ "event.user_id": event.user_id,
+ }
+ )
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -153,13 +200,10 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
- elif join_rule == JoinRules.PUBLIC or room.is_public:
+ elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
- if (
- not caller or caller.membership not in
- [Membership.INVITE, Membership.JOIN]
- ):
+ if not caller_in_room:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
@@ -171,29 +215,16 @@ class Auth(object):
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(403, "You are not in room %s." % event.room_id)
elif target_user_id != event.user_id:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
- _, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
-
if kick_level:
kick_level = int(kick_level)
else:
- kick_level = 50
+ kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
-
- ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
-
if ban_level:
ban_level = int(ban_level)
else:
@@ -204,7 +235,30 @@ class Auth(object):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- defer.returnValue(True)
+ return True
+
+ def _get_power_level_from_event_state(self, event, user_id):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
+ level = None
+ if power_level_event:
+ level = power_level_event.content.get(user_id)
+ if not level:
+ level = power_level_event.content.get("default", 0)
+
+ return level
+
+ def _get_ops_level_from_event_state(self, event):
+ key = (RoomOpsPowerLevelsEvent.TYPE, "", )
+ ops_event = event.old_state_events.get(key)
+
+ if ops_event:
+ return (
+ ops_event.content.get("ban_level"),
+ ops_event.content.get("kick_level"),
+ ops_event.content.get("redact_level"),
+ )
+ return None, None, None,
@defer.inlineCallbacks
def get_user_by_req(self, request):
@@ -282,8 +336,8 @@ class Auth(object):
else:
send_level = 0
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -308,8 +362,8 @@ class Auth(object):
add_level = int(add_level)
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -322,19 +376,9 @@ class Auth(object):
defer.returnValue(True)
- @defer.inlineCallbacks
def _can_replace_state(self, event):
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
-
- if current_state:
- current_state = current_state[0]
-
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -346,6 +390,10 @@ class Auth(object):
logger.debug(
"Checking power level for %s, %s", event.user_id, user_level
)
+
+ key = (event.type, event.state_key, )
+ current_state = event.old_state_events.get(key)
+
if current_state and hasattr(current_state, "required_power_level"):
req = current_state.required_power_level
@@ -356,10 +404,9 @@ class Auth(object):
"You don't have permission to change that state"
)
- @defer.inlineCallbacks
def _check_redaction(self, event):
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -368,7 +415,9 @@ class Auth(object):
else:
user_level = 0
- _, _, redact_level = yield self.store.get_ops_levels(event.room_id)
+ _, _, redact_level = self._get_ops_level_from_event_state(
+ event
+ )
if not redact_level:
redact_level = 50
@@ -379,7 +428,6 @@ class Auth(object):
"You don't have permission to redact events"
)
- @defer.inlineCallbacks
def _check_power_levels(self, event):
for k, v in event.content.items():
if k == "default":
@@ -399,19 +447,16 @@ class Auth(object):
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
+ key = (event.type, event.state_key, )
+ current_state = event.old_state_events.get(key)
if not current_state:
return
else:
current_state = current_state[0]
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index f66fea2904..b855811b98 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -65,13 +65,15 @@ class SynapseEvent(JsonEncodedObject):
internal_keys = [
"is_state",
- "prev_events",
"depth",
"destinations",
"origin",
"outlier",
"power_level",
"redacted",
+ "prev_events",
+ "hashes",
+ "signatures",
]
required_keys = [
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
index 74d0ef77f4..750096c618 100644
--- a/synapse/api/events/factory.py
+++ b/synapse/api/events/factory.py
@@ -21,6 +21,8 @@ from synapse.api.events.room import (
RoomRedactionEvent,
)
+from synapse.types import EventID
+
from synapse.util.stringutils import random_string
@@ -51,12 +53,22 @@ class EventFactory(object):
self.clock = hs.get_clock()
self.hs = hs
+ self.event_id_count = 0
+
+ def create_event_id(self):
+ i = str(self.event_id_count)
+ self.event_id_count += 1
+
+ local_part = str(int(self.clock.time())) + i + random_string(5)
+
+ e_id = EventID.create_local(local_part, self.hs)
+
+ return e_id.to_string()
+
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
- kwargs["event_id"] = "%s@%s" % (
- random_string(10), self.hs.hostname
- )
+ kwargs["event_id"] = self.create_event_id()
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py
index c3a32be8c1..7fdf45a264 100644
--- a/synapse/api/events/utils.py
+++ b/synapse/api/events/utils.py
@@ -27,7 +27,14 @@ def prune_event(event):
the user has specified, but we do want to keep necessary information like
type, state_key etc.
"""
+ return _prune_event_or_pdu(event.type, event)
+def prune_pdu(pdu):
+ """Removes keys that contain unrestricted and non-essential data from a PDU
+ """
+ return _prune_event_or_pdu(pdu.pdu_type, pdu)
+
+def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields.
event.unrecognized_keys = {}
@@ -38,25 +45,25 @@ def prune_event(event):
if field in event.content:
new_content[field] = event.content[field]
- if event.type == RoomMemberEvent.TYPE:
+ if event_type == RoomMemberEvent.TYPE:
add_fields("membership")
- elif event.type == RoomCreateEvent.TYPE:
+ elif event_type == RoomCreateEvent.TYPE:
add_fields("creator")
- elif event.type == RoomJoinRulesEvent.TYPE:
+ elif event_type == RoomJoinRulesEvent.TYPE:
add_fields("join_rule")
- elif event.type == RoomPowerLevelsEvent.TYPE:
+ elif event_type == RoomPowerLevelsEvent.TYPE:
# TODO: Actually check these are valid user_ids etc.
add_fields("default")
for k, v in event.content.items():
if k.startswith("@") and isinstance(v, (int, long)):
new_content[k] = v
- elif event.type == RoomAddStateLevelEvent.TYPE:
+ elif event_type == RoomAddStateLevelEvent.TYPE:
add_fields("level")
- elif event.type == RoomSendEventLevelEvent.TYPE:
+ elif event_type == RoomSendEventLevelEvent.TYPE:
add_fields("level")
- elif event.type == RoomOpsPowerLevelsEvent.TYPE:
+ elif event_type == RoomOpsPowerLevelsEvent.TYPE:
add_fields("kick_level", "ban_level", "redact_level")
- elif event.type == RoomAliasesEvent.TYPE:
+ elif event_type == RoomAliasesEvent.TYPE:
add_fields("aliases")
event.content = new_content
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
new file mode 100644
index 0000000000..61edd2c6f9
--- /dev/null
+++ b/synapse/crypto/event_signing.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.federation.units import Pdu
+from synapse.api.events.utils import prune_pdu
+from syutil.jsonutil import encode_canonical_json
+from syutil.base64util import encode_base64, decode_base64
+from syutil.crypto.jsonsign import sign_json, verify_signed_json
+
+import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def add_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
+ hashed = _compute_content_hash(pdu, hash_algorithm)
+ pdu.hashes[hashed.name] = encode_base64(hashed.digest())
+ return pdu
+
+
+def check_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
+ """Check whether the hash for this PDU matches the contents"""
+ computed_hash = _compute_content_hash(pdu, hash_algorithm)
+ if computed_hash.name not in pdu.hashes:
+ raise Exception("Algorithm %s not in hashes %s" % (
+ computed_hash.name, list(pdu.hashes)
+ ))
+ message_hash_base64 = pdu.hashes[computed_hash.name]
+ try:
+ message_hash_bytes = decode_base64(message_hash_base64)
+ except:
+ raise Exception("Invalid base64: %s" % (message_hash_base64,))
+ return message_hash_bytes == computed_hash.digest()
+
+
+def _compute_content_hash(pdu, hash_algorithm):
+ pdu_json = pdu.get_dict()
+ #TODO: Make "age_ts" key internal
+ pdu_json.pop("age_ts", None)
+ pdu_json.pop("unsigned", None)
+ pdu_json.pop("signatures", None)
+ pdu_json.pop("hashes", None)
+ pdu_json_bytes = encode_canonical_json(pdu_json)
+ return hash_algorithm(pdu_json_bytes)
+
+
+def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256):
+ tmp_pdu = Pdu(**pdu.get_dict())
+ tmp_pdu = prune_pdu(tmp_pdu)
+ pdu_json = tmp_pdu.get_dict()
+ pdu_json.pop("signatures", None)
+ pdu_json_bytes = encode_canonical_json(pdu_json)
+ hashed = hash_algorithm(pdu_json_bytes)
+ return (hashed.name, hashed.digest())
+
+
+def sign_event_pdu(pdu, signature_name, signing_key):
+ tmp_pdu = Pdu(**pdu.get_dict())
+ tmp_pdu = prune_pdu(tmp_pdu)
+ pdu_json = tmp_pdu.get_dict()
+ pdu_json = sign_json(pdu_json, signature_name, signing_key)
+ pdu.signatures = pdu_json["signatures"]
+ return pdu
+
+
+def verify_signed_event_pdu(pdu, signature_name, verify_key):
+ tmp_pdu = Pdu(**pdu.get_dict())
+ tmp_pdu = prune_pdu(tmp_pdu)
+ pdu_json = tmp_pdu.get_dict()
+ verify_signed_json(pdu_json, signature_name, verify_key)
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index e8180d94fd..6d31286290 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -14,41 +14,43 @@
# limitations under the License.
from .units import Pdu
+from synapse.crypto.event_signing import (
+ add_event_pdu_content_hash, sign_event_pdu
+)
+from synapse.types import EventID
import copy
-def decode_event_id(event_id, server_name):
- parts = event_id.split("@")
- if len(parts) < 2:
- return (event_id, server_name)
- else:
- return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
- return "%s@%s" % (pdu_id, origin)
-
-
class PduCodec(object):
def __init__(self, hs):
+ self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
+ self.hs = hs
+
+ def encode_event_id(self, local, domain):
+ return EventID.create(local, domain, self.hs).to_string()
+
+ def decode_event_id(self, event_id):
+ e_id = self.hs.parse_eventid(event_id)
+ return e_id.localpart, e_id.domain
def event_from_pdu(self, pdu):
kwargs = {}
- kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
+ kwargs["event_id"] = self.encode_event_id(pdu.pdu_id, pdu.origin)
kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type
kwargs["prev_events"] = [
- encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
+ (self.encode_event_id(i, o), s)
+ for i, o, s in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
- kwargs["prev_state"] = encode_event_id(
+ kwargs["prev_state"] = self.encode_event_id(
pdu.prev_state_id, pdu.prev_state_origin
)
@@ -70,21 +72,24 @@ class PduCodec(object):
def pdu_from_event(self, event):
d = event.get_full_dict()
- d["pdu_id"], d["origin"] = decode_event_id(
- event.event_id, self.server_name
+ d["pdu_id"], d["origin"] = self.decode_event_id(
+ event.event_id
)
d["context"] = event.room_id
d["pdu_type"] = event.type
if hasattr(event, "prev_events"):
+ def f(e, s):
+ i, o = self.decode_event_id(e)
+ return i, o, s
d["prev_pdus"] = [
- decode_event_id(e, self.server_name)
- for e in event.prev_events
+ f(e, s)
+ for e, s in event.prev_events
]
if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = (
- decode_event_id(event.prev_state, self.server_name)
+ self.decode_event_id(event.prev_state)
)
if hasattr(event, "state_key"):
@@ -99,4 +104,6 @@ class PduCodec(object):
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
- return Pdu(**kwargs)
+ pdu = Pdu(**kwargs)
+ pdu = add_event_pdu_content_hash(pdu)
+ return sign_event_pdu(pdu, self.server_name, self.signing_key)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..000a3081c2 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -244,13 +244,14 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, pdu_id=None,
+ pdu_origin=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,13 +264,14 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
+ )
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus)
@@ -295,6 +297,10 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
@@ -315,7 +321,7 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
@@ -347,14 +353,19 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
+ def on_context_state_request(self, context, pdu_id, pdu_origin):
+ if pdu_id and pdu_origin:
+ pdus = yield self.handler.get_state_for_pdu(
+ pdu_id, pdu_origin
+ )
+ else:
+ results = yield self.store.get_current_state_for_context(
+ context
+ )
+ pdus = [Pdu.from_pdu_tuple(p) for p in results]
- logger.debug("Context returning %d results", len(results))
+ logger.debug("Context returning %d results", len(pdus))
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@@ -393,9 +404,55 @@ class ReplicationLayer(object):
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue(pdu.get_dict())
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, ret_pdu.get_dict()))
+
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = Pdu(**content)
+ state = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, context, user_id):
+ pdu_dict = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.context,
+ pdu.pdu_id,
+ pdu.origin,
+ pdu.get_dict(),
+ )
+
+ logger.debug("Got content: %s", content)
+ pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+ for pdu in pdus:
+ yield self._handle_new_pdu(destination, pdu)
+
+ defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
@@ -414,20 +471,22 @@ class ReplicationLayer(object):
transmission.
"""
pdus = [p.get_dict() for p in pdu_list]
+ time_now = self._clock.time_msec()
for p in pdus:
- if "age_ts" in pdus:
- p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+ if "age_ts" in p:
+ age = time_now - p["age_ts"]
+ p.setdefault("unsigned", {})["age"] = int(age)
+ del p["age_ts"]
return Transaction(
origin=self.server_name,
pdus=pdus,
- origin_server_ts=int(self._clock.time_msec()),
+ origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
@@ -436,6 +495,8 @@ class ReplicationLayer(object):
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier:
@@ -443,7 +504,7 @@ class ReplicationLayer(object):
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
if min_depth and pdu.depth > min_depth:
- for pdu_id, origin in pdu.prev_pdus:
+ for pdu_id, origin, hashes in pdu.prev_pdus:
exists = yield self._get_persisted_pdu(pdu_id, origin)
if not exists:
@@ -459,12 +520,22 @@ class ReplicationLayer(object):
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.context, pdu.pdu_id, pdu.origin,
+ )
# Persist the Pdu, but don't mark it as processed yet.
yield self.store.persist_event(pdu=pdu)
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
@@ -589,7 +660,7 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
- origin_server_ts=self._clock.time_msec(),
+ origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
@@ -614,7 +685,9 @@ class _TransactionQueue(object):
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
- p["age"] = now - int(p["age_ts"])
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..7f01b4faaf 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,8 @@ class TransportLayer(object):
self.received_handler = None
@log_function
- def get_context_state(self, destination, context):
+ def get_context_state(self, destination, context, pdu_id=None,
+ pdu_origin=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@@ -89,7 +90,14 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
- return self._do_request_for_transaction(destination, subpath)
+ args = {}
+ if pdu_id and pdu_origin:
+ args["pdu_id"] = pdu_id
+ args["pdu_origin"] = pdu_origin
+
+ return self._do_request_for_transaction(
+ destination, subpath, args=args
+ )
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id):
@@ -135,8 +143,10 @@ class TransportLayer(object):
subpath = "/backfill/%s/" % context
- args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
- args["limit"] = limit
+ args = {
+ "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
+ "limit": limit,
+ }
return self._do_request_for_transaction(
dest,
@@ -198,6 +208,59 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
+ @log_function
+ def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+ path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ retry_on_dns_fail=retry_on_dns_fail,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_join(self, destination, context, pdu_id, origin, content):
+ path = PREFIX + "/send_join/%s/%s/%s" % (
+ context,
+ origin,
+ pdu_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, pdu_id, origin, content):
+ path = PREFIX + "/invite/%s/%s/%s" % (
+ context,
+ origin,
+ pdu_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
@@ -326,7 +389,11 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
- handler.on_context_state_request(context)
+ handler.on_context_state_request(
+ context,
+ query.get("pdu_id", [None])[0],
+ query.get("pdu_origin", [None])[0]
+ )
)
)
@@ -362,6 +429,39 @@ class TransportLayer(object):
)
)
+ self.server.register_path(
+ "GET",
+ re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, user_id:
+ self._on_make_join_request(
+ origin, content, query, context, user_id
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, pdu_origin, pdu_id:
+ self._on_send_join_request(
+ origin, content, query,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, pdu_origin, pdu_id:
+ self._on_invite_request(
+ origin, content, query,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -451,7 +551,34 @@ class TransportLayer(object):
versions = [v.split(",", 1) for v in v_list]
return self.request_handler.on_backfill_request(
- context, versions, limit)
+ context, versions, limit
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_make_join_request(self, origin, content, query, context, user_id):
+ content = yield self.request_handler.on_make_join_request(
+ context, user_id,
+ )
+ defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_join_request(self, origin, content, query):
+ content = yield self.request_handler.on_send_join_request(
+ origin, content,
+ )
+
+ defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
+
+ defer.returnValue((200, content))
class TransportReceivedHandler(object):
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2fb964180..adc3385644 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -18,6 +18,7 @@ server protocol.
"""
from synapse.util.jsonobject import JsonEncodedObject
+from syutil.base64util import encode_base64
import logging
import json
@@ -63,9 +64,10 @@ class Pdu(JsonEncodedObject):
"depth",
"content",
"outlier",
+ "hashes",
+ "signatures",
"is_state", # Below this are keys valid only for State Pdus.
"state_key",
- "power_level",
"prev_state_id",
"prev_state_origin",
"required_power_level",
@@ -91,7 +93,7 @@ class Pdu(JsonEncodedObject):
# just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
- outlier=False, **kwargs):
+ outlier=False, hashes={}, signatures={}, **kwargs):
if is_state:
for required_key in ["state_key"]:
if required_key not in kwargs:
@@ -99,9 +101,11 @@ class Pdu(JsonEncodedObject):
super(Pdu, self).__init__(
destinations=destinations,
- is_state=is_state,
+ is_state=bool(is_state),
prev_pdus=prev_pdus,
outlier=outlier,
+ hashes=hashes,
+ signatures=signatures,
**kwargs
)
@@ -120,6 +124,10 @@ class Pdu(JsonEncodedObject):
d = copy.copy(pdu_tuple.pdu_entry._asdict())
d["origin_server_ts"] = d.pop("ts")
+ for k in d.keys():
+ if d[k] is None:
+ del d[k]
+
d["content"] = json.loads(d["content_json"])
del d["content_json"]
@@ -127,8 +135,28 @@ class Pdu(JsonEncodedObject):
if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"]))
+ hashes = {
+ alg: encode_base64(hsh)
+ for alg, hsh in pdu_tuple.hashes.items()
+ }
+
+ signatures = {
+ kid: encode_base64(sig)
+ for kid, sig in pdu_tuple.signatures.items()
+ }
+
+ prev_pdus = []
+ for prev_pdu in pdu_tuple.prev_pdu_list:
+ prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
+ prev_hashes = {
+ alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
+ }
+ prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
+
return Pdu(
- prev_pdus=pdu_tuple.prev_pdu_list,
+ prev_pdus=prev_pdus,
+ hashes=hashes,
+ signatures=signatures,
**args
)
else:
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index de4d23bbb3..787a01efc5 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -16,6 +16,8 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
+from synapse.util.async import run_on_reactor
+
class BaseHandler(object):
def __init__(self, hs):
@@ -44,9 +46,19 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
- extra_users=[]):
+ extra_users=[], suppress_auth=False):
+ yield run_on_reactor()
+
snapshot.fill_out_prev_events(event)
+ yield self.state_handler.annotate_state_groups(event)
+
+ if not suppress_auth:
+ yield self.auth.check(event, snapshot, raises=True)
+
+ if hasattr(event, "state_key"):
+ yield self.state_handler.handle_new_event(event, snapshot)
+
yield self.store.persist_event(event)
destinations = set(extra_destinations)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a56830d520..6e897e915d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler):
user_id=user_id,
)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user_id], suppress_auth=True
+ )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f52591d2a3..1daeee833b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,6 +22,8 @@ from synapse.api.constants import Membership
from synapse.util.logutils import log_function
from synapse.federation.pdu_codec import PduCodec
from synapse.api.errors import SynapseError
+from synapse.util.async import run_on_reactor
+from synapse.types import EventID
from twisted.internet import defer, reactor
@@ -62,6 +64,9 @@ class FederationHandler(BaseHandler):
self.pdu_codec = PduCodec(hs)
+ # When joining a room we need to queue any events for that room up
+ self.room_queues = {}
+
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
@@ -78,6 +83,8 @@ class FederationHandler(BaseHandler):
processing.
"""
+ yield run_on_reactor()
+
pdu = self.pdu_codec.pdu_from_event(event)
if not hasattr(pdu, "destinations") or not pdu.destinations:
@@ -87,97 +94,82 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, pdu, backfilled):
+ def on_receive_pdu(self, pdu, backfilled, state=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it throught the StateHandler.
+ do auth checks and put it through the StateHandler.
"""
event = self.pdu_codec.event_from_pdu(pdu)
logger.debug("Got event: %s", event.event_id)
- with (yield self.lock_manager.lock(pdu.context)):
- if event.is_state and not backfilled:
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
- else:
- is_new_state = False
+ if event.room_id in self.room_queues:
+ self.room_queues[event.room_id].append(pdu)
+ return
+
+ logger.debug("Processing event: %s", event.event_id)
+
+ if state:
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
+
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
+
+ logger.debug("Event: %s", event)
+
+ if not backfilled:
+ yield self.auth.check(event, None, raises=True)
+
+ is_new_state = is_new_state and not backfilled
+
# TODO: Implement something in federation that allows us to
# respond to PDU.
- target_is_mine = False
- if hasattr(event, "target_host"):
- target_is_mine = event.target_host == self.hs.hostname
-
- if event.type == InviteJoinEvent.TYPE:
- if not target_is_mine:
- logger.debug("Ignoring invite/join event %s", event)
- return
-
- # If we receive an invite/join event then we need to join the
- # sender to the given room.
- # TODO: We should probably auth this or some such
- content = event.content
- content.update({"membership": Membership.JOIN})
- new_event = self.event_factory.create_event(
- etype=RoomMemberEvent.TYPE,
- state_key=event.user_id,
- room_id=event.room_id,
- user_id=event.user_id,
- membership=Membership.JOIN,
- content=content
+ with (yield self.room_lock.lock(event.room_id)):
+ yield self.store.persist_event(
+ event,
+ backfilled,
+ is_new_state=is_new_state
)
- yield self.hs.get_handlers().room_member_handler.change_membership(
- new_event,
- do_auth=False,
- )
+ room = yield self.store.get_room(event.room_id)
- else:
- with (yield self.room_lock.lock(event.room_id)):
- yield self.store.persist_event(
- event,
- backfilled,
- is_new_state=is_new_state
+ if not room:
+ # Huh, let's try and get the current state
+ try:
+ yield self.replication_layer.get_state_for_context(
+ event.origin, event.room_id, pdu.pdu_id, pdu.origin,
)
- room = yield self.store.get_room(event.room_id)
-
- if not room:
- # Huh, let's try and get the current state
- try:
- yield self.replication_layer.get_state_for_context(
- event.origin, event.room_id
- )
-
- hosts = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
- if self.hs.hostname in hosts:
- try:
- yield self.store.store_room(
- room_id=event.room_id,
- room_creator_user_id="",
- is_public=False,
- )
- except:
- pass
- except:
- logger.exception(
- "Failed to get current state for room %s",
- event.room_id
- )
-
- if not backfilled:
- extra_users = []
- if event.type == RoomMemberEvent.TYPE:
- target_user_id = event.state_key
- target_user = self.hs.parse_userid(target_user_id)
- extra_users.append(target_user)
-
- yield self.notifier.on_new_room_event(
- event, extra_users=extra_users
+ hosts = yield self.store.get_joined_hosts_for_room(
+ event.room_id
)
+ if self.hs.hostname in hosts:
+ try:
+ yield self.store.store_room(
+ room_id=event.room_id,
+ room_creator_user_id="",
+ is_public=False,
+ )
+ except:
+ pass
+ except:
+ logger.exception(
+ "Failed to get current state for room %s",
+ event.room_id
+ )
+
+ if not backfilled:
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
@@ -195,7 +187,12 @@ class FederationHandler(BaseHandler):
for pdu in pdus:
event = self.pdu_codec.event_from_pdu(pdu)
+
+ # FIXME (erikj): Not sure this actually works :/
+ yield self.state_handler.annotate_state_groups(event)
+
events.append(event)
+
yield self.store.persist_event(event, backfilled=True)
defer.returnValue(events)
@@ -203,62 +200,195 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# We are already in the room.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
- # First get current state to see if we are already joined.
+ pdu = yield self.replication_layer.make_join(
+ target_host,
+ room_id,
+ joinee
+ )
+
+ logger.debug("Got response to make_join: %s", pdu)
+
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ # We should assert some things.
+ assert(event.type == RoomMemberEvent.TYPE)
+ assert(event.user_id == joinee)
+ assert(event.state_key == joinee)
+ assert(event.room_id == room_id)
+
+ event.outlier = False
+
+ self.room_queues[room_id] = []
+
try:
- yield self.replication_layer.get_state_for_context(
- target_host, room_id
+ event.event_id = self.event_factory.create_event_id()
+ event.content = content
+
+ state = yield self.replication_layer.send_join(
+ target_host,
+ self.pdu_codec.pdu_from_event(event)
)
- hosts = yield self.store.get_joined_hosts_for_room(room_id)
- if self.hs.hostname in hosts:
- # Oh, we were actually in the room already.
- logger.debug("We're already in the room apparently")
- defer.returnValue(False)
- except Exception:
- logger.exception("Failed to get current state")
-
- new_event = self.event_factory.create_event(
- etype=InviteJoinEvent.TYPE,
- target_host=target_host,
- room_id=room_id,
- user_id=joinee,
- content=content
- )
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
- new_event.destinations = [target_host]
+ logger.debug("do_invite_join state: %s", state)
- snapshot.fill_out_prev_events(new_event)
- yield self.handle_new_event(new_event, snapshot)
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
- # TODO (erikj): Time out here.
- d = defer.Deferred()
- self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
- reactor.callLater(10, d.cancel)
+ logger.debug("do_invite_join event: %s", event)
- try:
- yield d
- except defer.CancelledError:
- raise SynapseError(500, "Unable to join remote room")
+ try:
+ yield self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=False
+ )
+ except:
+ # FIXME
+ pass
- try:
- yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
+ for e in state:
+ # FIXME: Auth these.
+ e.outlier = True
+
+ yield self.state_handler.annotate_state_groups(
+ e,
+ )
+
+ yield self.store.persist_event(
+ e,
+ backfilled=False,
+ is_new_state=False
+ )
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
)
- except:
- pass
+ finally:
+ room_queue = self.room_queues[room_id]
+ del self.room_queues[room_id]
+ for p in room_queue:
+ try:
+ yield self.on_receive_pdu(p, backfilled=False)
+ except:
+ pass
defer.returnValue(True)
+ @defer.inlineCallbacks
+ @log_function
+ def on_make_join_request(self, context, user_id):
+ event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ content={"membership": Membership.JOIN},
+ room_id=context,
+ user_id=user_id,
+ state_key=user_id,
+ )
+
+ snapshot = yield self.store.snapshot_room(
+ event.room_id, event.user_id,
+ )
+ snapshot.fill_out_prev_events(event)
+
+ yield self.state_handler.annotate_state_groups(event)
+ yield self.auth.check(event, None, raises=True)
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_send_join_request(self, origin, pdu):
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ event.outlier = False
+
+ is_new_state = yield self.state_handler.annotate_state_groups(event)
+ yield self.auth.check(event, None, raises=True)
+
+ # FIXME (erikj): All this is duplicated above :(
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
+ )
+
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
+
+ if event.type == RoomMemberEvent.TYPE:
+ if event.membership == Membership.JOIN:
+ user = self.hs.parse_userid(event.state_key)
+ self.distributor.fire(
+ "user_joined_room", user=user, room_id=event.room_id
+ )
+
+ new_pdu = self.pdu_codec.pdu_from_event(event);
+ new_pdu.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+
+ yield self.replication_layer.send_pdu(new_pdu)
+
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in event.state_events.values()
+ ])
+
+ @defer.inlineCallbacks
+ def get_state_for_pdu(self, pdu_id, pdu_origin):
+ yield run_on_reactor()
+
+ event_id = EventID.create(pdu_id, pdu_origin, self.hs).to_string()
+
+ state_groups = yield self.store.get_state_groups(
+ [event_id]
+ )
+
+ if state_groups:
+ results = {
+ (e.type, e.state_key): e for e in state_groups[0].state
+ }
+
+ event = yield self.store.get_event(event_id)
+ if hasattr(event, "state_key"):
+ # Get previous state
+ if hasattr(event, "prev_state") and event.prev_state:
+ prev_event = yield self.store.get_event(event.prev_state)
+ results[(event.type, event.state_key)] = prev_event
+ else:
+ del results[(event.type, event.state_key)]
+
+ defer.returnValue(
+ [
+ self.pdu_codec.pdu_from_event(s)
+ for s in results.values()
+ ]
+ )
+ else:
+ defer.returnValue([])
@log_function
def _on_user_joined(self, user, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 72894869ea..c6f6ab14d1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -83,10 +83,9 @@ class MessageHandler(BaseHandler):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- if not suppress_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self._on_new_room_event(event, snapshot)
+ yield self._on_new_room_event(
+ event, snapshot, suppress_auth=suppress_auth
+ )
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
@@ -149,10 +148,6 @@ class MessageHandler(BaseHandler):
state_key=event.state_key,
)
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
-
yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks
@@ -201,7 +196,7 @@ class MessageHandler(BaseHandler):
raise RoomError(
403, "Member does not meet private room rules.")
- data = yield self.store.get_current_state(
+ data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@@ -221,8 +216,6 @@ class MessageHandler(BaseHandler):
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- yield self.auth.check(event, snapshot, raises=True)
-
# store message in db
yield self._on_new_room_event(event, snapshot)
@@ -239,7 +232,7 @@ class MessageHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
- current_state = yield self.store.get_current_state(room_id)
+ current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
@@ -316,7 +309,7 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
- current_state = yield self.store.get_current_state(
+ current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dab9b03f04..4cd0a06093 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler):
user_id=j.state_key,
)
- yield self.state_handler.handle_new_event(new_event, snapshot)
- yield self._on_new_room_event(new_event, snapshot)
+ yield self._on_new_room_event(
+ new_event, snapshot, suppress_auth=True
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 81ce1a5907..ffc0892f1a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler):
logger.debug("Event: %s", event)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user], suppress_auth=True
+ )
for event in creation_events:
yield handle_event(event)
@@ -391,8 +392,6 @@ class RoomMemberHandler(BaseHandler):
yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
# If we're banning someone, set a req power level
if event.membership == Membership.BAN:
@@ -414,6 +413,7 @@ class RoomMemberHandler(BaseHandler):
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
defer.returnValue({"room_id": room_id})
@@ -502,14 +502,11 @@ class RoomMemberHandler(BaseHandler):
if not have_joined:
logger.debug("Doing normal join")
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
user = self.hs.parse_userid(event.user_id)
@@ -553,7 +550,8 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
- def _do_local_membership_update(self, event, membership, snapshot):
+ def _do_local_membership_update(self, event, membership, snapshot,
+ do_auth):
destinations = []
# If we're inviting someone, then we should also send it to that
@@ -570,9 +568,10 @@ class RoomMemberHandler(BaseHandler):
return self._on_new_room_event(
event, snapshot, extra_destinations=destinations,
- extra_users=[target_user]
+ extra_users=[target_user], suppress_auth=(not do_auth),
)
+
class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..dc784c1527 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
+import logging
+
+
+logger = logging.getLogger(__name__)
+
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
-
- handler = self.handlers.event_stream_handler
- pagin_config = PaginationConfig.from_request(request)
- timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
- try:
- timeout = int(request.args["timeout"][0])
- except ValueError:
- raise SynapseError(400, "timeout must be in milliseconds.")
-
- chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
- timeout=timeout)
+ try:
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ chunk = yield handler.get_stream(
+ auth_user.to_string(), pagin_config, timeout=timeout
+ )
+ except:
+ logger.exception("Event stream failed")
+ raise
defer.returnValue((200, chunk))
diff --git a/synapse/server.py b/synapse/server.py
index a4d2d4aba5..d770b20b19 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -28,7 +28,7 @@ from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, EventID
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
@@ -143,6 +143,11 @@ class BaseHomeServer(object):
object."""
return RoomID.from_string(s, hs=self)
+ def parse_eventid(self, s):
+ """Parse the string given by 's' as a Event ID and return a EventID
+ object."""
+ return EventID.from_string(s, hs=self)
+
def serialize_event(self, e):
return serialize_event(self, e)
diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..414701b272 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -16,11 +16,14 @@
from twisted.internet import defer
-from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function
+from synapse.util.async import run_on_reactor
+
+from synapse.types import EventID
from collections import namedtuple
+import copy
import logging
import hashlib
@@ -35,13 +38,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object):
- """ Repsonsible for doing state conflict resolution.
+ """ Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self._replication = hs.get_replication_layer()
self.server_name = hs.hostname
+ self.hs = hs
@defer.inlineCallbacks
@log_function
@@ -50,7 +54,7 @@ class StateHandler(object):
to update the state and b) works out what the prev_state should be.
Returns:
- Deferred: Resolved with a boolean indicating if we succesfully
+ Deferred: Resolved with a boolean indicating if we successfully
updated the state.
Raised:
@@ -71,23 +75,22 @@ class StateHandler(object):
# (w.r.t. to power levels)
snapshot.fill_out_prev_events(event)
-
- event.prev_events = [
- e for e in event.prev_events if e != event.event_id
- ]
+ yield self.annotate_state_groups(event)
current_state = snapshot.prev_state_pdu
if current_state:
- event.prev_state = encode_event_id(
- current_state.pdu_id, current_state.origin
- )
+ event.prev_state = EventID.create(
+ current_state.pdu_id, current_state.origin, self.hs
+ ).to_string()
# TODO check current_state to see if the min power level is less
# than the power level of the user
# power_level = self._get_power_level_for_event(event)
- pdu_id, origin = decode_event_id(event.event_id, self.server_name)
+ e_id = self.hs.parse_eventid(event.event_id)
+ pdu_id = e_id.localpart
+ origin = e_id.domain
yield self.store.update_current_state(
pdu_id=pdu_id,
@@ -128,6 +131,135 @@ class StateHandler(object):
defer.returnValue(is_new)
+ @defer.inlineCallbacks
+ @log_function
+ def annotate_state_groups(self, event, old_state=None):
+ yield run_on_reactor()
+
+ if old_state:
+ event.state_group = None
+ event.old_state_events = old_state
+ event.state_events = {(s.type, s.state_key): s for s in old_state}
+
+ if hasattr(event, "state_key"):
+ event.state_events[(event.type, event.state_key)] = event
+
+ defer.returnValue(False)
+ return
+
+ if hasattr(event, "outlier") and event.outlier:
+ event.state_group = None
+ event.old_state_events = None
+ event.state_events = {}
+ defer.returnValue(False)
+ return
+
+ new_state = yield self.resolve_state_groups(
+ [e for e, _ in event.prev_events]
+ )
+
+ event.old_state_events = copy.deepcopy(new_state)
+
+ if hasattr(event, "state_key"):
+ new_state[(event.type, event.state_key)] = event
+
+ event.state_group = None
+ event.state_events = new_state
+
+ defer.returnValue(hasattr(event, "state_key"))
+
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ events = yield self.store.get_latest_events_in_room(room_id)
+
+ event_ids = [
+ e_id
+ for e_id, _, _ in events
+ ]
+
+ res = yield self.resolve_state_groups(event_ids)
+
+ if event_type:
+ defer.returnValue(res.get((event_type, state_key)))
+ return
+
+ defer.returnValue(res.values())
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(self, event_ids):
+ state_groups = yield self.store.get_state_groups(
+ event_ids
+ )
+
+ state = {}
+ for group in state_groups:
+ for s in group.state:
+ state.setdefault(
+ (s.type, s.state_key),
+ {}
+ )[s.event_id] = s
+
+ unconflicted_state = {
+ k: v.values()[0] for k, v in state.items()
+ if len(v.values()) == 1
+ }
+
+ conflicted_state = {
+ k: v.values()
+ for k, v in state.items()
+ if len(v.values()) > 1
+ }
+
+ try:
+ new_state = {}
+ new_state.update(unconflicted_state)
+ for key, events in conflicted_state.items():
+ new_state[key] = yield self._resolve_state_events(events)
+ except:
+ logger.exception("Failed to resolve state")
+ raise
+
+ defer.returnValue(new_state)
+
+ @defer.inlineCallbacks
+ @log_function
+ def _resolve_state_events(self, events):
+ curr_events = events
+
+ new_powers_deferreds = []
+ for e in curr_events:
+ new_powers_deferreds.append(
+ self.store.get_power_level(e.room_id, e.user_id)
+ )
+
+ new_powers = yield defer.gatherResults(
+ new_powers_deferreds,
+ consumeErrors=True
+ )
+
+ max_power = max([int(p) for p in new_powers])
+
+ curr_events = [
+ z[0] for z in zip(curr_events, new_powers)
+ if int(z[1]) == max_power
+ ]
+
+ if not curr_events:
+ raise RuntimeError("Max didn't get a max?")
+ elif len(curr_events) == 1:
+ defer.returnValue(curr_events[0])
+
+ # TODO: For now, just choose the one with the largest event_id.
+ defer.returnValue(
+ sorted(
+ curr_events,
+ key=lambda e: hashlib.sha1(
+ e.event_id + e.user_id + e.room_id + e.type
+ ).hexdigest()
+ )[0]
+ )
+
def _get_power_level_for_event(self, event):
# return self._persistence.get_power_level_for_user(event.room_id,
# event.sender)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4e9291fdff..d75c366834 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -40,6 +40,15 @@ from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore
from .keys import KeyStore
+from .event_federation import EventFederationStore
+
+from .state import StateStore
+from .signatures import SignatureStore
+
+from syutil.base64util import decode_base64
+
+from synapse.crypto.event_signing import compute_pdu_event_reference_hash
+
import json
import logging
@@ -59,6 +68,10 @@ SCHEMAS = [
"room_aliases",
"keys",
"redactions",
+ "state",
+ "signatures",
+ "event_edges",
+ "event_signatures",
]
@@ -73,10 +86,12 @@ class _RollbackButIsFineException(Exception):
"""
pass
+
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore,
- DirectoryStore, KeyStore):
+ DirectoryStore, KeyStore, StateStore, SignatureStore,
+ EventFederationStore, ):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
@@ -99,6 +114,7 @@ class DataStore(RoomMemberStore, RoomStore,
try:
yield self.runInteraction(
+ "persist_event",
self._persist_pdu_event_txn,
pdu=pdu,
event=event,
@@ -119,7 +135,8 @@ class DataStore(RoomMemberStore, RoomStore,
"type",
"room_id",
"content",
- "unrecognized_keys"
+ "unrecognized_keys",
+ "depth",
],
allow_none=allow_none,
)
@@ -144,6 +161,8 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_event_pdu_txn(self, txn, pdu):
cols = dict(pdu.__dict__)
unrec_keys = dict(pdu.unrecognized_keys)
+ del cols["hashes"]
+ del cols["signatures"]
del cols["content"]
del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content)
@@ -159,6 +178,33 @@ class DataStore(RoomMemberStore, RoomStore,
logger.debug("Persisting: %s", repr(cols))
+ for hash_alg, hash_base64 in pdu.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_pdu_content_hash_txn(
+ txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
+ )
+
+ signatures = pdu.signatures.get(pdu.origin, {})
+
+ for key_id, signature_base64 in signatures.items():
+ signature_bytes = decode_base64(signature_base64)
+ self._store_pdu_origin_signature_txn(
+ txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
+ )
+
+ for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_pdu_hash_txn(
+ txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
+ hash_bytes
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
+ self._store_pdu_reference_hash_txn(
+ txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
+ )
+
if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols)
else:
@@ -190,6 +236,10 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event)
+ outlier = False
+ if hasattr(event, "outlier"):
+ outlier = event.outlier
+
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
@@ -197,25 +247,30 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"content": json.dumps(event.content),
"processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
- if hasattr(event, "outlier"):
- vals["outlier"] = event.outlier
- else:
- vals["outlier"] = False
-
unrec = {
k: v
for k, v in event.get_full_dict().items()
- if k not in vals.keys() and k not in ["redacted", "redacted_because"]
+ if k not in vals.keys() and k not in [
+ "redacted", "redacted_because", "signatures", "hashes",
+ "prev_events",
+ ]
}
vals["unrecognized_keys"] = json.dumps(unrec)
try:
- self._simple_insert_txn(txn, "events", vals)
+ self._simple_insert_txn(
+ txn,
+ "events",
+ vals,
+ or_replace=(not outlier),
+ )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
@@ -224,6 +279,16 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ self._store_state_groups_txn(txn, event)
+
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
vals = {
@@ -249,6 +314,30 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
+ if hasattr(event, "signatures"):
+ signatures = event.signatures.get(event.origin, {})
+
+ for key_id, signature_base64 in signatures.items():
+ signature_bytes = decode_base64(signature_base64)
+ self._store_event_origin_signature_txn(
+ txn, event.event_id, event.origin, key_id, signature_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg, hash_bytes
+ )
+
+ # TODO
+ # (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
+ # self._store_event_reference_hash_txn(
+ # txn, event.event_id, ref_alg, ref_hash_bytes
+ # )
+
+ self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
+
def _store_redaction(self, txn, event):
txn.execute(
"INSERT OR IGNORE INTO redactions "
@@ -331,9 +420,8 @@ class DataStore(RoomMemberStore, RoomStore,
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id)
- prev_pdus = self._get_latest_pdus_in_context(
- txn, room_id
- )
+ prev_events = self._get_latest_events_in_room(txn, room_id)
+
if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key
@@ -345,14 +433,14 @@ class DataStore(RoomMemberStore, RoomStore,
store=self,
room_id=room_id,
user_id=user_id,
- prev_pdus=prev_pdus,
+ prev_events=prev_events,
membership_state=membership_state,
state_type=state_type,
state_key=state_key,
prev_state_pdu=prev_state_pdu,
)
- return self.runInteraction(_snapshot)
+ return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
@@ -361,7 +449,7 @@ class Snapshot(object):
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
- prev_pdus (list): The list of PDU ids this snapshot is after.
+ prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
@@ -370,13 +458,13 @@ class Snapshot(object):
the previous value of the state type and key in the room.
"""
- def __init__(self, store, room_id, user_id, prev_pdus,
+ def __init__(self, store, room_id, user_id, prev_events,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
- self.prev_pdus = prev_pdus
+ self.prev_events = prev_events
self.membership_state = membership_state
self.state_type = state_type
self.state_key = state_key
@@ -386,14 +474,13 @@ class Snapshot(object):
if hasattr(event, "prev_events"):
return
- es = [
- "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
+ event.prev_events = [
+ (event_id, hashes)
+ for event_id, hashes, _ in self.prev_events
]
- event.prev_events = [e for e in es if e != event.event_id]
-
- if self.prev_pdus:
- event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
+ if self.prev_events:
+ event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
@@ -452,9 +539,10 @@ def prepare_database(db_conn):
db_conn.commit()
else:
- sql_script = "BEGIN TRANSACTION;"
+ sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS:
sql_script += read_schema(sql_loc)
+ sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..464b12f032 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,54 +19,66 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
import collections
import copy
import json
+import sys
+import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
- __slots__ = ["txn"]
+ __slots__ = ["txn", "name"]
- def __init__(self, txn):
+ def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
- def __getattribute__(self, name):
- if name == "execute":
- return object.__getattribute__(self, "execute")
-
- return getattr(object.__getattribute__(self, "txn"), name)
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
def __setattr__(self, name, value):
- setattr(object.__getattribute__(self, "txn"), name, value)
+ setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] %s", sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
- sql_logger.debug("[SQL values] " +
- ", ".join(("<%s>",) * len(values)), *values)
+ sql_logger.debug(
+ "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+ self.name,
+ *values
+ )
except:
# Don't let logging failures stop SQL from working
pass
- # TODO(paul): Here would be an excellent place to put some timing
- # measurements, and log (warning?) slow queries.
- return object.__getattribute__(self, "txn").execute(
- sql, *args, **kwargs
- )
+ start = time.clock() * 1000
+ try:
+ return self.txn.execute(
+ sql, *args, **kwargs
+ )
+ except:
+ logger.exception("[SQL FAIL] {%s}", self.name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
+ _TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@@ -74,10 +86,30 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
- def runInteraction(self, func, *args, **kwargs):
+ def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
- return func(LoggingTransaction(txn), *args, **kwargs)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
@@ -113,7 +145,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction(interaction)
+ return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +162,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
+ "_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -170,7 +203,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
- @defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -181,19 +213,41 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
- ret = yield self._simple_select_one(
+ return self.runInteraction(
+ "_simple_select_one_onecol_txn",
+ self._simple_select_one_onecol_txn,
+ table, keyvalues, retcol, allow_none=allow_none,
+ )
+
+ def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ allow_none=False):
+ ret = self._simple_select_onecol_txn(
+ txn,
table=table,
keyvalues=keyvalues,
- retcols=[retcol],
- allow_none=allow_none
+ retcol=retcol,
)
if ret:
- defer.returnValue(ret[retcol])
+ return ret[0]
else:
- defer.returnValue(None)
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ txn.execute(sql, keyvalues.values())
+
+ return [r[0] for r in txn.fetchall()]
+
- @defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -206,19 +260,11 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
- sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
- "retcol": retcol,
- "table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- }
-
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return txn.fetchall()
-
- res = yield self.runInteraction(func)
-
- defer.returnValue([r[0] for r in res])
+ return self.runInteraction(
+ "_simple_select_onecol",
+ self._simple_select_onecol_txn,
+ table, keyvalues, retcol
+ )
def _simple_select_list(self, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
@@ -239,7 +285,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
- return self.runInteraction(func)
+ return self.runInteraction("_simple_select_list", func)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -307,7 +353,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction(func)
+ return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +365,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@@ -328,7 +374,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction(func)
+ return self.runInteraction("_simple_delete_one", func)
+
+ def _simple_delete(self, table, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+
+ return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+ def _simple_delete_txn(self, txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+
+ return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +410,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self.runInteraction(func)
+ return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -370,7 +434,9 @@ class SQLBaseStore(object):
)
def _parse_events(self, rows):
- return self.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
@@ -378,6 +444,17 @@ class SQLBaseStore(object):
sql = "SELECT * FROM events WHERE event_id = ?"
for ev in events:
+ signatures = self._get_event_origin_signatures_txn(
+ txn, ev.event_id,
+ )
+
+ ev.signatures = {
+ k: encode_base64(v) for k, v in signatures.items()
+ }
+
+ prev_events = self._get_latest_events_in_room(txn, ev.room_id)
+ ev.prev_events = [(e_id, s,) for e_id, s, _ in prev_events]
+
if hasattr(ev, "prev_state"):
# Load previous state_content.
# TODO: Should we be pulling this out above?
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 52373a28a6..d6a7113b9c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
def delete_room_alias(self, room_alias):
return self.runInteraction(
+ "delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
new file mode 100644
index 0000000000..88d09d9ba8
--- /dev/null
+++ b/synapse/storage/event_federation.py
@@ -0,0 +1,162 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from syutil.base64util import encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class EventFederationStore(SQLBaseStore):
+
+ def get_latest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_latest_events_in_room",
+ self._get_latest_events_in_room,
+ room_id,
+ )
+
+ def _get_latest_events_in_room(self, txn, room_id):
+ self._simple_select_onecol_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ )
+
+ sql = (
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "WHERE f.room_id = ?"
+ )
+
+ txn.execute(sql, (room_id, ))
+
+ results = []
+ for event_id, depth in txn.fetchall():
+ hashes = self._get_prev_event_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes, depth))
+
+ return results
+
+ def _get_min_depth_interaction(self, txn, room_id):
+ min_depth = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id,},
+ retcol="min_depth",
+ allow_none=True,
+ )
+
+ return int(min_depth) if min_depth is not None else None
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ self._simple_insert_txn(
+ txn,
+ table="room_depth",
+ values={
+ "room_id": room_id,
+ "min_depth": depth,
+ },
+ or_replace=True,
+ )
+
+ def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+ room_id):
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_delete_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "event_id": e_id,
+ "room_id": room_id,
+ }
+ )
+
+
+
+ # We only insert as a forward extremity the new pdu if there are no
+ # other pdus that reference it as a prev pdu
+ query = (
+ "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
+ "SELECT ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(event_edges)s WHERE "
+ "prev_event_id = ? "
+ ")"
+ ) % {
+ "table": "event_forward_extremities",
+ "event_edges": "event_edges",
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (event_id, room_id, event_id))
+
+ # Insert all the prev_pdus as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference pdus that we have already seen
+ query = (
+ "DELETE FROM event_backward_extremities WHERE EXISTS ("
+ "SELECT 1 FROM events "
+ "WHERE "
+ "event_backward_extremities.event_id = events.event_id "
+ "AND not events.outlier "
+ ")"
+ )
+ txn.execute(query)
\ No newline at end of file
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index d70467dcd6..4a4341907b 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
+
from collections import namedtuple
import logging
+
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
- self._get_pdu_tuple, pdu_id, origin
+ "get_pdu", self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
@@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
for r in PduEdgesTable.decode_results(txn.fetchall())
]
+ edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
+
+ hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
+ signatures = self._get_pdu_origin_signatures_txn(
+ txn, pdu_id, origin
+ )
+
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
@@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
row = txn.fetchone()
if row:
- results.append(PduTuple(PduEntry(*row), edges))
+ results.append(PduTuple(
+ PduEntry(*row), edges, hashes, signatures, edge_hashes
+ ))
return results
@@ -96,6 +108,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_current_state_for_context",
self._get_current_state_for_context,
context
)
@@ -144,6 +157,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "mark_pdu_as_processed",
self._mark_as_processed, pdu_id, pdu_origin
)
@@ -153,6 +167,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self.runInteraction(
+ "get_all_pdus_from_context",
self._get_all_pdus_from_context, context,
)
@@ -180,6 +195,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples
"""
return self.runInteraction(
+ "get_backfill",
self._get_backfill, context, pdu_list, limit
)
@@ -241,6 +257,7 @@ class PduStore(SQLBaseStore):
context (str)
"""
return self.runInteraction(
+ "get_min_depth_for_context",
self._get_min_depth_for_context, context
)
@@ -277,6 +294,13 @@ class PduStore(SQLBaseStore):
(context, depth)
)
+ def get_latest_pdus_in_context(self, context):
+ return self.runInteraction(
+ "get_latest_pdus_in_context",
+ self._get_latest_pdus_in_context,
+ context
+ )
+
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
@@ -303,9 +327,14 @@ class PduStore(SQLBaseStore):
(context, )
)
- results = txn.fetchall()
+ results = []
+ for pdu_id, origin, depth in txn.fetchall():
+ hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
+ sha256_bytes = hashes["sha256"]
+ prev_hashes = {"sha256": encode_base64(sha256_bytes)}
+ results.append((pdu_id, origin, prev_hashes, depth))
- return [(row[0], row[1], row[2]) for row in results]
+ return results
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
@@ -347,6 +376,7 @@ class PduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "is_pdu_new",
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
@@ -424,7 +454,7 @@ class PduStore(SQLBaseStore):
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
- txn.executemany(query, prev_pdus)
+ txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
@@ -447,7 +477,7 @@ class PduStore(SQLBaseStore):
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o in prev_pdus]
+ [(i, o, context) for i, o, _ in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
@@ -500,6 +530,7 @@ class StatePduStore(SQLBaseStore):
def get_unresolved_state_tree(self, new_state_pdu):
return self.runInteraction(
+ "get_unresolved_state_tree",
self._get_unresolved_state_tree, new_state_pdu
)
@@ -539,6 +570,7 @@ class StatePduStore(SQLBaseStore):
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self.runInteraction(
+ "update_current_state",
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
@@ -578,6 +610,7 @@ class StatePduStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_current_state_pdu",
self._get_current_state_pdu, context, pdu_type, state_key
)
@@ -637,6 +670,7 @@ class StatePduStore(SQLBaseStore):
bool: True if the new_pdu clobbered the current state, False if not
"""
return self.runInteraction(
+ "handle_new_state",
self._handle_new_state, new_pdu
)
@@ -908,7 +942,7 @@ This does not include a prev_pdus key.
PduTuple = namedtuple(
"PduTuple",
- ("pdu_entry", "prev_pdu_list")
+ ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 719806f82b..a2ca6f9a69 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(self._register, user_id, token,
- password_hash)
+ yield self.runInteraction(
+ "register",
+ self._register, user_id, token, password_hash
+ )
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
@@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
+ "get_user_by_token",
self._query_for_auth,
token
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8cd46334cf..7e48ce9cc3 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore):
def get_power_level(self, room_id, user_id):
return self.runInteraction(
+ "get_power_level",
self._get_power_level,
room_id, user_id,
)
@@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore):
def get_ops_levels(self, room_id):
return self.runInteraction(
+ "get_ops_levels",
self._get_ops_levels,
room_id,
)
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
new file mode 100644
index 0000000000..e5f768c705
--- /dev/null
+++ b/synapse/storage/schema/event_edges.sql
@@ -0,0 +1,49 @@
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+ event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+ event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+ event_id TEXT,
+ prev_event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
+);
+
+CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
+CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+ room_id TEXT,
+ min_depth INTEGER,
+ CONSTRAINT uniqueness UNIQUE (room_id)
+);
+
+CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+ event_id TEXT,
+ destination TEXT,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql
new file mode 100644
index 0000000000..5491c7ecec
--- /dev/null
+++ b/synapse/storage/schema/event_signatures.sql
@@ -0,0 +1,65 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_origin_signatures (
+ event_id TEXT,
+ origin TEXT,
+ key_id TEXT,
+ signature BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+ event_id TEXT,
+ prev_event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ event_id, prev_event_id, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
+ event_id
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 3aa83f5c8c..8d6f655993 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
+ depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql
new file mode 100644
index 0000000000..1c45a51bec
--- /dev/null
+++ b/synapse/storage/schema/signatures.sql
@@ -0,0 +1,66 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS pdu_content_hashes (
+ pdu_id TEXT,
+ origin TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes (
+ pdu_id, origin
+);
+
+CREATE TABLE IF NOT EXISTS pdu_reference_hashes (
+ pdu_id TEXT,
+ origin TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes (
+ pdu_id, origin
+);
+
+CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
+ pdu_id TEXT,
+ origin TEXT,
+ key_id TEXT,
+ signature BLOB,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
+ pdu_id, origin
+);
+
+CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
+ pdu_id TEXT,
+ origin TEXT,
+ prev_pdu_id TEXT,
+ prev_origin TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ pdu_id, origin, prev_pdu_id, prev_origin, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
+ pdu_id, origin
+);
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
new file mode 100644
index 0000000000..b44c56b519
--- /dev/null
+++ b/synapse/storage/schema/state.sql
@@ -0,0 +1,33 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+ id INTEGER PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+ state_group INTEGER NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+ event_id TEXT NOT NULL,
+ state_group INTEGER NOT NULL
+);
\ No newline at end of file
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
new file mode 100644
index 0000000000..5e99174fcd
--- /dev/null
+++ b/synapse/storage/signatures.py
@@ -0,0 +1,302 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class SignatureStore(SQLBaseStore):
+ """Persistence for PDU signatures and hashes"""
+
+ def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin):
+ """Get all the hashes for a given PDU.
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM pdu_content_hashes"
+ " WHERE pdu_id = ? and origin = ?"
+ )
+ txn.execute(query, (pdu_id, origin))
+ return dict(txn.fetchall())
+
+ def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm,
+ hash_bytes):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(txn, "pdu_content_hashes", {
+ "pdu_id": pdu_id,
+ "origin": origin,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ })
+
+ def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin):
+ """Get all the hashes for a given PDU.
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM pdu_reference_hashes"
+ " WHERE pdu_id = ? and origin = ?"
+ )
+ txn.execute(query, (pdu_id, origin))
+ return dict(txn.fetchall())
+
+ def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm,
+ hash_bytes):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(txn, "pdu_reference_hashes", {
+ "pdu_id": pdu_id,
+ "origin": origin,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ })
+
+
+ def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin):
+ """Get all the signatures for a given PDU.
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ Returns:
+ A dict of key_id -> signature_bytes.
+ """
+ query = (
+ "SELECT key_id, signature"
+ " FROM pdu_origin_signatures"
+ " WHERE pdu_id = ? and origin = ?"
+ )
+ txn.execute(query, (pdu_id, origin))
+ return dict(txn.fetchall())
+
+ def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id,
+ signature_bytes):
+ """Store a signature from the origin server for a PDU.
+ Args:
+ txn (cursor):
+ pdu_id (str): Id for the PDU.
+ origin (str): origin of the PDU.
+ key_id (str): Id for the signing key.
+ signature (bytes): The signature.
+ """
+ self._simple_insert_txn(txn, "pdu_origin_signatures", {
+ "pdu_id": pdu_id,
+ "origin": origin,
+ "key_id": key_id,
+ "signature": buffer(signature_bytes),
+ })
+
+ def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
+ """Get all the hashes for previous PDUs of a PDU
+ Args:
+ txn (cursor):
+ pdu_id (str): Id of the PDU.
+ origin (str): Origin of the PDU.
+ Returns:
+ dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+ """
+ query = (
+ "SELECT prev_pdu_id, prev_origin, algorithm, hash"
+ " FROM pdu_edge_hashes"
+ " WHERE pdu_id = ? and origin = ?"
+ )
+ txn.execute(query, (pdu_id, origin))
+ results = {}
+ for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault((prev_pdu_id, prev_origin), {})
+ hashes[algorithm] = hash_bytes
+ return results
+
+ def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
+ prev_origin, algorithm, hash_bytes):
+ self._simple_insert_txn(txn, "pdu_edge_hashes", {
+ "pdu_id": pdu_id,
+ "origin": origin,
+ "prev_pdu_id": prev_pdu_id,
+ "prev_origin": prev_origin,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ })
+
+ ## Events ##
+
+ def _get_event_content_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given Event.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_content_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_content_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a Event
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_content_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def _get_event_reference_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_reference_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_reference_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+
+ def _get_event_origin_signatures_txn(self, txn, event_id):
+ """Get all the signatures for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of key_id -> signature_bytes.
+ """
+ query = (
+ "SELECT key_id, signature"
+ " FROM event_origin_signatures"
+ " WHERE event_id = ? "
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
+ signature_bytes):
+ """Store a signature from the origin server for a PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ origin (str): origin of the Event.
+ key_id (str): Id for the signing key.
+ signature (bytes): The signature.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_origin_signatures",
+ {
+ "event_id": event_id,
+ "origin": origin,
+ "key_id": key_id,
+ "signature": buffer(signature_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def _get_prev_event_hashes_txn(self, txn, event_id):
+ """Get all the hashes for previous PDUs of a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+ """
+ query = (
+ "SELECT prev_event_id, algorithm, hash"
+ " FROM event_edge_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ results = {}
+ for prev_event_id, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault(prev_event_id, {})
+ hashes[algorithm] = hash_bytes
+ return results
+
+ def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+ algorithm, hash_bytes):
+ self._simple_insert_txn(
+ txn,
+ "event_edge_hashes",
+ {
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
new file mode 100644
index 0000000000..e08acd6404
--- /dev/null
+++ b/synapse/storage/state.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from collections import namedtuple
+
+
+StateGroup = namedtuple("StateGroup", ("group", "state"))
+
+
+class StateStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, event_ids):
+ groups = set()
+ for event_id in event_ids:
+ group = yield self._simple_select_one_onecol(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ groups.add(group)
+
+ res = []
+ for group in groups:
+ state_ids = yield self._simple_select_onecol(
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+ state = []
+ for state_id in state_ids:
+ s = yield self.get_event(
+ state_id,
+ allow_none=True,
+ )
+ if s:
+ state.append(s)
+
+ res.append(StateGroup(group, state))
+
+ defer.returnValue(res)
+
+ def store_state_groups(self, event):
+ return self.runInteraction(
+ "store_state_groups",
+ self._store_state_groups_txn, event
+ )
+
+ def _store_state_groups_txn(self, txn, event):
+ if not event.state_events:
+ return
+
+ state_group = event.state_group
+ if not state_group:
+ state_group = self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ }
+ )
+
+ for state in event.state_events.values():
+ self._simple_insert_txn(
+ txn,
+ table="state_groups_state",
+ values={
+ "state_group": state_group,
+ "room_id": state.room_id,
+ "type": state.type,
+ "state_key": state.state_key,
+ "event_id": state.event_id,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="event_to_state_groups",
+ values={
+ "state_group": state_group,
+ "event_id": event.event_id,
+ }
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d61f909939..8f7f61d29d 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
- return self.runInteraction(self._get_room_events_max_id_txn)
+ return self.runInteraction(
+ "get_room_events_max_id",
+ self._get_room_events_max_id_txn
+ )
def _get_room_events_max_id_txn(self, txn):
txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 2ba8e30efe..908014d38b 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -42,6 +42,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
)
@@ -73,6 +74,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "set_received_txn_response",
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@@ -106,6 +108,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "prep_send_transaction",
self._prep_send_transaction,
transaction_id, destination, origin_server_ts, pdu_list
)
@@ -161,6 +164,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.runInteraction(
+ "delivered_txn",
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@@ -186,6 +190,7 @@ class TransactionStore(SQLBaseStore):
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self.runInteraction(
+ "get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
@@ -216,6 +221,7 @@ class TransactionStore(SQLBaseStore):
list: A list of PduTuple
"""
return self.runInteraction(
+ "get_pdus_after_transaction",
self._get_pdus_after_transaction,
transaction_id, destination
)
diff --git a/synapse/types.py b/synapse/types.py
index c51bc8e4f2..649ff2f7d7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -78,6 +78,11 @@ class DomainSpecificString(
"""Create a structure on the local domain"""
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
+ @classmethod
+ def create(cls, localpart, domain, hs):
+ is_mine = domain == hs.hostname
+ return cls(localpart=localpart, domain=domain, is_mine=is_mine)
+
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
@@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
SIGIL = "!"
+class EventID(DomainSpecificString):
+ """Structure representing an event id. """
+ SIGIL = "$"
+
+
class StreamToken(
namedtuple(
"Token",
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 647ea6142c..bf578f8bfb 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -21,3 +21,10 @@ def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
return d
+
+
+def run_on_reactor():
+ """ This will cause the rest of the function to be invoked upon the next
+ iteration of the main loop
+ """
+ return sleep(0)
\ No newline at end of file
|