diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e1b1823cd7..d4f284bd60 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -21,8 +21,10 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
+ RoomJoinRulesEvent, RoomCreateEvent,
)
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
import logging
@@ -35,8 +37,7 @@ class Auth(object):
self.hs = hs
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def check(self, event, snapshot, raises=False):
+ def check(self, event, raises=False):
""" Checks if this event is correctly authed.
Returns:
@@ -47,43 +48,48 @@ class Auth(object):
"""
try:
if hasattr(event, "room_id"):
- is_state = hasattr(event, "state_key")
+ if event.old_state_events is None:
+ # Oh, we don't know what the state of the room was, so we
+ # are trusting that this is allowed (at least for now)
+ logger.warn("Trusting event: %s", event.event_id)
+ return True
+
+ if hasattr(event, "outlier") and event.outlier is True:
+ # TODO (erikj): Auth for outliers is done differently.
+ return True
+
+ if event.type == RoomCreateEvent.TYPE:
+ # FIXME
+ return True
if event.type == RoomMemberEvent.TYPE:
- yield self._can_replace_state(event)
- allowed = yield self.is_membership_change_allowed(event)
- defer.returnValue(allowed)
- return
-
- self._check_joined_room(
- member=snapshot.membership_state,
- user_id=snapshot.user_id,
- room_id=snapshot.room_id,
- )
+ allowed = self.is_membership_change_allowed(event)
+ if allowed:
+ logger.debug("Allowing! %s", event)
+ else:
+ logger.debug("Denying! %s", event)
+ return allowed
- if is_state:
- # TODO (erikj): This really only should be called for *new*
- # state
- yield self._can_add_state(event)
- yield self._can_replace_state(event)
- else:
- yield self._can_send_event(event)
+ self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
- yield self._check_power_levels(event)
+ self._check_power_levels(event)
if event.type == RoomRedactionEvent.TYPE:
- yield self._check_redaction(event)
+ self._check_redaction(event)
- defer.returnValue(True)
+ logger.debug("Allowing! %s", event)
+ return True
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e:
logger.info("Event auth check failed on event %s with msg: %s",
event, e.msg)
+ logger.info("Denying! %s", event)
if raises:
raise e
- defer.returnValue(False)
+
+ return False
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id):
@@ -98,45 +104,72 @@ class Auth(object):
pass
defer.returnValue(None)
+ def check_event_sender_in_room(self, event):
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ member_event = event.state_events.get(key)
+
+ return self._check_joined_room(
+ member_event,
+ event.user_id,
+ event.room_id
+ )
+
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
- @defer.inlineCallbacks
+ @log_function
def is_membership_change_allowed(self, event):
target_user_id = event.state_key
- # does this room even exist
- room = yield self.store.get_room(event.room_id)
- if not room:
- raise AuthError(403, "Room does not exist")
-
# get info about the caller
- try:
- caller = yield self.store.get_room_member(
- user_id=event.user_id,
- room_id=event.room_id)
- except:
- caller = None
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ caller = event.old_state_events.get(key)
+
caller_in_room = caller and caller.membership == "join"
# get info about the target
- try:
- target = yield self.store.get_room_member(
- user_id=target_user_id,
- room_id=event.room_id)
- except:
- target = None
+ key = (RoomMemberEvent.TYPE, target_user_id, )
+ target = event.old_state_events.get(key)
+
target_in_room = target and target.membership == "join"
membership = event.content["membership"]
- join_rule = yield self.store.get_room_join_rule(event.room_id)
- if not join_rule:
+ key = (RoomJoinRulesEvent.TYPE, "", )
+ join_rule_event = event.old_state_events.get(key)
+ if join_rule_event:
+ join_rule = join_rule_event.content.get(
+ "join_rule", JoinRules.INVITE
+ )
+ else:
join_rule = JoinRules.INVITE
+ user_level = self._get_power_level_from_event_state(
+ event,
+ event.user_id,
+ )
+
+ ban_level, kick_level, redact_level = (
+ self._get_ops_level_from_event_state(
+ event
+ )
+ )
+
+ logger.debug(
+ "is_membership_change_allowed: %s",
+ {
+ "caller_in_room": caller_in_room,
+ "target_in_room": target_in_room,
+ "membership": membership,
+ "join_rule": join_rule,
+ "target_user_id": target_user_id,
+ "event.user_id": event.user_id,
+ }
+ )
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -153,13 +186,10 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
- elif join_rule == JoinRules.PUBLIC or room.is_public:
+ elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
- if (
- not caller or caller.membership not in
- [Membership.INVITE, Membership.JOIN]
- ):
+ if not caller_in_room:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
@@ -171,29 +201,16 @@ class Auth(object):
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(403, "You are not in room %s." % event.room_id)
elif target_user_id != event.user_id:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
- _, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
-
if kick_level:
kick_level = int(kick_level)
else:
- kick_level = 50
+ kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
-
- ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
-
if ban_level:
ban_level = int(ban_level)
else:
@@ -204,7 +221,30 @@ class Auth(object):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- defer.returnValue(True)
+ return True
+
+ def _get_power_level_from_event_state(self, event, user_id):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
+ level = None
+ if power_level_event:
+ level = power_level_event.content.get("users", {}).get(user_id)
+ if not level:
+ level = power_level_event.content.get("users_default", 0)
+
+ return level
+
+ def _get_ops_level_from_event_state(self, event):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
+
+ if power_level_event:
+ return (
+ power_level_event.content.get("ban", 50),
+ power_level_event.content.get("kick", 50),
+ power_level_event.content.get("redact", 50),
+ )
+ return None, None, None,
@defer.inlineCallbacks
def get_user_by_req(self, request):
@@ -273,68 +313,79 @@ class Auth(object):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
- @log_function
- def _can_send_event(self, event):
- send_level = yield self.store.get_send_event_level(event.room_id)
+ def add_auth_events(self, event):
+ if event.type == RoomCreateEvent.TYPE:
+ event.auth_events = []
+ return
- if send_level:
- send_level = int(send_level)
- else:
- send_level = 0
+ auth_events = []
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
- if user_level:
- user_level = int(user_level)
- else:
- user_level = 0
+ if power_level_event:
+ auth_events.append(power_level_event.event_id)
- if user_level < send_level:
- raise AuthError(
- 403, "You don't have permission to post to the room"
- )
+ key = (RoomJoinRulesEvent.TYPE, "", )
+ join_rule_event = event.old_state_events.get(key)
- defer.returnValue(True)
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ member_event = event.old_state_events.get(key)
- @defer.inlineCallbacks
- def _can_add_state(self, event):
- add_level = yield self.store.get_add_state_level(event.room_id)
+ if join_rule_event:
+ join_rule = join_rule_event.content.get("join_rule")
+ is_public = join_rule == JoinRules.PUBLIC if join_rule else False
- if not add_level:
- defer.returnValue(True)
+ if event.type == RoomMemberEvent.TYPE:
+ if event.content["membership"] == Membership.JOIN:
+ if is_public:
+ auth_events.append(join_rule_event.event_id)
+ elif member_event:
+ auth_events.append(member_event.event_id)
- add_level = int(add_level)
+ if member_event:
+ if member_event.content["membership"] == Membership.JOIN:
+ auth_events.append(member_event.event_id)
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
+ hashes = yield self.store.get_event_reference_hashes(
+ auth_events
)
+ hashes = [
+ {
+ k: encode_base64(v) for k, v in h.items()
+ if k == "sha256"
+ }
+ for h in hashes
+ ]
+ event.auth_events = zip(auth_events, hashes)
- user_level = int(user_level)
- if user_level < add_level:
- raise AuthError(
- 403, "You don't have permission to add state to the room"
+ @log_function
+ def _can_send_event(self, event):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ send_level_event = event.old_state_events.get(key)
+ send_level = None
+ if send_level_event:
+ send_level = send_level_event.content.get("events", {}).get(
+ event.type
)
+ if not send_level:
+ if hasattr(event, "state_key"):
+ send_level = send_level_event.content.get(
+ "state_default", 50
+ )
+ else:
+ send_level = send_level_event.content.get(
+ "events_default", 0
+ )
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- def _can_replace_state(self, event):
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
-
- if current_state:
- current_state = current_state[0]
+ if send_level:
+ send_level = int(send_level)
+ else:
+ send_level = 0
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -343,35 +394,22 @@ class Auth(object):
else:
user_level = 0
- logger.debug(
- "Checking power level for %s, %s", event.user_id, user_level
- )
- if current_state and hasattr(current_state, "required_power_level"):
- req = current_state.required_power_level
+ if user_level < send_level:
+ raise AuthError(
+ 403, "You don't have permission to post that to the room"
+ )
- logger.debug("Checked power level for %s, %s", event.user_id, req)
- if user_level < req:
- raise AuthError(
- 403,
- "You don't have permission to change that state"
- )
+ return True
- @defer.inlineCallbacks
def _check_redaction(self, event):
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
- if user_level:
- user_level = int(user_level)
- else:
- user_level = 0
-
- _, _, redact_level = yield self.store.get_ops_levels(event.room_id)
-
- if not redact_level:
- redact_level = 50
+ _, _, redact_level = self._get_ops_level_from_event_state(
+ event
+ )
if user_level < redact_level:
raise AuthError(
@@ -379,16 +417,10 @@ class Auth(object):
"You don't have permission to redact events"
)
- @defer.inlineCallbacks
def _check_power_levels(self, event):
- for k, v in event.content.items():
- if k == "default":
- continue
-
- # FIXME (erikj): We don't want hsob_Ts in content.
- if k == "hsob_ts":
- continue
-
+ user_list = event.content.get("users", {})
+ # Validate users
+ for k, v in user_list.items():
try:
self.hs.parse_userid(k)
except:
@@ -399,80 +431,68 @@ class Auth(object):
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
+ key = (event.type, event.state_key, )
+ current_state = event.old_state_events.get(key)
if not current_state:
return
- else:
- current_state = current_state[0]
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
- if user_level:
- user_level = int(user_level)
- else:
- user_level = 0
+ # Check other levels:
+ levels_to_check = [
+ ("users_default", []),
+ ("events_default", []),
+ ("ban", []),
+ ("redact", []),
+ ("kick", []),
+ ]
+
+ old_list = current_state.content.get("users")
+ for user in set(old_list.keys() + user_list.keys()):
+ levels_to_check.append(
+ (user, ["users"])
+ )
- old_list = current_state.content
+ old_list = current_state.content.get("events")
+ new_list = event.content.get("events")
+ for ev_id in set(old_list.keys() + new_list.keys()):
+ levels_to_check.append(
+ (ev_id, ["events"])
+ )
- # FIXME (erikj)
- old_people = {k: v for k, v in old_list.items() if k.startswith("@")}
- new_people = {
- k: v for k, v in event.content.items()
- if k.startswith("@")
- }
+ old_state = current_state.content
+ new_state = event.content
- removed = set(old_people.keys()) - set(new_people.keys())
- added = set(new_people.keys()) - set(old_people.keys())
- same = set(old_people.keys()) & set(new_people.keys())
+ for level_to_check, dir in levels_to_check:
+ old_loc = old_state
+ for d in dir:
+ old_loc = old_loc.get(d, {})
- for r in removed:
- if int(old_list[r]) > user_level:
- raise AuthError(
- 403,
- "You don't have permission to remove user: %s" % (r, )
- )
+ new_loc = new_state
+ for d in dir:
+ new_loc = new_loc.get(d, {})
- for n in added:
- if int(event.content[n]) > user_level:
- raise AuthError(
- 403,
- "You don't have permission to add ops level greater "
- "than your own"
- )
+ if level_to_check in old_loc:
+ old_level = int(old_loc[level_to_check])
+ else:
+ old_level = None
- for s in same:
- if int(event.content[s]) != int(old_list[s]):
- if int(event.content[s]) > user_level:
- raise AuthError(
- 403,
- "You don't have permission to add ops level greater "
- "than your own"
- )
+ if level_to_check in new_loc:
+ new_level = int(new_loc[level_to_check])
+ else:
+ new_level = None
- if "default" in old_list:
- old_default = int(old_list["default"])
+ if new_level is not None and old_level is not None:
+ if new_level == old_level:
+ continue
- if old_default > user_level:
+ if old_level > user_level or new_level > user_level:
raise AuthError(
403,
- "You don't have permission to add ops level greater than "
- "your own"
+ "You don't have permission to add ops level greater "
+ "than your own"
)
-
- if "default" in event.content:
- new_default = int(event.content["default"])
-
- if new_default > user_level:
- raise AuthError(
- 403,
- "You don't have permission to add ops level greater "
- "than your own"
- )
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 38ccb4f9d1..33d15072af 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -158,3 +158,37 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
for key, value in kwargs.iteritems():
err[key] = value
return err
+
+
+class FederationError(RuntimeError):
+ """ This class is used to inform remote home servers about erroneous
+ PDUs they sent us.
+
+ FATAL: The remote server could not interpret the source event.
+ (e.g., it was missing a required field)
+ ERROR: The remote server interpreted the event, but it failed some other
+ check (e.g. auth)
+ WARN: The remote server accepted the event, but believes some part of it
+ is wrong (e.g., it referred to an invalid event)
+ """
+
+ def __init__(self, level, code, reason, affected, source=None):
+ if level not in ["FATAL", "ERROR", "WARN"]:
+ raise ValueError("Level is not valid: %s" % (level,))
+ self.level = level
+ self.code = code
+ self.reason = reason
+ self.affected = affected
+ self.source = source
+
+ msg = "%s %s: %s" % (level, code, reason,)
+ super(FederationError, self).__init__(msg)
+
+ def get_dict(self):
+ return {
+ "level": self.level,
+ "code": self.code,
+ "reason": self.reason,
+ "affected": self.affected,
+ "source": self.source if self.source else self.affected,
+ }
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index f66fea2904..e5980c4be3 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -56,22 +56,25 @@ class SynapseEvent(JsonEncodedObject):
"user_id", # sender/initiator
"content", # HTTP body, JSON
"state_key",
- "required_power_level",
"age_ts",
"prev_content",
- "prev_state",
+ "replaces_state",
"redacted_because",
+ "origin_server_ts",
]
internal_keys = [
"is_state",
- "prev_events",
"depth",
"destinations",
"origin",
"outlier",
- "power_level",
"redacted",
+ "prev_events",
+ "hashes",
+ "signatures",
+ "prev_state",
+ "auth_events",
]
required_keys = [
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
index 74d0ef77f4..a1ec708a81 100644
--- a/synapse/api/events/factory.py
+++ b/synapse/api/events/factory.py
@@ -16,11 +16,13 @@
from synapse.api.events.room import (
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
- RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent,
- RoomCreateEvent, RoomAddStateLevelEvent, RoomSendEventLevelEvent,
+ RoomPowerLevelsEvent, RoomJoinRulesEvent,
+ RoomCreateEvent,
RoomRedactionEvent,
)
+from synapse.types import EventID
+
from synapse.util.stringutils import random_string
@@ -37,9 +39,6 @@ class EventFactory(object):
RoomPowerLevelsEvent,
RoomJoinRulesEvent,
RoomCreateEvent,
- RoomAddStateLevelEvent,
- RoomSendEventLevelEvent,
- RoomOpsPowerLevelsEvent,
RoomRedactionEvent,
]
@@ -51,12 +50,26 @@ class EventFactory(object):
self.clock = hs.get_clock()
self.hs = hs
+ self.event_id_count = 0
+
+ def create_event_id(self):
+ i = str(self.event_id_count)
+ self.event_id_count += 1
+
+ local_part = str(int(self.clock.time())) + i + random_string(5)
+
+ e_id = EventID.create_local(local_part, self.hs)
+
+ return e_id.to_string()
+
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
- kwargs["event_id"] = "%s@%s" % (
- random_string(10), self.hs.hostname
- )
+ kwargs["event_id"] = self.create_event_id()
+ kwargs["origin"] = self.hs.hostname
+ else:
+ ev_id = self.hs.parse_eventid(kwargs["event_id"])
+ kwargs["origin"] = ev_id.domain
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
diff --git a/synapse/api/events/room.py b/synapse/api/events/room.py
index cd936074fc..25bc883706 100644
--- a/synapse/api/events/room.py
+++ b/synapse/api/events/room.py
@@ -153,28 +153,6 @@ class RoomPowerLevelsEvent(SynapseStateEvent):
def get_content_template(self):
return {}
-
-class RoomAddStateLevelEvent(SynapseStateEvent):
- TYPE = "m.room.add_state_level"
-
- def get_content_template(self):
- return {}
-
-
-class RoomSendEventLevelEvent(SynapseStateEvent):
- TYPE = "m.room.send_event_level"
-
- def get_content_template(self):
- return {}
-
-
-class RoomOpsPowerLevelsEvent(SynapseStateEvent):
- TYPE = "m.room.ops_levels"
-
- def get_content_template(self):
- return {}
-
-
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"
diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py
index c3a32be8c1..5fc79105b5 100644
--- a/synapse/api/events/utils.py
+++ b/synapse/api/events/utils.py
@@ -15,7 +15,6 @@
from .room import (
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
- RoomAddStateLevelEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent,
RoomAliasesEvent, RoomCreateEvent,
)
@@ -27,7 +26,14 @@ def prune_event(event):
the user has specified, but we do want to keep necessary information like
type, state_key etc.
"""
+ return _prune_event_or_pdu(event.type, event)
+def prune_pdu(pdu):
+ """Removes keys that contain unrestricted and non-essential data from a PDU
+ """
+ return _prune_event_or_pdu(pdu.type, pdu)
+
+def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields.
event.unrecognized_keys = {}
@@ -38,25 +44,25 @@ def prune_event(event):
if field in event.content:
new_content[field] = event.content[field]
- if event.type == RoomMemberEvent.TYPE:
+ if event_type == RoomMemberEvent.TYPE:
add_fields("membership")
- elif event.type == RoomCreateEvent.TYPE:
+ elif event_type == RoomCreateEvent.TYPE:
add_fields("creator")
- elif event.type == RoomJoinRulesEvent.TYPE:
+ elif event_type == RoomJoinRulesEvent.TYPE:
add_fields("join_rule")
- elif event.type == RoomPowerLevelsEvent.TYPE:
- # TODO: Actually check these are valid user_ids etc.
- add_fields("default")
- for k, v in event.content.items():
- if k.startswith("@") and isinstance(v, (int, long)):
- new_content[k] = v
- elif event.type == RoomAddStateLevelEvent.TYPE:
- add_fields("level")
- elif event.type == RoomSendEventLevelEvent.TYPE:
- add_fields("level")
- elif event.type == RoomOpsPowerLevelsEvent.TYPE:
- add_fields("kick_level", "ban_level", "redact_level")
- elif event.type == RoomAliasesEvent.TYPE:
+ elif event_type == RoomPowerLevelsEvent.TYPE:
+ add_fields(
+ "users",
+ "users_default",
+ "events",
+ "events_default",
+ "events_default",
+ "state_default",
+ "ban",
+ "kick",
+ "redact",
+ )
+ elif event_type == RoomAliasesEvent.TYPE:
add_fields("aliases")
event.content = new_content
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b3dae5da64..43164c8d67 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -236,7 +236,10 @@ def setup():
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
- hs.start_listening(config.bind_port, config.unsecure_port)
+ bind_port = config.bind_port
+ if config.no_tls:
+ bind_port = None
+ hs.start_listening(bind_port, config.unsecure_port)
if config.daemonize:
print config.pid_file
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 3afda12d5a..814a4c349b 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -30,6 +30,7 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file)
self.webclient = True
self.manhole = args.manhole
+ self.no_tls = args.no_tls
if not args.content_addr:
host = args.server_name
@@ -67,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
+ server_group.add_argument("--no-tls", action='store_true',
+ help="Don't bind to the https port.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
new file mode 100644
index 0000000000..de5d2e7465
--- /dev/null
+++ b/synapse/crypto/event_signing.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.events.utils import prune_event
+from syutil.jsonutil import encode_canonical_json
+from syutil.base64util import encode_base64, decode_base64
+from syutil.crypto.jsonsign import sign_json
+from synapse.api.events.room import GenericEvent
+
+import copy
+import hashlib
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+ """Check whether the hash for this PDU matches the contents"""
+ computed_hash = _compute_content_hash(event, hash_algorithm)
+ if computed_hash.name not in event.hashes:
+ raise Exception("Algorithm %s not in hashes %s" % (
+ computed_hash.name, list(event.hashes)
+ ))
+ message_hash_base64 = event.hashes[computed_hash.name]
+ try:
+ message_hash_bytes = decode_base64(message_hash_base64)
+ except:
+ raise Exception("Invalid base64: %s" % (message_hash_base64,))
+ return message_hash_bytes == computed_hash.digest()
+
+
+def _compute_content_hash(event, hash_algorithm):
+ event_json = event.get_full_dict()
+ #TODO: We need to sign the JSON that is going out via fedaration.
+ event_json.pop("age_ts", None)
+ event_json.pop("unsigned", None)
+ event_json.pop("signatures", None)
+ event_json.pop("hashes", None)
+ event_json_bytes = encode_canonical_json(event_json)
+ return hash_algorithm(event_json_bytes)
+
+
+def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+ # FIXME(erikj): GenericEvent!
+ tmp_event = GenericEvent(**event.get_full_dict())
+ tmp_event = prune_event(tmp_event)
+ event_json = tmp_event.get_dict()
+ event_json.pop("signatures", None)
+ event_json.pop("age_ts", None)
+ event_json.pop("unsigned", None)
+ event_json_bytes = encode_canonical_json(event_json)
+ hashed = hash_algorithm(event_json_bytes)
+ return (hashed.name, hashed.digest())
+
+
+def compute_event_signature(event, signature_name, signing_key):
+ tmp_event = copy.deepcopy(event)
+ tmp_event = prune_event(tmp_event)
+ redact_json = tmp_event.get_full_dict()
+ redact_json.pop("signatures", None)
+ redact_json.pop("age_ts", None)
+ redact_json.pop("unsigned", None)
+ logger.debug("Signing event: %s", redact_json)
+ redact_json = sign_json(redact_json, signature_name, signing_key)
+ return redact_json["signatures"]
+
+
+def add_hashes_and_signatures(event, signature_name, signing_key,
+ hash_algorithm=hashlib.sha256):
+ hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
+
+ if not hasattr(event, "hashes"):
+ event.hashes = {}
+ event.hashes[hashed.name] = encode_base64(hashed.digest())
+
+ event.signatures = compute_event_signature(
+ event,
+ signature_name=signature_name,
+ signing_key=signing_key,
+ )
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index e8180d94fd..52c84efb5b 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -18,50 +18,25 @@ from .units import Pdu
import copy
-def decode_event_id(event_id, server_name):
- parts = event_id.split("@")
- if len(parts) < 2:
- return (event_id, server_name)
- else:
- return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
- return "%s@%s" % (pdu_id, origin)
-
-
class PduCodec(object):
def __init__(self, hs):
+ self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
+ self.hs = hs
def event_from_pdu(self, pdu):
kwargs = {}
- kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
- kwargs["room_id"] = pdu.context
- kwargs["etype"] = pdu.pdu_type
- kwargs["prev_events"] = [
- encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
- ]
-
- if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
- kwargs["prev_state"] = encode_event_id(
- pdu.prev_state_id, pdu.prev_state_origin
- )
+ kwargs["etype"] = pdu.type
kwargs.update({
k: v
for k, v in pdu.get_full_dict().items()
if k not in [
- "pdu_id",
- "context",
- "pdu_type",
- "prev_pdus",
- "prev_state_id",
- "prev_state_origin",
+ "type",
]
})
@@ -70,33 +45,10 @@ class PduCodec(object):
def pdu_from_event(self, event):
d = event.get_full_dict()
- d["pdu_id"], d["origin"] = decode_event_id(
- event.event_id, self.server_name
- )
- d["context"] = event.room_id
- d["pdu_type"] = event.type
-
- if hasattr(event, "prev_events"):
- d["prev_pdus"] = [
- decode_event_id(e, self.server_name)
- for e in event.prev_events
- ]
-
- if hasattr(event, "prev_state"):
- d["prev_state_id"], d["prev_state_origin"] = (
- decode_event_id(event.prev_state, self.server_name)
- )
-
- if hasattr(event, "state_key"):
- d["is_state"] = True
-
kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({
k: v for k, v in d.items()
- if k not in ["event_id", "room_id", "type", "prev_events"]
})
- if "origin_server_ts" not in kwargs:
- kwargs["origin_server_ts"] = int(self.clock.time_msec())
-
- return Pdu(**kwargs)
+ pdu = Pdu(**kwargs)
+ return pdu
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7043fcc504..73dc844d59 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
from twisted.internet import defer
-from .units import Pdu
-
from synapse.util.logutils import log_function
import json
@@ -32,76 +30,6 @@ import logging
logger = logging.getLogger(__name__)
-class PduActions(object):
- """ Defines persistence actions that relate to handling PDUs.
- """
-
- def __init__(self, datastore):
- self.store = datastore
-
- @log_function
- def mark_as_processed(self, pdu):
- """ Persist the fact that we have fully processed the given `Pdu`
-
- Returns:
- Deferred
- """
- return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
-
- @defer.inlineCallbacks
- @log_function
- def after_transaction(self, transaction_id, destination, origin):
- """ Returns all `Pdu`s that we sent to the given remote home server
- after a given transaction id.
-
- Returns:
- Deferred: Results in a list of `Pdu`s
- """
- results = yield self.store.get_pdus_after_transaction(
- transaction_id,
- destination
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def get_all_pdus_from_context(self, context):
- results = yield self.store.get_all_pdus_from_context(context)
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def backfill(self, context, pdu_list, limit):
- """ For a given list of PDU id and origins return the proceeding
- `limit` `Pdu`s in the given `context`.
-
- Returns:
- Deferred: Results in a list of `Pdu`s.
- """
- results = yield self.store.get_backfill(
- context, pdu_list, limit
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @log_function
- def is_new(self, pdu):
- """ When we receive a `Pdu` from a remote home server, we want to
- figure out whether it is `new`, i.e. it is not some historic PDU that
- we haven't seen simply because we haven't backfilled back that far.
-
- Returns:
- Deferred: Results in a `bool`
- """
- return self.store.is_pdu_new(
- pdu_id=pdu.pdu_id,
- origin=pdu.origin,
- context=pdu.context,
- depth=pdu.depth
- )
-
-
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
"""
@@ -158,7 +86,6 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
transaction.origin_server_ts,
- [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
)
@log_function
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 092411eaf9..92a9678e2c 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from .units import Transaction, Pdu, Edu
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
from synapse.util.logutils import log_function
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
- self.pdu_actions = PduActions(self.store)
+ # self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
@@ -81,7 +81,7 @@ class ReplicationLayer(object):
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type))
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
@@ -102,24 +102,17 @@ class ReplicationLayer(object):
object to encode as JSON.
"""
if query_type in self.query_handlers:
- raise KeyError("Already have a Query handler for %s" % (query_type))
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
- This will fill out various attributes on the PDU object, e.g. the
- `prev_pdus` key.
-
- *Note:* The home server should always call `send_pdu` even if it knows
- that it does not need to be replicated to other home servers. This is
- in case e.g. someone else joins via a remote home server and then
- backfills.
-
TODO: Figure out when we should actually resolve the deferred.
Args:
@@ -132,18 +125,15 @@ class ReplicationLayer(object):
order = self._order
self._order += 1
- logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
- # Save *before* trying to send
- yield self.store.persist_event(pdu=pdu)
-
- logger.debug("[%s] Persisted PDU", pdu.pdu_id)
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
- logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+ logger.debug(
+ "[%s] transaction_layer.enqueue_pdu... done",
+ pdu.event_id
+ )
@log_function
def send_edu(self, destination, edu_type, content):
@@ -159,6 +149,11 @@ class ReplicationLayer(object):
return defer.succeed(None)
@log_function
+ def send_failure(self, failure, destination):
+ self._transaction_queue.enqueue_failure(failure, destination)
+ return defer.succeed(None)
+
+ @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
@@ -181,7 +176,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def backfill(self, dest, context, limit):
+ def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
@@ -189,12 +184,12 @@ class ReplicationLayer(object):
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
- extremities = yield self.store.get_oldest_pdus_in_context(context)
-
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
@@ -216,7 +211,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
@@ -225,7 +220,7 @@ class ReplicationLayer(object):
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
- pdu_id (str)
+ event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -234,8 +229,9 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU.
"""
- transaction_data = yield self.transport_layer.get_pdu(
- destination, pdu_origin, pdu_id)
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
transaction = Transaction(**transaction_data)
@@ -244,13 +240,13 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, event_id=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,29 +259,32 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination,
+ context,
+ event_id=event_id,
+ )
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def on_context_pdus_request(self, context):
- pdus = yield self.pdu_actions.get_all_pdus_from_context(
- context
+ raise NotImplementedError(
+ "on_context_pdus_request is a security violation"
)
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, context, versions, limit):
-
- pdus = yield self.pdu_actions.backfill(context, versions, limit)
+ pdus = yield self.handler.on_backfill_request(
+ context, versions, limit
+ )
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@@ -295,6 +294,10 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
@@ -315,11 +318,15 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(transaction.origin, edu.edu_type, edu.content)
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
results = yield defer.DeferredList(dl)
@@ -347,20 +354,26 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
-
- logger.debug("Context returning %d results", len(results))
+ def on_context_state_request(self, context, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ event_id
+ )
+ else:
+ raise NotImplementedError("Specify an event")
+ # results = yield self.store.get_current_state_for_context(
+ # context
+ # )
+ # pdus = [Pdu.from_pdu_tuple(p) for p in results]
+ #
+ # logger.debug("Context returning %d results", len(pdus))
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
- def on_pdu_request(self, pdu_origin, pdu_id):
- pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+ def on_pdu_request(self, event_id):
+ pdu = yield self._get_persisted_pdu(event_id)
if pdu:
defer.returnValue(
@@ -372,20 +385,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
- transaction_id = max([int(v) for v in versions])
-
- response = yield self.pdu_actions.after_transaction(
- transaction_id,
- origin,
- self.server_name
- )
-
- if not response:
- response = []
-
- defer.returnValue(
- (200, self._transaction_from_pdus(response).get_dict())
- )
+ raise NotImplementedError("Pull transacions not implemented")
+
+ # transaction_id = max([int(v) for v in versions])
+ #
+ # response = yield self.pdu_actions.after_transaction(
+ # transaction_id,
+ # origin,
+ # self.server_name
+ # )
+ #
+ # if not response:
+ # response = []
+ #
+ # defer.returnValue(
+ # (200, self._transaction_from_pdus(response).get_dict())
+ # )
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
@@ -393,82 +408,138 @@ class ReplicationLayer(object):
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue(pdu.get_dict())
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, ret_pdu.get_dict()))
@defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = Pdu(**content)
+ state = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, context, user_id):
+ pdu_dict = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.room_id,
+ pdu.event_id,
+ pdu.get_dict(),
+ )
+
+ logger.debug("Got content: %s", content)
+ pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+ for pdu in pdus:
+ yield self._handle_new_pdu(destination, pdu)
+
+ defer.returnValue(pdus)
+
@log_function
- def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ def _get_persisted_pdu(self, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
- pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
- defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+ return self.handler.get_persisted_pdu(event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
pdus = [p.get_dict() for p in pdu_list]
+ time_now = self._clock.time_msec()
for p in pdus:
- if "age_ts" in pdus:
- p["age"] = int(self.clock.time_msec()) - p["age_ts"]
-
+ if "age_ts" in p:
+ age = time_now - p["age_ts"]
+ p.setdefault("unsigned", {})["age"] = int(age)
+ del p["age_ts"]
return Transaction(
origin=self.server_name,
pdus=pdus,
- origin_server_ts=int(self._clock.time_msec()),
+ origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+ existing = yield self._get_persisted_pdu(pdu.event_id)
if existing and (not existing.outlier or pdu.outlier):
- logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
- is_new = yield self.pdu_actions.is_new(pdu)
- if is_new and not pdu.outlier:
+ if not pdu.outlier:
# We only backfill backwards to the min depth.
- min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
if min_depth and pdu.depth > min_depth:
- for pdu_id, origin in pdu.prev_pdus:
- exists = yield self._get_persisted_pdu(pdu_id, origin)
+ for event_id, hashes in pdu.prev_events:
+ exists = yield self._get_persisted_pdu(event_id)
if not exists:
- logger.debug("Requesting pdu %s %s", pdu_id, origin)
+ logger.debug("Requesting pdu %s", event_id)
try:
yield self.get_pdu(
pdu.origin,
- pdu_id=pdu_id,
- pdu_origin=origin
+ event_id=event_id,
)
- logger.debug("Processed pdu %s %s", pdu_id, origin)
+ logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.room_id, pdu.event_id,
+ )
# Persist the Pdu, but don't mark it as processed yet.
- yield self.store.persist_event(pdu=pdu)
+ # yield self.store.persist_event(pdu=pdu)
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
- yield self.pdu_actions.mark_as_processed(pdu)
+ # yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
@@ -476,14 +547,6 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name
-class ReplicationHandler(object):
- """This defines the methods that the :py:class:`.ReplicationLayer` will
- use to communicate with the rest of the home server.
- """
- def on_receive_pdu(self, pdu):
- raise NotImplementedError("on_receive_pdu")
-
-
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
@@ -509,6 +572,9 @@ class _TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
+ # destination -> list of tuple(failure, deferred)
+ self.pending_failures_by_dest = {}
+
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@@ -562,6 +628,18 @@ class _TransactionQueue(object):
return deferred
@defer.inlineCallbacks
+ def enqueue_failure(self, failure, destination):
+ deferred = defer.Deferred()
+
+ self.pending_failures_by_dest.setdefault(
+ destination, []
+ ).append(
+ (failure, deferred)
+ )
+
+ yield deferred
+
+ @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
if destination in self.pending_transactions:
@@ -570,8 +648,9 @@ class _TransactionQueue(object):
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if not pending_pdus and not pending_edus:
+ if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug("TX [%s] Attempting new transaction", destination)
@@ -581,7 +660,11 @@ class _TransactionQueue(object):
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
- deferreds = [x[1] for x in pending_pdus + pending_edus]
+ failures = [x[0].get_dict() for x in pending_failures]
+ deferreds = [
+ x[1]
+ for x in pending_pdus + pending_edus + pending_failures
+ ]
try:
self.pending_transactions[destination] = 1
@@ -589,12 +672,13 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
- origin_server_ts=self._clock.time_msec(),
+ origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
+ pdu_failures=failures,
)
self._next_txn_id += 1
@@ -614,7 +698,9 @@ class _TransactionQueue(object):
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
- p["age"] = now - int(p["age_ts"])
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..04ad7e63ae 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,7 @@ class TransportLayer(object):
self.received_handler = None
@log_function
- def get_context_state(self, destination, context):
+ def get_context_state(self, destination, context, event_id=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@@ -89,54 +89,62 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
- return self._do_request_for_transaction(destination, subpath)
+ args = {}
+ if event_id:
+ args["event_id"] = event_id
+
+ return self._do_request_for_transaction(
+ destination, subpath, args=args
+ )
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id):
+ def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
- pdu_origin (str): The home server which created the PDU.
- pdu_id (str): The id of the PDU being requested.
+ event_id (str): The id of the event being requested.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
- destination, pdu_origin, pdu_id)
+ logger.debug("get_pdu dest=%s, event_id=%s",
+ destination, event_id)
- subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+ subpath = "/event/%s/" % (event_id, )
return self._do_request_for_transaction(destination, subpath)
@log_function
- def backfill(self, dest, context, pdu_tuples, limit):
+ def backfill(self, dest, context, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
dest (str)
context (str)
- pdu_tuples (list)
+ event_tuples (list)
limt (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
- dest, context, repr(pdu_tuples), str(limit)
+ "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+ dest, context, repr(event_tuples), str(limit)
)
- if not pdu_tuples:
+ if not event_tuples:
+ # TODO: raise?
return
- subpath = "/backfill/%s/" % context
+ subpath = "/backfill/%s/" % (context,)
- args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
- args["limit"] = limit
+ args = {
+ "v": event_tuples,
+ "limit": limit,
+ }
return self._do_request_for_transaction(
dest,
@@ -198,6 +206,57 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
+ @log_function
+ def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+ path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ retry_on_dns_fail=retry_on_dns_fail,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_join(self, destination, context, event_id, content):
+ path = PREFIX + "/send_join/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, event_id, content):
+ path = PREFIX + "/invite/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
@@ -313,10 +372,10 @@ class TransportLayer(object):
# data_id pair.
self.server.register_path(
"GET",
- re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+ re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
- lambda origin, content, query, pdu_origin, pdu_id:
- handler.on_pdu_request(pdu_origin, pdu_id)
+ lambda origin, content, query, event_id:
+ handler.on_pdu_request(event_id)
)
)
@@ -326,7 +385,10 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
- handler.on_context_state_request(context)
+ handler.on_context_state_request(
+ context,
+ query.get("event_id", [None])[0],
+ )
)
)
@@ -362,6 +424,39 @@ class TransportLayer(object):
)
)
+ self.server.register_path(
+ "GET",
+ re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, user_id:
+ self._on_make_join_request(
+ origin, content, query, context, user_id
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_send_join_request(
+ origin, content, query,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_invite_request(
+ origin, content, query,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -448,124 +543,34 @@ class TransportLayer(object):
limit = int(limits[-1])
- versions = [v.split(",", 1) for v in v_list]
+ versions = v_list
return self.request_handler.on_backfill_request(
- context, versions, limit)
-
-
-class TransportReceivedHandler(object):
- """ Callbacks used when we receive a transaction
- """
- def on_incoming_transaction(self, transaction):
- """ Called on PUT /send/<transaction_id>, or on response to a request
- that we sent (e.g. a backfill request)
-
- Args:
- transaction (synapse.transaction.Transaction): The transaction that
- was sent to us.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
-
-class TransportRequestHandler(object):
- """ Handlers used when someone want's data from us
- """
- def on_pull_request(self, versions):
- """ Called on GET /pull/?v=...
-
- This is hit when a remote home server wants to get all data
- after a given transaction. Mainly used when a home server comes back
- online and wants to get everything it has missed.
-
- Args:
- versions (list): A list of transaction_ids that should be used to
- determine what PDUs the remote side have not yet seen.
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_pdu_request(self, pdu_origin, pdu_id):
- """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
- Someone wants a particular PDU. This PDU may or may not have originated
- from us.
-
- Args:
- pdu_origin (str)
- pdu_id (str)
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_context_state_request(self, context):
- """ Called on GET /state/<context>/
-
- Gets hit when someone wants all the *current* state for a given
- contexts.
-
- Args:
- context (str): The name of the context that we're interested in.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_backfill_request(self, context, versions, limit):
- """ Called on GET /backfill/<context>/?v=...&limit=...
+ context, versions, limit
+ )
- Gets hit when we want to backfill backwards on a given context from
- the given point.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_make_join_request(self, origin, content, query, context, user_id):
+ content = yield self.request_handler.on_make_join_request(
+ context, user_id,
+ )
+ defer.returnValue((200, content))
- Args:
- context (str): The context to backfill
- versions (list): A list of 2-tuples representing where to backfill
- from, in the form `(pdu_id, origin)`
- limit (int): How many pdus to return.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_join_request(self, origin, content, query):
+ content = yield self.request_handler.on_send_join_request(
+ origin, content,
+ )
- Returns:
- Deferred: Results in a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
+ defer.returnValue((200, content))
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
- def on_query_request(self):
- """ Called on a GET /query/<query_type> request. """
+ defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2fb964180..2070ffe1e2 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -20,8 +20,6 @@ server protocol.
from synapse.util.jsonobject import JsonEncodedObject
import logging
-import json
-import copy
logger = logging.getLogger(__name__)
@@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject):
A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done
- via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+ via a unique constraint on the tuple (context, type, state_key). A pdu
is a state pdu if `is_state` is True.
Example pdu::
{
- "pdu_id": "78c",
+ "event_id": "$78c:example.com",
"origin_server_ts": 1404835423000,
"origin": "bar",
"prev_ids": [
@@ -52,22 +50,21 @@ class Pdu(JsonEncodedObject):
"""
valid_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"destinations",
"transaction_id",
- "prev_pdus",
+ "prev_events",
"depth",
"content",
"outlier",
- "is_state", # Below this are keys valid only for State Pdus.
+ "hashes",
+ "signatures", # Below this are keys valid only for State Pdus.
"state_key",
- "power_level",
- "prev_state_id",
- "prev_state_origin",
+ "prev_state",
"required_power_level",
"user_id",
]
@@ -79,61 +76,28 @@ class Pdu(JsonEncodedObject):
]
required_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"content",
]
# TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!)
- def __init__(self, destinations=[], is_state=False, prev_pdus=[],
- outlier=False, **kwargs):
- if is_state:
- for required_key in ["state_key"]:
- if required_key not in kwargs:
- raise RuntimeError("Key %s is required" % required_key)
-
+ def __init__(self, destinations=[], prev_events=[],
+ outlier=False, hashes={}, signatures={}, **kwargs):
super(Pdu, self).__init__(
destinations=destinations,
- is_state=is_state,
- prev_pdus=prev_pdus,
+ prev_events=prev_events,
outlier=outlier,
+ hashes=hashes,
+ signatures=signatures,
**kwargs
)
- @classmethod
- def from_pdu_tuple(cls, pdu_tuple):
- """ Converts a PduTuple to a Pdu
-
- Args:
- pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
- convert
-
- Returns:
- Pdu
- """
- if pdu_tuple:
- d = copy.copy(pdu_tuple.pdu_entry._asdict())
- d["origin_server_ts"] = d.pop("ts")
-
- d["content"] = json.loads(d["content_json"])
- del d["content_json"]
-
- args = {f: d[f] for f in cls.valid_keys if f in d}
- if "unrecognized_keys" in d and d["unrecognized_keys"]:
- args.update(json.loads(d["unrecognized_keys"]))
-
- return Pdu(
- prev_pdus=pdu_tuple.prev_pdu_list,
- **args
- )
- else:
- return None
-
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
@@ -193,6 +157,7 @@ class Transaction(JsonEncodedObject):
"edus",
"transaction_id",
"destination",
+ "pdu_failures",
]
internal_keys = [
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index de4d23bbb3..f630280031 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,7 +14,16 @@
# limitations under the License.
from twisted.internet import defer
+
from synapse.api.errors import LimitExceededError
+from synapse.util.async import run_on_reactor
+from synapse.crypto.event_signing import add_hashes_and_signatures
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
class BaseHandler(object):
@@ -30,6 +39,9 @@ class BaseHandler(object):
self.clock = hs.get_clock()
self.hs = hs
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@@ -44,9 +56,30 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
- extra_users=[]):
+ extra_users=[], suppress_auth=False):
+ yield run_on_reactor()
+
snapshot.fill_out_prev_events(event)
+ yield self.state_handler.annotate_state_groups(event)
+
+ yield self.auth.add_auth_events(event)
+
+ logger.debug("Signing event...")
+
+ add_hashes_and_signatures(
+ event, self.server_name, self.signing_key
+ )
+
+ logger.debug("Signed event.")
+
+ if not suppress_auth:
+ logger.debug("Authing...")
+ self.auth.check(event, raises=True)
+ logger.debug("Authed")
+ else:
+ logger.debug("Suppressed auth.")
+
yield self.store.persist_event(event)
destinations = set(extra_destinations)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a56830d520..164363cdc5 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -147,10 +147,8 @@ class DirectoryHandler(BaseHandler):
content={"aliases": aliases},
)
- snapshot = yield self.store.snapshot_room(
- room_id=room_id,
- user_id=user_id,
- )
+ snapshot = yield self.store.snapshot_room(event)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user_id], suppress_auth=True
+ )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f52591d2a3..09593303a4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,13 +17,14 @@
from ._base import BaseHandler
-from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
+from synapse.api.errors import AuthError, FederationError
+from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import Membership
from synapse.util.logutils import log_function
from synapse.federation.pdu_codec import PduCodec
-from synapse.api.errors import SynapseError
+from synapse.util.async import run_on_reactor
-from twisted.internet import defer, reactor
+from twisted.internet import defer
import logging
@@ -62,6 +63,9 @@ class FederationHandler(BaseHandler):
self.pdu_codec = PduCodec(hs)
+ # When joining a room we need to queue any events for that room up
+ self.room_queues = {}
+
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
@@ -78,6 +82,8 @@ class FederationHandler(BaseHandler):
processing.
"""
+ yield run_on_reactor()
+
pdu = self.pdu_codec.pdu_from_event(event)
if not hasattr(pdu, "destinations") or not pdu.destinations:
@@ -87,97 +93,88 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, pdu, backfilled):
+ def on_receive_pdu(self, pdu, backfilled, state=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it throught the StateHandler.
+ do auth checks and put it through the StateHandler.
"""
event = self.pdu_codec.event_from_pdu(pdu)
logger.debug("Got event: %s", event.event_id)
- with (yield self.lock_manager.lock(pdu.context)):
- if event.is_state and not backfilled:
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
- else:
- is_new_state = False
- # TODO: Implement something in federation that allows us to
- # respond to PDU.
+ if event.room_id in self.room_queues:
+ self.room_queues[event.room_id].append(pdu)
+ return
- target_is_mine = False
- if hasattr(event, "target_host"):
- target_is_mine = event.target_host == self.hs.hostname
-
- if event.type == InviteJoinEvent.TYPE:
- if not target_is_mine:
- logger.debug("Ignoring invite/join event %s", event)
- return
-
- # If we receive an invite/join event then we need to join the
- # sender to the given room.
- # TODO: We should probably auth this or some such
- content = event.content
- content.update({"membership": Membership.JOIN})
- new_event = self.event_factory.create_event(
- etype=RoomMemberEvent.TYPE,
- state_key=event.user_id,
- room_id=event.room_id,
- user_id=event.user_id,
- membership=Membership.JOIN,
- content=content
- )
+ logger.debug("Processing event: %s", event.event_id)
+
+ if state:
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
- yield self.hs.get_handlers().room_member_handler.change_membership(
- new_event,
- do_auth=False,
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
+
+ logger.debug("Event: %s", event)
+
+ try:
+ self.auth.check(event, raises=True)
+ except AuthError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=event.event_id,
)
- else:
- with (yield self.room_lock.lock(event.room_id)):
- yield self.store.persist_event(
- event,
- backfilled,
- is_new_state=is_new_state
- )
+ is_new_state = is_new_state and not backfilled
- room = yield self.store.get_room(event.room_id)
+ # TODO: Implement something in federation that allows us to
+ # respond to PDU.
- if not room:
- # Huh, let's try and get the current state
- try:
- yield self.replication_layer.get_state_for_context(
- event.origin, event.room_id
- )
+ yield self.store.persist_event(
+ event,
+ backfilled,
+ is_new_state=is_new_state
+ )
- hosts = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
- if self.hs.hostname in hosts:
- try:
- yield self.store.store_room(
- room_id=event.room_id,
- room_creator_user_id="",
- is_public=False,
- )
- except:
- pass
- except:
- logger.exception(
- "Failed to get current state for room %s",
- event.room_id
- )
+ room = yield self.store.get_room(event.room_id)
- if not backfilled:
- extra_users = []
- if event.type == RoomMemberEvent.TYPE:
- target_user_id = event.state_key
- target_user = self.hs.parse_userid(target_user_id)
- extra_users.append(target_user)
+ if not room:
+ # Huh, let's try and get the current state
+ try:
+ yield self.replication_layer.get_state_for_context(
+ event.origin, event.room_id, event.event_id,
+ )
- yield self.notifier.on_new_room_event(
- event, extra_users=extra_users
+ hosts = yield self.store.get_joined_hosts_for_room(
+ event.room_id
)
+ if self.hs.hostname in hosts:
+ try:
+ yield self.store.store_room(
+ room_id=event.room_id,
+ room_creator_user_id="",
+ is_public=False,
+ )
+ except:
+ pass
+ except:
+ logger.exception(
+ "Failed to get current state for room %s",
+ event.room_id
+ )
+
+ if not backfilled:
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
@@ -189,13 +186,28 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit):
- pdus = yield self.replication_layer.backfill(dest, room_id, limit)
+ extremities = yield self.store.get_oldest_events_in_room(room_id)
+
+ pdus = yield self.replication_layer.backfill(
+ dest,
+ room_id,
+ limit,
+ extremities=[
+ self.pdu_codec.decode_event_id(e)
+ for e in extremities
+ ]
+ )
events = []
for pdu in pdus:
event = self.pdu_codec.event_from_pdu(pdu)
+
+ # FIXME (erikj): Not sure this actually works :/
+ yield self.state_handler.annotate_state_groups(event)
+
events.append(event)
+
yield self.store.persist_event(event, backfilled=True)
defer.returnValue(events)
@@ -203,62 +215,231 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# We are already in the room.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
- # First get current state to see if we are already joined.
+ pdu = yield self.replication_layer.make_join(
+ target_host,
+ room_id,
+ joinee
+ )
+
+ logger.debug("Got response to make_join: %s", pdu)
+
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ # We should assert some things.
+ assert(event.type == RoomMemberEvent.TYPE)
+ assert(event.user_id == joinee)
+ assert(event.state_key == joinee)
+ assert(event.room_id == room_id)
+
+ event.outlier = False
+
+ self.room_queues[room_id] = []
+
try:
- yield self.replication_layer.get_state_for_context(
- target_host, room_id
+ event.event_id = self.event_factory.create_event_id()
+ event.content = content
+
+ state = yield self.replication_layer.send_join(
+ target_host,
+ self.pdu_codec.pdu_from_event(event)
)
- hosts = yield self.store.get_joined_hosts_for_room(room_id)
- if self.hs.hostname in hosts:
- # Oh, we were actually in the room already.
- logger.debug("We're already in the room apparently")
- defer.returnValue(False)
- except Exception:
- logger.exception("Failed to get current state")
-
- new_event = self.event_factory.create_event(
- etype=InviteJoinEvent.TYPE,
- target_host=target_host,
- room_id=room_id,
- user_id=joinee,
- content=content
- )
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
- new_event.destinations = [target_host]
+ logger.debug("do_invite_join state: %s", state)
- snapshot.fill_out_prev_events(new_event)
- yield self.handle_new_event(new_event, snapshot)
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
- # TODO (erikj): Time out here.
- d = defer.Deferred()
- self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
- reactor.callLater(10, d.cancel)
+ logger.debug("do_invite_join event: %s", event)
- try:
- yield d
- except defer.CancelledError:
- raise SynapseError(500, "Unable to join remote room")
+ try:
+ yield self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=False
+ )
+ except:
+ # FIXME
+ pass
- try:
- yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
+ for e in state:
+ # FIXME: Auth these.
+ e.outlier = True
+
+ yield self.state_handler.annotate_state_groups(
+ e,
+ )
+
+ yield self.store.persist_event(
+ e,
+ backfilled=False,
+ is_new_state=False
+ )
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
)
- except:
- pass
+ finally:
+ room_queue = self.room_queues[room_id]
+ del self.room_queues[room_id]
+ for p in room_queue:
+ try:
+ yield self.on_receive_pdu(p, backfilled=False)
+ except:
+ pass
defer.returnValue(True)
+ @defer.inlineCallbacks
+ @log_function
+ def on_make_join_request(self, context, user_id):
+ event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ content={"membership": Membership.JOIN},
+ room_id=context,
+ user_id=user_id,
+ state_key=user_id,
+ )
+
+ snapshot = yield self.store.snapshot_room(event)
+ snapshot.fill_out_prev_events(event)
+
+ yield self.state_handler.annotate_state_groups(event)
+ yield self.auth.add_auth_events(event)
+ self.auth.check(event, raises=True)
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_send_join_request(self, origin, pdu):
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ event.outlier = False
+
+ is_new_state = yield self.state_handler.annotate_state_groups(event)
+ self.auth.check(event, raises=True)
+
+ # FIXME (erikj): All this is duplicated above :(
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
+ )
+
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
+
+ if event.type == RoomMemberEvent.TYPE:
+ if event.membership == Membership.JOIN:
+ user = self.hs.parse_userid(event.state_key)
+ self.distributor.fire(
+ "user_joined_room", user=user, room_id=event.room_id
+ )
+
+ new_pdu = self.pdu_codec.pdu_from_event(event);
+ new_pdu.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+
+ yield self.replication_layer.send_pdu(new_pdu)
+
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in event.state_events.values()
+ ])
+
+ @defer.inlineCallbacks
+ def get_state_for_pdu(self, event_id):
+ yield run_on_reactor()
+
+ state_groups = yield self.store.get_state_groups(
+ [event_id]
+ )
+
+ if state_groups:
+ results = {
+ (e.type, e.state_key): e for e in state_groups[0].state
+ }
+
+ event = yield self.store.get_event(event_id)
+ if hasattr(event, "state_key"):
+ # Get previous state
+ if hasattr(event, "replaces_state") and event.replaces_state:
+ prev_event = yield self.store.get_event(
+ event.replaces_state
+ )
+ results[(event.type, event.state_key)] = prev_event
+ else:
+ del results[(event.type, event.state_key)]
+
+ defer.returnValue(
+ [
+ self.pdu_codec.pdu_from_event(s)
+ for s in results.values()
+ ]
+ )
+ else:
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_backfill_request(self, context, pdu_list, limit):
+
+ events = yield self.store.get_backfill_events(
+ context,
+ pdu_list,
+ limit
+ )
+
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in events
+ ])
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_persisted_pdu(self, event_id):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ event = yield self.store.get_event(
+ event_id,
+ allow_none=True,
+ )
+
+ if event:
+ defer.returnValue(self.pdu_codec.pdu_from_event(event))
+ else:
+ defer.returnValue(None)
+
+ @log_function
+ def get_min_depth_for_context(self, context):
+ return self.store.get_min_depth(context)
@log_function
def _on_user_joined(self, user, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 72894869ea..8394013df3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,12 +81,11 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,)
- snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
+ snapshot = yield self.store.snapshot_room(event)
- if not suppress_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self._on_new_room_event(event, snapshot)
+ yield self._on_new_room_event(
+ event, snapshot, suppress_auth=suppress_auth
+ )
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
@@ -142,16 +141,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong.
"""
- snapshot = yield self.store.snapshot_room(
- event.room_id,
- event.user_id,
- state_type=event.type,
- state_key=event.state_key,
- )
-
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
+ snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(event, snapshot)
@@ -201,7 +191,7 @@ class MessageHandler(BaseHandler):
raise RoomError(
403, "Member does not meet private room rules.")
- data = yield self.store.get_current_state(
+ data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@@ -219,9 +209,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def send_feedback(self, event):
- snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
-
- yield self.auth.check(event, snapshot, raises=True)
+ snapshot = yield self.store.snapshot_room(event)
# store message in db
yield self._on_new_room_event(event, snapshot)
@@ -239,7 +227,7 @@ class MessageHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
- current_state = yield self.store.get_current_state(room_id)
+ current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
@@ -316,7 +304,7 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
- current_state = yield self.store.get_current_state(
+ current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dab9b03f04..e47814483a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership
-from synapse.api.events.room import RoomMemberEvent
from ._base import BaseHandler
@@ -196,10 +195,7 @@ class ProfileHandler(BaseHandler):
)
for j in joins:
- snapshot = yield self.store.snapshot_room(
- j.room_id, j.state_key, RoomMemberEvent.TYPE,
- j.state_key
- )
+ snapshot = yield self.store.snapshot_room(j)
content = {
"membership": j.content["membership"],
@@ -218,5 +214,6 @@ class ProfileHandler(BaseHandler):
user_id=j.state_key,
)
- yield self.state_handler.handle_new_event(new_event, snapshot)
- yield self._on_new_room_event(new_event, snapshot)
+ yield self._on_new_room_event(
+ new_event, snapshot, suppress_auth=True
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 81ce1a5907..42a6c9f9bf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -21,8 +21,7 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
- RoomJoinRulesEvent, RoomAddStateLevelEvent, RoomTopicEvent,
- RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent,
+ RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
)
from synapse.util import stringutils
from ._base import BaseHandler
@@ -122,15 +121,13 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def handle_event(event):
- snapshot = yield self.store.snapshot_room(
- room_id=room_id,
- user_id=user_id,
- )
+ snapshot = yield self.store.snapshot_room(event)
logger.debug("Event: %s", event)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user], suppress_auth=True
+ )
for event in creation_events:
yield handle_event(event)
@@ -141,7 +138,6 @@ class RoomCreationHandler(BaseHandler):
etype=RoomNameEvent.TYPE,
room_id=room_id,
user_id=user_id,
- required_power_level=50,
content={"name": name},
)
@@ -153,7 +149,6 @@ class RoomCreationHandler(BaseHandler):
etype=RoomTopicEvent.TYPE,
room_id=room_id,
user_id=user_id,
- required_power_level=50,
content={"topic": topic},
)
@@ -198,7 +193,6 @@ class RoomCreationHandler(BaseHandler):
event_keys = {
"room_id": room_id,
"user_id": creator.to_string(),
- "required_power_level": 100,
}
def create(etype, **content):
@@ -215,7 +209,21 @@ class RoomCreationHandler(BaseHandler):
power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE,
- content={creator.to_string(): 100, "default": 0},
+ content={
+ "users": {
+ creator.to_string(): 100,
+ },
+ "users_default": 0,
+ "events": {
+ RoomNameEvent.TYPE: 100,
+ RoomPowerLevelsEvent.TYPE: 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50
+ },
**event_keys
)
@@ -225,30 +233,10 @@ class RoomCreationHandler(BaseHandler):
join_rule=join_rule,
)
- add_state_event = create(
- etype=RoomAddStateLevelEvent.TYPE,
- level=100,
- )
-
- send_event = create(
- etype=RoomSendEventLevelEvent.TYPE,
- level=0,
- )
-
- ops = create(
- etype=RoomOpsPowerLevelsEvent.TYPE,
- ban_level=50,
- kick_level=50,
- redact_level=50,
- )
-
return [
creation_event,
power_levels_event,
join_rules_event,
- add_state_event,
- send_event,
- ops,
]
@@ -363,10 +351,8 @@ class RoomMemberHandler(BaseHandler):
"""
target_user_id = event.state_key
- snapshot = yield self.store.snapshot_room(
- event.room_id, event.user_id,
- RoomMemberEvent.TYPE, target_user_id
- )
+ snapshot = yield self.store.snapshot_room(event)
+
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
@@ -391,29 +377,17 @@ class RoomMemberHandler(BaseHandler):
yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- # If we're banning someone, set a req power level
- if event.membership == Membership.BAN:
- if not hasattr(event, "required_power_level") or event.required_power_level is None:
- # Add some default required_power_level
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
- event.required_power_level = user_level
if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP.
defer.returnValue({})
return
- yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
defer.returnValue({"room_id": room_id})
@@ -443,10 +417,7 @@ class RoomMemberHandler(BaseHandler):
content=content,
)
- snapshot = yield self.store.snapshot_room(
- room_id, joinee.to_string(), RoomMemberEvent.TYPE,
- joinee.to_string()
- )
+ snapshot = yield self.store.snapshot_room(new_event)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
@@ -502,14 +473,11 @@ class RoomMemberHandler(BaseHandler):
if not have_joined:
logger.debug("Doing normal join")
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
user = self.hs.parse_userid(event.user_id)
@@ -553,7 +521,8 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
- def _do_local_membership_update(self, event, membership, snapshot):
+ def _do_local_membership_update(self, event, membership, snapshot,
+ do_auth):
destinations = []
# If we're inviting someone, then we should also send it to that
@@ -570,9 +539,10 @@ class RoomMemberHandler(BaseHandler):
return self._on_new_room_event(
event, snapshot, extra_destinations=destinations,
- extra_users=[target_user]
+ extra_users=[target_user], suppress_auth=(not do_auth),
)
+
class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..dc784c1527 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
+import logging
+
+
+logger = logging.getLogger(__name__)
+
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
-
- handler = self.handlers.event_stream_handler
- pagin_config = PaginationConfig.from_request(request)
- timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
- try:
- timeout = int(request.args["timeout"][0])
- except ValueError:
- raise SynapseError(400, "timeout must be in milliseconds.")
-
- chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
- timeout=timeout)
+ try:
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ chunk = yield handler.get_stream(
+ auth_user.to_string(), pagin_config, timeout=timeout
+ )
+ except:
+ logger.exception("Event stream failed")
+ raise
defer.returnValue((200, chunk))
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
index 7724967061..5c9c9d3af4 100644
--- a/synapse/rest/room.py
+++ b/synapse/rest/room.py
@@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
- defer.returnValue((200, data[0].get_dict()["content"]))
+ defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
diff --git a/synapse/server.py b/synapse/server.py
index a4d2d4aba5..d770b20b19 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -28,7 +28,7 @@ from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, EventID
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
@@ -143,6 +143,11 @@ class BaseHomeServer(object):
object."""
return RoomID.from_string(s, hs=self)
+ def parse_eventid(self, s):
+ """Parse the string given by 's' as a Event ID and return a EventID
+ object."""
+ return EventID.from_string(s, hs=self)
+
def serialize_event(self, e):
return serialize_event(self, e)
diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..e2fd48bdae 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -16,11 +16,13 @@
from twisted.internet import defer
-from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function
+from synapse.util.async import run_on_reactor
+from synapse.api.events.room import RoomPowerLevelsEvent
from collections import namedtuple
+import copy
import logging
import hashlib
@@ -35,230 +37,154 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object):
- """ Repsonsible for doing state conflict resolution.
+ """ Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self._replication = hs.get_replication_layer()
self.server_name = hs.hostname
+ self.hs = hs
@defer.inlineCallbacks
@log_function
- def handle_new_event(self, event, snapshot):
- """ Given an event this works out if a) we have sufficient power level
- to update the state and b) works out what the prev_state should be.
+ def annotate_state_groups(self, event, old_state=None):
+ yield run_on_reactor()
- Returns:
- Deferred: Resolved with a boolean indicating if we succesfully
- updated the state.
+ if old_state:
+ event.state_group = None
+ event.old_state_events = {
+ (s.type, s.state_key): s for s in old_state
+ }
+ event.state_events = event.old_state_events
- Raised:
- AuthError
- """
- # This needs to be done in a transaction.
+ if hasattr(event, "state_key"):
+ event.state_events[(event.type, event.state_key)] = event
- if not hasattr(event, "state_key"):
+ defer.returnValue(False)
return
- key = KeyStateTuple(
- event.room_id,
- event.type,
- _get_state_key_from_event(event)
- )
-
- # Now I need to fill out the prev state and work out if it has auth
- # (w.r.t. to power levels)
-
- snapshot.fill_out_prev_events(event)
-
- event.prev_events = [
- e for e in event.prev_events if e != event.event_id
- ]
-
- current_state = snapshot.prev_state_pdu
-
- if current_state:
- event.prev_state = encode_event_id(
- current_state.pdu_id, current_state.origin
- )
-
- # TODO check current_state to see if the min power level is less
- # than the power level of the user
- # power_level = self._get_power_level_for_event(event)
-
- pdu_id, origin = decode_event_id(event.event_id, self.server_name)
+ if hasattr(event, "outlier") and event.outlier:
+ event.state_group = None
+ event.old_state_events = None
+ event.state_events = {}
+ defer.returnValue(False)
+ return
- yield self.store.update_current_state(
- pdu_id=pdu_id,
- origin=origin,
- context=key.context,
- pdu_type=key.type,
- state_key=key.state_key
+ new_state = yield self.resolve_state_groups(
+ [e for e, _ in event.prev_events]
)
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- @log_function
- def handle_new_state(self, new_pdu):
- """ Apply conflict resolution to `new_pdu`.
+ event.old_state_events = copy.deepcopy(new_state)
- This should be called on every new state pdu, regardless of whether or
- not there is a conflict.
+ if hasattr(event, "state_key"):
+ key = (event.type, event.state_key)
+ if key in new_state:
+ event.replaces_state = new_state[key].event_id
+ new_state[key] = event
- This function is safe against the race of it getting called with two
- `PDU`s trying to update the same state.
- """
+ event.state_group = None
+ event.state_events = new_state
- # This needs to be done in a transaction.
+ defer.returnValue(hasattr(event, "state_key"))
- is_new = yield self._handle_new_state(new_pdu)
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ events = yield self.store.get_latest_events_in_room(room_id)
- logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
+ event_ids = [
+ e_id
+ for e_id, _, _ in events
+ ]
- if is_new:
- yield self.store.update_current_state(
- pdu_id=new_pdu.pdu_id,
- origin=new_pdu.origin,
- context=new_pdu.context,
- pdu_type=new_pdu.pdu_type,
- state_key=new_pdu.state_key
- )
+ res = yield self.resolve_state_groups(event_ids)
- defer.returnValue(is_new)
+ if event_type:
+ defer.returnValue(res.get((event_type, state_key)))
+ return
- def _get_power_level_for_event(self, event):
- # return self._persistence.get_power_level_for_user(event.room_id,
- # event.sender)
- return event.power_level
+ defer.returnValue(res.values())
@defer.inlineCallbacks
@log_function
- def _handle_new_state(self, new_pdu):
- tree, missing_branch = yield self.store.get_unresolved_state_tree(
- new_pdu
+ def resolve_state_groups(self, event_ids):
+ state_groups = yield self.store.get_state_groups(
+ event_ids
)
- new_branch, current_branch = tree
-
- logger.debug(
- "_handle_new_state new=%s, current=%s",
- new_branch, current_branch
- )
-
- if missing_branch is not None:
- # We're missing some PDUs. Fetch them.
- # TODO (erikj): Limit this.
- missing_prev = tree[missing_branch][-1]
-
- pdu_id = missing_prev.prev_state_id
- origin = missing_prev.prev_state_origin
-
- is_missing = yield self.store.get_pdu(pdu_id, origin) is None
- if not is_missing:
- raise Exception("Conflict resolution failed")
-
- yield self._replication.get_pdu(
- destination=missing_prev.origin,
- pdu_origin=origin,
- pdu_id=pdu_id,
- outlier=True
- )
- updated_current = yield self._handle_new_state(new_pdu)
- defer.returnValue(updated_current)
-
- if not current_branch:
- # There is no current state
- defer.returnValue(True)
- return
-
- n = new_branch[-1]
- c = current_branch[-1]
-
- common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
-
- if common_ancestor:
- # We found a common ancestor!
-
- if len(current_branch) == 1:
- # This is a direct clobber so we can just...
- defer.returnValue(True)
-
- else:
- # We didn't find a common ancestor. This is probably fine.
- pass
-
- result = yield self._do_conflict_res(
- new_branch, current_branch, common_ancestor
- )
- defer.returnValue(result)
+ state = {}
+ for group in state_groups:
+ for s in group.state:
+ state.setdefault(
+ (s.type, s.state_key),
+ {}
+ )[s.event_id] = s
+
+ unconflicted_state = {
+ k: v.values()[0] for k, v in state.items()
+ if len(v.values()) == 1
+ }
+
+ conflicted_state = {
+ k: v.values()
+ for k, v in state.items()
+ if len(v.values()) > 1
+ }
+
+ try:
+ new_state = {}
+ new_state.update(unconflicted_state)
+ for key, events in conflicted_state.items():
+ new_state[key] = yield self._resolve_state_events(events)
+ except:
+ logger.exception("Failed to resolve state")
+ raise
+
+ defer.returnValue(new_state)
+
+ def _get_power_level_from_event_state(self, event, user_id):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
+ level = None
+ if power_level_event:
+ level = power_level_event.content.get("users", {}).get(user_id)
+ if not level:
+ level = power_level_event.content.get("users_default", 0)
+
+ return level
@defer.inlineCallbacks
- def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
- conflict_res = [
- self._do_power_level_conflict_res,
- self._do_chain_length_conflict_res,
- self._do_hash_conflict_res,
- ]
-
- for algo in conflict_res:
- new_res, curr_res = yield defer.maybeDeferred(
- algo,
- new_branch, current_branch, common_ancestor
- )
+ @log_function
+ def _resolve_state_events(self, events):
+ curr_events = events
- if new_res < curr_res:
- defer.returnValue(False)
- elif new_res > curr_res:
- defer.returnValue(True)
+ new_powers = [
+ self._get_power_level_from_event_state(e, e.user_id)
+ for e in curr_events
+ ]
- raise Exception("Conflict resolution failed.")
+ new_powers = [
+ int(p) if p else 0 for p in new_powers
+ ]
- @defer.inlineCallbacks
- def _do_power_level_conflict_res(self, new_branch, current_branch,
- common_ancestor):
- new_powers_deferreds = []
- for e in new_branch[:-1] if common_ancestor else new_branch:
- if hasattr(e, "user_id"):
- new_powers_deferreds.append(
- self.store.get_power_level(e.context, e.user_id)
- )
-
- current_powers_deferreds = []
- for e in current_branch[:-1] if common_ancestor else current_branch:
- if hasattr(e, "user_id"):
- current_powers_deferreds.append(
- self.store.get_power_level(e.context, e.user_id)
- )
-
- new_powers = yield defer.gatherResults(
- new_powers_deferreds,
- consumeErrors=True
- )
+ max_power = max(new_powers)
- current_powers = yield defer.gatherResults(
- current_powers_deferreds,
- consumeErrors=True
- )
+ curr_events = [
+ z[0] for z in zip(curr_events, new_powers)
+ if z[1] == max_power
+ ]
- max_power_new = max(new_powers)
- max_power_current = max(current_powers)
+ if not curr_events:
+ raise RuntimeError("Max didn't get a max?")
+ elif len(curr_events) == 1:
+ defer.returnValue(curr_events[0])
+ # TODO: For now, just choose the one with the largest event_id.
defer.returnValue(
- (max_power_new, max_power_current)
- )
-
- def _do_chain_length_conflict_res(self, new_branch, current_branch,
- common_ancestor):
- return (len(new_branch), len(current_branch))
-
- def _do_hash_conflict_res(self, new_branch, current_branch,
- common_ancestor):
- new_str = "".join([p.pdu_id + p.origin for p in new_branch])
- c_str = "".join([p.pdu_id + p.origin for p in current_branch])
-
- return (
- hashlib.sha1(new_str).hexdigest(),
- hashlib.sha1(c_str).hexdigest()
+ sorted(
+ curr_events,
+ key=lambda e: hashlib.sha1(
+ e.event_id + e.user_id + e.room_id + e.type
+ ).hexdigest()
+ )[0]
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4e9291fdff..96adf20c89 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -17,13 +17,8 @@ from twisted.internet import defer
from synapse.api.events.room import (
RoomMemberEvent, RoomTopicEvent, FeedbackEvent,
-# RoomConfigEvent,
RoomNameEvent,
RoomJoinRulesEvent,
- RoomPowerLevelsEvent,
- RoomAddStateLevelEvent,
- RoomSendEventLevelEvent,
- RoomOpsPowerLevelsEvent,
RoomRedactionEvent,
)
@@ -37,9 +32,17 @@ from .registration import RegistrationStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .stream import StreamStore
-from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore
from .keys import KeyStore
+from .event_federation import EventFederationStore
+
+from .state import StateStore
+from .signatures import SignatureStore
+
+from syutil.base64util import decode_base64
+
+from synapse.crypto.event_signing import compute_event_reference_hash
+
import json
import logging
@@ -51,7 +54,6 @@ logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
- "pdu",
"users",
"profiles",
"presence",
@@ -59,6 +61,9 @@ SCHEMAS = [
"room_aliases",
"keys",
"redactions",
+ "state",
+ "event_edges",
+ "event_signatures",
]
@@ -73,10 +78,12 @@ class _RollbackButIsFineException(Exception):
"""
pass
+
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
- PresenceStore, PduStore, StatePduStore, TransactionStore,
- DirectoryStore, KeyStore):
+ PresenceStore, TransactionStore,
+ DirectoryStore, KeyStore, StateStore, SignatureStore,
+ EventFederationStore, ):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
@@ -99,6 +106,7 @@ class DataStore(RoomMemberStore, RoomStore,
try:
yield self.runInteraction(
+ "persist_event",
self._persist_pdu_event_txn,
pdu=pdu,
event=event,
@@ -119,7 +127,8 @@ class DataStore(RoomMemberStore, RoomStore,
"type",
"room_id",
"content",
- "unrecognized_keys"
+ "unrecognized_keys",
+ "depth",
],
allow_none=allow_none,
)
@@ -133,39 +142,12 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
backfilled=False, stream_ordering=None,
is_new_state=True):
- if pdu is not None:
- self._persist_event_pdu_txn(txn, pdu)
if event is not None:
return self._persist_event_txn(
txn, event, backfilled, stream_ordering,
is_new_state=is_new_state,
)
- def _persist_event_pdu_txn(self, txn, pdu):
- cols = dict(pdu.__dict__)
- unrec_keys = dict(pdu.unrecognized_keys)
- del cols["content"]
- del cols["prev_pdus"]
- cols["content_json"] = json.dumps(pdu.content)
-
- unrec_keys.update({
- k: v for k, v in cols.items()
- if k not in PdusTable.fields
- })
-
- cols["unrecognized_keys"] = json.dumps(unrec_keys)
-
- cols["ts"] = cols.pop("origin_server_ts")
-
- logger.debug("Persisting: %s", repr(cols))
-
- if pdu.is_state:
- self._persist_state_txn(txn, pdu.prev_pdus, cols)
- else:
- self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
-
- self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
-
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
@@ -179,17 +161,13 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_room_topic_txn(txn, event)
elif event.type == RoomJoinRulesEvent.TYPE:
self._store_join_rule(txn, event)
- elif event.type == RoomPowerLevelsEvent.TYPE:
- self._store_power_levels(txn, event)
- elif event.type == RoomAddStateLevelEvent.TYPE:
- self._store_add_state_level(txn, event)
- elif event.type == RoomSendEventLevelEvent.TYPE:
- self._store_send_event_level(txn, event)
- elif event.type == RoomOpsPowerLevelsEvent.TYPE:
- self._store_ops_level(txn, event)
elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event)
+ outlier = False
+ if hasattr(event, "outlier"):
+ outlier = event.outlier
+
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
@@ -197,25 +175,33 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"content": json.dumps(event.content),
"processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
- if hasattr(event, "outlier"):
- vals["outlier"] = event.outlier
- else:
- vals["outlier"] = False
-
unrec = {
k: v
for k, v in event.get_full_dict().items()
- if k not in vals.keys() and k not in ["redacted", "redacted_because"]
+ if k not in vals.keys() and k not in [
+ "redacted",
+ "redacted_because",
+ "signatures",
+ "hashes",
+ "prev_events",
+ ]
}
vals["unrecognized_keys"] = json.dumps(unrec)
try:
- self._simple_insert_txn(txn, "events", vals)
+ self._simple_insert_txn(
+ txn,
+ "events",
+ vals,
+ or_replace=(not outlier),
+ )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
@@ -224,6 +210,16 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ self._store_state_groups_txn(txn, event)
+
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
vals = {
@@ -233,8 +229,8 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
- if hasattr(event, "prev_state"):
- vals["prev_state"] = event.prev_state
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals)
@@ -249,6 +245,81 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
+ for e_id, h in event.prev_state:
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event.event_id,
+ "prev_event_id": e_id,
+ "room_id": event.room_id,
+ "is_state": 1,
+ },
+ or_ignore=True,
+ )
+
+ if not backfilled:
+ self._simple_insert_txn(
+ txn,
+ table="state_forward_extremities",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+ )
+
+ for prev_state_id, _ in event.prev_state:
+ self._simple_delete_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "event_id": prev_state_id,
+ }
+ )
+
+ for hash_alg, hash_base64 in event.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_event_content_hash_txn(
+ txn, event.event_id, hash_alg, hash_bytes,
+ )
+
+ if hasattr(event, "signatures"):
+ signatures = event.signatures.get(event.origin, {})
+
+ for key_id, signature_base64 in signatures.items():
+ signature_bytes = decode_base64(signature_base64)
+ self._store_event_origin_signature_txn(
+ txn, event.event_id, event.origin, key_id, signature_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg, hash_bytes
+ )
+
+ for auth_id, _ in event.auth_events:
+ self._simple_insert_txn(
+ txn,
+ table="event_auth",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "auth_id": auth_id,
+ },
+ or_ignore=True,
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+ self._store_event_reference_hash_txn(
+ txn, event.event_id, ref_alg, ref_hash_bytes
+ )
+
+ self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
+
def _store_redaction(self, txn, event):
txn.execute(
"INSERT OR IGNORE INTO redactions "
@@ -319,7 +390,7 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+ def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
@@ -330,29 +401,33 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
- membership_state = self._get_room_member(txn, user_id, room_id)
- prev_pdus = self._get_latest_pdus_in_context(
- txn, room_id
+ prev_events = self._get_latest_events_in_room(
+ txn,
+ event.room_id
)
- if state_type is not None and state_key is not None:
- prev_state_pdu = self._get_current_state_pdu(
- txn, room_id, state_type, state_key
+
+ prev_state = None
+ state_key = None
+ if hasattr(event, "state_key"):
+ state_key = event.state_key
+ prev_state = self._get_latest_state_in_room(
+ txn,
+ event.room_id,
+ type=event.type,
+ state_key=state_key,
)
- else:
- prev_state_pdu = None
return Snapshot(
store=self,
- room_id=room_id,
- user_id=user_id,
- prev_pdus=prev_pdus,
- membership_state=membership_state,
- state_type=state_type,
+ room_id=event.room_id,
+ user_id=event.user_id,
+ prev_events=prev_events,
+ prev_state=prev_state,
+ state_type=event.type,
state_key=state_key,
- prev_state_pdu=prev_state_pdu,
)
- return self.runInteraction(_snapshot)
+ return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
@@ -361,7 +436,7 @@ class Snapshot(object):
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
- prev_pdus (list): The list of PDU ids this snapshot is after.
+ prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
@@ -370,32 +445,30 @@ class Snapshot(object):
the previous value of the state type and key in the room.
"""
- def __init__(self, store, room_id, user_id, prev_pdus,
- membership_state, state_type=None, state_key=None,
- prev_state_pdu=None):
+ def __init__(self, store, room_id, user_id, prev_events,
+ prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
- self.prev_pdus = prev_pdus
- self.membership_state = membership_state
+ self.prev_events = prev_events
+ self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
- self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
- if hasattr(event, "prev_events"):
- return
-
- es = [
- "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
- ]
-
- event.prev_events = [e for e in es if e != event.event_id]
+ if not hasattr(event, "prev_events"):
+ event.prev_events = [
+ (event_id, hashes)
+ for event_id, hashes, _ in self.prev_events
+ ]
+
+ if self.prev_events:
+ event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+ else:
+ event.depth = 0
- if self.prev_pdus:
- event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
- else:
- event.depth = 0
+ if not hasattr(event, "prev_state") and self.prev_state is not None:
+ event.prev_state = self.prev_state
def schema_path(schema):
@@ -452,9 +525,10 @@ def prepare_database(db_conn):
db_conn.commit()
else:
- sql_script = "BEGIN TRANSACTION;"
+ sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS:
sql_script += read_schema(sql_loc)
+ sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..9aa404695d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,59 +14,69 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
import collections
import copy
import json
+import sys
+import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
- __slots__ = ["txn"]
+ __slots__ = ["txn", "name"]
- def __init__(self, txn):
+ def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
- def __getattribute__(self, name):
- if name == "execute":
- return object.__getattribute__(self, "execute")
-
- return getattr(object.__getattribute__(self, "txn"), name)
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
def __setattr__(self, name, value):
- setattr(object.__getattribute__(self, "txn"), name, value)
+ setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] %s", sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
- sql_logger.debug("[SQL values] " +
- ", ".join(("<%s>",) * len(values)), *values)
+ sql_logger.debug(
+ "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+ self.name,
+ *values
+ )
except:
# Don't let logging failures stop SQL from working
pass
- # TODO(paul): Here would be an excellent place to put some timing
- # measurements, and log (warning?) slow queries.
- return object.__getattribute__(self, "txn").execute(
- sql, *args, **kwargs
- )
+ start = time.clock() * 1000
+ try:
+ return self.txn.execute(
+ sql, *args, **kwargs
+ )
+ except:
+ logger.exception("[SQL FAIL] {%s}", self.name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
+ _TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@@ -74,10 +84,30 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
- def runInteraction(self, func, *args, **kwargs):
+ def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
- return func(LoggingTransaction(txn), *args, **kwargs)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
@@ -113,7 +143,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction(interaction)
+ return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +160,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
+ "_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -170,7 +201,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
- @defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -181,19 +211,40 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
- ret = yield self._simple_select_one(
+ return self.runInteraction(
+ "_simple_select_one_onecol_txn",
+ self._simple_select_one_onecol_txn,
+ table, keyvalues, retcol, allow_none=allow_none,
+ )
+
+ def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ allow_none=False):
+ ret = self._simple_select_onecol_txn(
+ txn,
table=table,
keyvalues=keyvalues,
- retcols=[retcol],
- allow_none=allow_none
+ retcol=retcol,
)
if ret:
- defer.returnValue(ret[retcol])
+ return ret[0]
else:
- defer.returnValue(None)
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ txn.execute(sql, keyvalues.values())
+
+ return [r[0] for r in txn.fetchall()]
- @defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -206,25 +257,33 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
- sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
- "retcol": retcol,
- "table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- }
-
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return txn.fetchall()
+ return self.runInteraction(
+ "_simple_select_onecol",
+ self._simple_select_onecol_txn,
+ table, keyvalues, retcol
+ )
- res = yield self.runInteraction(func)
+ def _simple_select_list(self, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
- defer.returnValue([r[0] for r in res])
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ return self.runInteraction(
+ "_simple_select_list",
+ self._simple_select_list_txn,
+ table, keyvalues, retcols
+ )
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
+ txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
@@ -232,14 +291,11 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(func)
+ txn.execute(sql, keyvalues.values())
+ return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -307,7 +363,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction(func)
+ return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +375,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@@ -328,7 +384,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction(func)
+ return self.runInteraction("_simple_delete_one", func)
+
+ def _simple_delete(self, table, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+
+ return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+ def _simple_delete_txn(self, txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+
+ return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +420,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self.runInteraction(func)
+ return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -355,6 +429,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
+ replaces_state = d.pop("prev_state", None)
+
+ if replaces_state:
+ d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
@@ -370,22 +448,52 @@ class SQLBaseStore(object):
)
def _parse_events(self, rows):
- return self.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
- sql = "SELECT * FROM events WHERE event_id = ?"
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
for ev in events:
- if hasattr(ev, "prev_state"):
- # Load previous state_content.
- # TODO: Should we be pulling this out above?
- cursor = txn.execute(sql, (ev.prev_state,))
- prevs = self.cursor_to_dict(cursor)
- if prevs:
- prev = self._parse_event_from_row(prevs[0])
- ev.prev_content = prev.content
+ signatures = self._get_event_origin_signatures_txn(
+ txn, ev.event_id,
+ )
+
+ ev.signatures = {
+ k: encode_base64(v) for k, v in signatures.items()
+ }
+
+ prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+ ev.prev_events = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 0
+ ]
+
+ ev.auth_events = self._get_auth_events(txn, ev.event_id)
+
+ if hasattr(ev, "state_key"):
+ ev.prev_state = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 1
+ ]
+
+ if hasattr(ev, "replaces_state"):
+ # Load previous state_content.
+ # FIXME (erikj): Handle multiple prev_states.
+ cursor = txn.execute(
+ select_event_sql,
+ (ev.replaces_state,)
+ )
+ prevs = self.cursor_to_dict(cursor)
+ if prevs:
+ prev = self._parse_event_from_row(prevs[0])
+ ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
@@ -393,8 +501,8 @@ class SQLBaseStore(object):
if ev.redacted:
# Get the redaction event.
- sql = "SELECT * FROM events WHERE event_id = ?"
- txn.execute(sql, (ev.redacted,))
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+ txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 52373a28a6..d6a7113b9c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
def delete_room_alias(self, room_alias):
return self.runInteraction(
+ "delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
new file mode 100644
index 0000000000..7140ea3d57
--- /dev/null
+++ b/synapse/storage/event_federation.py
@@ -0,0 +1,371 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from syutil.base64util import encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class EventFederationStore(SQLBaseStore):
+
+ def get_auth_chain(self, event_id):
+ return self.runInteraction(
+ "get_auth_chain",
+ self._get_auth_chain_txn,
+ event_id
+ )
+
+ def _get_auth_chain_txn(self, txn, event_id):
+ results = set([event_id])
+
+ front = set([event_id])
+ while front:
+ for ev_id in front:
+ new_front = set()
+ auth_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_auth",
+ keyvalues={
+ "event_id": ev_id,
+ },
+ retcol="auth_id",
+ )
+
+ new_front.update(auth_ids)
+ front = new_front
+ new_front.clear()
+
+ sql = "SELECT * FROM events WHERE event_id = ?"
+ rows = []
+ for ev_id in results:
+ c = txn.execute(sql, (ev_id,))
+ rows.extend(self.cursor_to_dict(c))
+
+ return self._parse_events_txn(txn, rows)
+
+ def get_oldest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_oldest_events_in_room",
+ self._get_oldest_events_in_room_txn,
+ room_id,
+ )
+
+ def _get_oldest_events_in_room_txn(self, txn, room_id):
+ return self._simple_select_onecol_txn(
+ txn,
+ table="event_backward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ )
+
+ def get_latest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_latest_events_in_room",
+ self._get_latest_events_in_room,
+ room_id,
+ )
+
+ def _get_latest_events_in_room(self, txn, room_id):
+ sql = (
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "WHERE f.room_id = ?"
+ )
+
+ txn.execute(sql, (room_id, ))
+
+ results = []
+ for event_id, depth in txn.fetchall():
+ hashes = self._get_event_reference_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes, depth))
+
+ return results
+
+ def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+ event_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ "type": type,
+ "state_key": state_key,
+ },
+ retcol="event_id",
+ )
+
+ results = []
+ for event_id in event_ids:
+ hashes = self._get_event_reference_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes))
+
+ return results
+
+ def _get_prev_events(self, txn, event_id):
+ results = self._get_prev_events_and_state(
+ txn,
+ event_id,
+ is_state=0,
+ )
+
+ return [(e_id, h, ) for e_id, h, _ in results]
+
+ def _get_prev_state(self, txn, event_id):
+ results = self._get_prev_events_and_state(
+ txn,
+ event_id,
+ is_state=1,
+ )
+
+ return [(e_id, h, ) for e_id, h, _ in results]
+
+ def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+ keyvalues = {
+ "event_id": event_id,
+ }
+
+ if is_state is not None:
+ keyvalues["is_state"] = is_state
+
+ res = self._simple_select_list_txn(
+ txn,
+ table="event_edges",
+ keyvalues=keyvalues,
+ retcols=["prev_event_id", "is_state"],
+ )
+
+ results = []
+ for d in res:
+ hashes = self._get_event_reference_hashes_txn(
+ txn,
+ d["prev_event_id"]
+ )
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+ return results
+
+ def _get_auth_events(self, txn, event_id):
+ auth_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_auth",
+ keyvalues={
+ "event_id": event_id,
+ },
+ retcol="auth_id",
+ )
+
+ results = []
+ for auth_id in auth_ids:
+ hashes = self._get_event_reference_hashes_txn(txn, auth_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((auth_id, prev_hashes))
+
+ return results
+
+ def get_min_depth(self, room_id):
+ return self.runInteraction(
+ "get_min_depth",
+ self._get_min_depth_interaction,
+ room_id,
+ )
+
+ def _get_min_depth_interaction(self, txn, room_id):
+ min_depth = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id,},
+ retcol="min_depth",
+ allow_none=True,
+ )
+
+ return int(min_depth) if min_depth is not None else None
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ self._simple_insert_txn(
+ txn,
+ table="room_depth",
+ values={
+ "room_id": room_id,
+ "min_depth": depth,
+ },
+ or_replace=True,
+ )
+
+ def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+ room_id):
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": e_id,
+ "room_id": room_id,
+ "is_state": 0,
+ },
+ or_ignore=True,
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_delete_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "event_id": e_id,
+ "room_id": room_id,
+ }
+ )
+
+
+
+ # We only insert as a forward extremity the new pdu if there are no
+ # other pdus that reference it as a prev pdu
+ query = (
+ "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
+ "SELECT ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(event_edges)s WHERE "
+ "prev_event_id = ? "
+ ")"
+ ) % {
+ "table": "event_forward_extremities",
+ "event_edges": "event_edges",
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (event_id, room_id, event_id))
+
+ # Insert all the prev_pdus as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference pdus that we have already seen
+ query = (
+ "DELETE FROM event_backward_extremities WHERE EXISTS ("
+ "SELECT 1 FROM events "
+ "WHERE "
+ "event_backward_extremities.event_id = events.event_id "
+ "AND not events.outlier "
+ ")"
+ )
+ txn.execute(query)
+
+
+ def get_backfill_events(self, room_id, event_list, limit):
+ """Get a list of Events for a given topic that occured before (and
+ including) the pdus in pdu_list. Return a list of max size `limit`.
+
+ Args:
+ txn
+ room_id (str)
+ event_list (list)
+ limit (int)
+
+ Return:
+ list: A list of PduTuples
+ """
+ return self.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events, room_id, event_list, limit
+ )
+
+ def _get_backfill_events(self, txn, room_id, event_list, limit):
+ logger.debug(
+ "_get_backfill_events: %s, %s, %s",
+ room_id, repr(event_list), limit
+ )
+
+ # We seed the pdu_results with the things from the pdu_list.
+ event_results = event_list
+
+ front = event_list
+
+ query = (
+ "SELECT prev_event_id FROM event_edges "
+ "WHERE room_id = ? AND event_id = ? "
+ "LIMIT ?"
+ )
+
+ # We iterate through all event_ids in `front` to select their previous
+ # events. These are dumped in `new_front`.
+ # We continue until we reach the limit *or* new_front is empty (i.e.,
+ # we've run out of things to select
+ while front and len(event_results) < limit:
+
+ new_front = []
+ for event_id in front:
+ logger.debug(
+ "_backfill_interaction: id=%s",
+ event_id
+ )
+
+ txn.execute(
+ query,
+ (room_id, event_id, limit - len(event_results))
+ )
+
+ for row in txn.fetchall():
+ logger.debug(
+ "_backfill_interaction: got id=%s",
+ *row
+ )
+ new_front.append(row)
+
+ front = new_front
+ event_results += new_front
+
+ # We also want to update the `prev_pdus` attributes before returning.
+ return self._get_pdu_tuples(txn, event_results)
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
deleted file mode 100644
index d70467dcd6..0000000000
--- a/synapse/storage/pdu.py
+++ /dev/null
@@ -1,915 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore, Table, JoinHelper
-
-from synapse.federation.units import Pdu
-from synapse.util.logutils import log_function
-
-from collections import namedtuple
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class PduStore(SQLBaseStore):
- """A collection of queries for handling PDUs.
- """
-
- def get_pdu(self, pdu_id, origin):
- """Given a pdu_id and origin, get a PDU.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
-
- Returns:
- PduTuple: If the pdu does not exist in the database, returns None
- """
-
- return self.runInteraction(
- self._get_pdu_tuple, pdu_id, origin
- )
-
- def _get_pdu_tuple(self, txn, pdu_id, origin):
- res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
- return res[0] if res else None
-
- def _get_pdu_tuples(self, txn, pdu_id_tuples):
- results = []
- for pdu_id, origin in pdu_id_tuples:
- txn.execute(
- PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
- (pdu_id, origin)
- )
-
- edges = [
- (r.prev_pdu_id, r.prev_origin)
- for r in PduEdgesTable.decode_results(txn.fetchall())
- ]
-
- query = (
- "SELECT %(fields)s FROM %(pdus)s as p "
- "LEFT JOIN %(state)s as s "
- "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
- "WHERE p.pdu_id = ? AND p.origin = ? "
- ) % {
- "fields": _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s"),
- "pdus": PdusTable.table_name,
- "state": StatePdusTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- row = txn.fetchone()
- if row:
- results.append(PduTuple(PduEntry(*row), edges))
-
- return results
-
- def get_current_state_for_context(self, context):
- """Get a list of PDUs that represent the current state for a given
- context
-
- Args:
- context (str)
-
- Returns:
- list: A list of PduTuples
- """
-
- return self.runInteraction(
- self._get_current_state_for_context,
- context
- )
-
- def _get_current_state_for_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s WHERE context = ?"
- % CurrentStateTable.table_name
- )
-
- logger.debug("get_current_state %s, Args=%s", query, context)
- txn.execute(query, (context,))
-
- res = txn.fetchall()
-
- logger.debug("get_current_state %d results", len(res))
-
- return self._get_pdu_tuples(txn, res)
-
- def _persist_pdu_txn(self, txn, prev_pdus, cols):
- """Inserts a (non-state) PDU into the database.
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable.
- """
- entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
-
- txn.execute(PdusTable.insert_statement(), entry)
-
- self._handle_prev_pdus(
- txn, entry.outlier, entry.pdu_id, entry.origin,
- prev_pdus, entry.context
- )
-
- def mark_pdu_as_processed(self, pdu_id, pdu_origin):
- """Mark a received PDU as processed.
-
- Args:
- txn
- pdu_id (str)
- pdu_origin (str)
- """
-
- return self.runInteraction(
- self._mark_as_processed, pdu_id, pdu_origin
- )
-
- def _mark_as_processed(self, txn, pdu_id, pdu_origin):
- txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
-
- def get_all_pdus_from_context(self, context):
- """Get a list of all PDUs for a given context."""
- return self.runInteraction(
- self._get_all_pdus_from_context, context,
- )
-
- def _get_all_pdus_from_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s "
- "WHERE context = ?"
- ) % PdusTable.table_name
-
- txn.execute(query, (context,))
-
- return self._get_pdu_tuples(txn, txn.fetchall())
-
- def get_backfill(self, context, pdu_list, limit):
- """Get a list of Pdus for a given topic that occured before (and
- including) the pdus in pdu_list. Return a list of max size `limit`.
-
- Args:
- txn
- context (str)
- pdu_list (list)
- limit (int)
-
- Return:
- list: A list of PduTuples
- """
- return self.runInteraction(
- self._get_backfill, context, pdu_list, limit
- )
-
- def _get_backfill(self, txn, context, pdu_list, limit):
- logger.debug(
- "backfill: %s, %s, %s",
- context, repr(pdu_list), limit
- )
-
- # We seed the pdu_results with the things from the pdu_list.
- pdu_results = pdu_list
-
- front = pdu_list
-
- query = (
- "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
- "WHERE context = ? AND pdu_id = ? AND origin = ? "
- "LIMIT ?"
- ) % {
- "edges_table": PduEdgesTable.table_name,
- }
-
- # We iterate through all pdu_ids in `front` to select their previous
- # pdus. These are dumped in `new_front`. We continue until we reach the
- # limit *or* new_front is empty (i.e., we've run out of things to
- # select
- while front and len(pdu_results) < limit:
-
- new_front = []
- for pdu_id, origin in front:
- logger.debug(
- "_backfill_interaction: i=%s, o=%s",
- pdu_id, origin
- )
-
- txn.execute(
- query,
- (context, pdu_id, origin, limit - len(pdu_results))
- )
-
- for row in txn.fetchall():
- logger.debug(
- "_backfill_interaction: got i=%s, o=%s",
- *row
- )
- new_front.append(row)
-
- front = new_front
- pdu_results += new_front
-
- # We also want to update the `prev_pdus` attributes before returning.
- return self._get_pdu_tuples(txn, pdu_results)
-
- def get_min_depth_for_context(self, context):
- """Get the current minimum depth for a context
-
- Args:
- txn
- context (str)
- """
- return self.runInteraction(
- self._get_min_depth_for_context, context
- )
-
- def _get_min_depth_for_context(self, txn, context):
- return self._get_min_depth_interaction(txn, context)
-
- def _get_min_depth_interaction(self, txn, context):
- txn.execute(
- "SELECT min_depth FROM %s WHERE context = ?"
- % ContextDepthTable.table_name,
- (context,)
- )
-
- row = txn.fetchone()
-
- return row[0] if row else None
-
- def _update_min_depth_for_context_txn(self, txn, context, depth):
- """Update the minimum `depth` of the given context, which is the line
- on which we stop backfilling backwards.
-
- Args:
- context (str)
- depth (int)
- """
- min_depth = self._get_min_depth_interaction(txn, context)
-
- do_insert = depth < min_depth if min_depth else True
-
- if do_insert:
- txn.execute(
- "INSERT OR REPLACE INTO %s (context, min_depth) "
- "VALUES (?,?)" % ContextDepthTable.table_name,
- (context, depth)
- )
-
- def _get_latest_pdus_in_context(self, txn, context):
- """Get's a list of the most current pdus for a given context. This is
- used when we are sending a Pdu and need to fill out the `prev_pdus`
- key
-
- Args:
- txn
- context
- """
- query = (
- "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
- "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
- "AND f.origin = p.origin "
- "WHERE f.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- logger.debug("get_prev query: %s", query)
-
- txn.execute(
- query,
- (context, )
- )
-
- results = txn.fetchall()
-
- return [(row[0], row[1], row[2]) for row in results]
-
- @defer.inlineCallbacks
- def get_oldest_pdus_in_context(self, context):
- """Get a list of Pdus that we haven't backfilled beyond yet (and havent
- seen). This list is used when we want to backfill backwards and is the
- list we send to the remote server.
-
- Args:
- txn
- context (str)
-
- Returns:
- list: A list of PduIdTuple.
- """
- results = yield self._execute(
- None,
- "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
- % {"back": PduBackwardExtremitiesTable.table_name, },
- context
- )
-
- defer.returnValue([PduIdTuple(i, o) for i, o in results])
-
- def is_pdu_new(self, pdu_id, origin, context, depth):
- """For a given Pdu, try and figure out if it's 'new', i.e., if it's
- not something we got randomly from the past, for example when we
- request the current state of the room that will probably return a bunch
- of pdus from before we joined.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
- context (str)
- depth (int)
-
- Returns:
- bool
- """
-
- return self.runInteraction(
- self._is_pdu_new,
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- depth=depth
- )
-
- def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
- # If depth > min depth in back table, then we classify it as new.
- # OR if there is nothing in the back table, then it kinda needs to
- # be a new thing.
- query = (
- "SELECT min(p.depth) FROM %(edges)s as e "
- "INNER JOIN %(back)s as b "
- "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
- "INNER JOIN %(pdus)s as p "
- "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
- "WHERE p.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "edges": PduEdgesTable.table_name,
- "back": PduBackwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (context,))
-
- min_depth, = txn.fetchone()
-
- if not min_depth or depth > int(min_depth):
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s min_depth=%s",
- pdu_id, origin, depth, min_depth
- )
- return True
-
- # If this pdu is in the forwards table, then it also is a new one
- query = (
- "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
- ) % {
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- # Did we get anything?
- if txn.fetchall():
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s was forward",
- pdu_id, origin, depth
- )
- return True
-
- logger.debug(
- "is_new false: id=%s, o=%s, d=%s",
- pdu_id, origin, depth
- )
-
- # FINE THEN. It's probably old.
- return False
-
- @staticmethod
- @log_function
- def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
- context):
- txn.executemany(
- PduEdgesTable.insert_statement(),
- [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
- )
-
- # Update the extremities table if this is not an outlier.
- if not outlier:
-
- # First, we delete the new one from the forwards extremities table.
- query = (
- "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
- % PduForwardExtremitiesTable.table_name
- )
- txn.executemany(query, prev_pdus)
-
- # We only insert as a forward extremety the new pdu if there are no
- # other pdus that reference it as a prev pdu
- query = (
- "INSERT INTO %(table)s (pdu_id, origin, context) "
- "SELECT ?, ?, ? WHERE NOT EXISTS ("
- "SELECT 1 FROM %(pdu_edges)s WHERE "
- "prev_pdu_id = ? AND prev_origin = ?"
- ")"
- ) % {
- "table": PduForwardExtremitiesTable.table_name,
- "pdu_edges": PduEdgesTable.table_name
- }
-
- logger.debug("query: %s", query)
-
- txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
-
- # Insert all the prev_pdus as a backwards thing, they'll get
- # deleted in a second if they're incorrect anyway.
- txn.executemany(
- PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o in prev_pdus]
- )
-
- # Also delete from the backwards extremities table all ones that
- # reference pdus that we have already seen
- query = (
- "DELETE FROM %(pdu_back)s WHERE EXISTS ("
- "SELECT 1 FROM %(pdus)s AS pdus "
- "WHERE "
- "%(pdu_back)s.pdu_id = pdus.pdu_id "
- "AND %(pdu_back)s.origin = pdus.origin "
- "AND not pdus.outlier "
- ")"
- ) % {
- "pdu_back": PduBackwardExtremitiesTable.table_name,
- "pdus": PdusTable.table_name,
- }
- txn.execute(query)
-
-
-class StatePduStore(SQLBaseStore):
- """A collection of queries for handling state PDUs.
- """
-
- def _persist_state_txn(self, txn, prev_pdus, cols):
- """Inserts a state PDU into the database
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable and StatePdusTable
- """
- pdu_entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
- state_entry = StatePdusTable.EntryType(
- **{k: cols.get(k, None) for k in StatePdusTable.fields}
- )
-
- logger.debug("Inserting pdu: %s", repr(pdu_entry))
- logger.debug("Inserting state: %s", repr(state_entry))
-
- txn.execute(PdusTable.insert_statement(), pdu_entry)
- txn.execute(StatePdusTable.insert_statement(), state_entry)
-
- self._handle_prev_pdus(
- txn,
- pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
- pdu_entry.context
- )
-
- def get_unresolved_state_tree(self, new_state_pdu):
- return self.runInteraction(
- self._get_unresolved_state_tree, new_state_pdu
- )
-
- @log_function
- def _get_unresolved_state_tree(self, txn, new_pdu):
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- ReturnType = namedtuple(
- "StateReturnType", ["new_branch", "current_branch"]
- )
- return_value = ReturnType([new_pdu], [])
-
- if not current:
- logger.debug("get_unresolved_state_tree No current state.")
- return (return_value, None)
-
- return_value.current_branch.append(current)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
-
- missing_branch = None
- for branch, prev_state, state in enum_branches:
- if state:
- return_value[branch].append(state)
- else:
- # We don't have prev_state :(
- missing_branch = branch
- break
-
- return (return_value, missing_branch)
-
- def update_current_state(self, pdu_id, origin, context, pdu_type,
- state_key):
- return self.runInteraction(
- self._update_current_state,
- pdu_id, origin, context, pdu_type, state_key
- )
-
- def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
- state_key):
- query = (
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- ) % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- }
-
- query_args = CurrentStateTable.EntryType(
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- pdu_type=pdu_type,
- state_key=state_key
- )
-
- txn.execute(query, query_args)
-
- def get_current_state_pdu(self, context, pdu_type, state_key):
- """For a given context, pdu_type, state_key 3-tuple, return what is
- currently considered the current state.
-
- Args:
- txn
- context (str)
- pdu_type (str)
- state_key (str)
-
- Returns:
- PduEntry
- """
-
- return self.runInteraction(
- self._get_current_state_pdu, context, pdu_type, state_key
- )
-
- def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
- return self._get_current_interaction(txn, context, pdu_type, state_key)
-
- def _get_current_interaction(self, txn, context, pdu_type, state_key):
- logger.debug(
- "_get_current_interaction %s %s %s",
- context, pdu_type, state_key
- )
-
- fields = _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s")
-
- current_query = (
- "SELECT %(fields)s FROM %(state)s as s "
- "INNER JOIN %(pdus)s as p "
- "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
- "INNER JOIN %(curr)s as c "
- "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
- "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
- ) % {
- "fields": fields,
- "curr": CurrentStateTable.table_name,
- "state": StatePdusTable.table_name,
- "pdus": PdusTable.table_name,
- }
-
- txn.execute(
- current_query,
- (context, pdu_type, state_key)
- )
-
- row = txn.fetchone()
-
- result = PduEntry(*row) if row else None
-
- if not result:
- logger.debug("_get_current_interaction not found")
- else:
- logger.debug(
- "_get_current_interaction found %s %s",
- result.pdu_id, result.origin
- )
-
- return result
-
- def handle_new_state(self, new_pdu):
- """Actually perform conflict resolution on the new_pdu on the
- assumption we have all the pdus required to perform it.
-
- Args:
- new_pdu
-
- Returns:
- bool: True if the new_pdu clobbered the current state, False if not
- """
- return self.runInteraction(
- self._handle_new_state, new_pdu
- )
-
- def _handle_new_state(self, txn, new_pdu):
- logger.debug(
- "handle_new_state %s %s",
- new_pdu.pdu_id, new_pdu.origin
- )
-
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- is_current = False
-
- if (not current or not current.prev_state_id
- or not current.prev_state_origin):
- # Oh, we don't have any state for this yet.
- is_current = True
- elif (current.pdu_id == new_pdu.prev_state_id
- and current.origin == new_pdu.prev_state_origin):
- # Oh! A direct clobber. Just do it.
- is_current = True
- else:
- ##
- # Ok, now loop through until we get to a common ancestor.
- max_new = int(new_pdu.power_level)
- max_current = int(current.power_level)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
- for branch, prev_state, state in enum_branches:
- if not state:
- raise RuntimeError(
- "Could not find state_pdu %s %s" %
- (
- prev_state.prev_state_id,
- prev_state.prev_state_origin
- )
- )
-
- if branch == 0:
- max_new = max(int(state.depth), max_new)
- else:
- max_current = max(int(state.depth), max_current)
-
- is_current = max_new > max_current
-
- if is_current:
- logger.debug("handle_new_state make current")
-
- # Right, this is a new thing, so woo, just insert it.
- txn.execute(
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- },
- CurrentStateTable.EntryType(
- *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
- )
- )
- else:
- logger.debug("handle_new_state not current")
-
- logger.debug("handle_new_state done")
-
- return is_current
-
- @log_function
- def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
- branch_a = pdu_a
- branch_b = pdu_b
-
- while True:
- if (branch_a.pdu_id == branch_b.pdu_id
- and branch_a.origin == branch_b.origin):
- # Woo! We found a common ancestor
- logger.debug("_enumerate_state_branches Found common ancestor")
- break
-
- do_branch_a = (
- hasattr(branch_a, "prev_state_id") and
- branch_a.prev_state_id
- )
-
- do_branch_b = (
- hasattr(branch_b, "prev_state_id") and
- branch_b.prev_state_id
- )
-
- logger.debug(
- "do_branch_a=%s, do_branch_b=%s",
- do_branch_a, do_branch_b
- )
-
- if do_branch_a and do_branch_b:
- do_branch_a = int(branch_a.depth) > int(branch_b.depth)
-
- if do_branch_a:
- pdu_tuple = PduIdTuple(
- branch_a.prev_state_id,
- branch_a.prev_state_origin
- )
-
- prev_branch = branch_a
-
- logger.debug("getting branch_a prev %s", pdu_tuple)
- branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_a:
- branch_a = Pdu.from_pdu_tuple(branch_a)
-
- logger.debug("branch_a=%s", branch_a)
-
- yield (0, prev_branch, branch_a)
-
- if not branch_a:
- break
- elif do_branch_b:
- pdu_tuple = PduIdTuple(
- branch_b.prev_state_id,
- branch_b.prev_state_origin
- )
-
- prev_branch = branch_b
-
- logger.debug("getting branch_b prev %s", pdu_tuple)
- branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_b:
- branch_b = Pdu.from_pdu_tuple(branch_b)
-
- logger.debug("branch_b=%s", branch_b)
-
- yield (1, prev_branch, branch_b)
-
- if not branch_b:
- break
- else:
- break
-
-
-class PdusTable(Table):
- table_name = "pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "ts",
- "depth",
- "is_state",
- "content_json",
- "unrecognized_keys",
- "outlier",
- "have_processed",
- ]
-
- EntryType = namedtuple("PdusEntry", fields)
-
-
-class PduDestinationsTable(Table):
- table_name = "pdu_destinations"
-
- fields = [
- "pdu_id",
- "origin",
- "destination",
- "delivered_ts",
- ]
-
- EntryType = namedtuple("PduDestinationsEntry", fields)
-
-
-class PduEdgesTable(Table):
- table_name = "pdu_edges"
-
- fields = [
- "pdu_id",
- "origin",
- "prev_pdu_id",
- "prev_origin",
- "context"
- ]
-
- EntryType = namedtuple("PduEdgesEntry", fields)
-
-
-class PduForwardExtremitiesTable(Table):
- table_name = "pdu_forward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
-
-
-class PduBackwardExtremitiesTable(Table):
- table_name = "pdu_backward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
-
-
-class ContextDepthTable(Table):
- table_name = "context_depth"
-
- fields = [
- "context",
- "min_depth",
- ]
-
- EntryType = namedtuple("ContextDepthEntry", fields)
-
-
-class StatePdusTable(Table):
- table_name = "state_pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- "power_level",
- "prev_state_id",
- "prev_state_origin",
- ]
-
- EntryType = namedtuple("StatePdusEntry", fields)
-
-
-class CurrentStateTable(Table):
- table_name = "current_state"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- ]
-
- EntryType = namedtuple("CurrentStateEntry", fields)
-
-_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
-
-
-# TODO: These should probably be put somewhere more sensible
-PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
-
-PduEntry = _pdu_state_joiner.EntryType
-""" We are always interested in the join of the PdusTable and StatePdusTable,
-rather than just the PdusTable.
-
-This does not include a prev_pdus key.
-"""
-
-PduTuple = namedtuple(
- "PduTuple",
- ("pdu_entry", "prev_pdu_list")
-)
-""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
-the `prev_pdus` key of a PDU.
-"""
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 719806f82b..a2ca6f9a69 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(self._register, user_id, token,
- password_hash)
+ yield self.runInteraction(
+ "register",
+ self._register, user_id, token, password_hash
+ )
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
@@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
+ "get_user_by_token",
self._query_for_auth,
token
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8cd46334cf..0c83c11ad3 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -148,83 +148,6 @@ class RoomStore(SQLBaseStore):
else:
defer.returnValue(None)
- def get_power_level(self, room_id, user_id):
- return self.runInteraction(
- self._get_power_level,
- room_id, user_id,
- )
-
- def _get_power_level(self, txn, room_id, user_id):
- sql = (
- "SELECT level FROM room_power_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? AND r.user_id = ? "
- )
-
- rows = txn.execute(sql, (room_id, user_id,)).fetchall()
-
- if len(rows) == 1:
- return rows[0][0]
-
- sql = (
- "SELECT level FROM room_default_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- )
-
- rows = txn.execute(sql, (room_id,)).fetchall()
-
- if len(rows) == 1:
- return rows[0][0]
- else:
- return None
-
- def get_ops_levels(self, room_id):
- return self.runInteraction(
- self._get_ops_levels,
- room_id,
- )
-
- def _get_ops_levels(self, txn, room_id):
- sql = (
- "SELECT ban_level, kick_level, redact_level "
- "FROM room_ops_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- )
-
- rows = txn.execute(sql, (room_id,)).fetchall()
-
- if len(rows) == 1:
- return OpsLevel(rows[0][0], rows[0][1], rows[0][2])
- else:
- return OpsLevel(None, None)
-
- def get_add_state_level(self, room_id):
- return self._get_level_from_table("room_add_state_levels", room_id)
-
- def get_send_event_level(self, room_id):
- return self._get_level_from_table("room_send_event_levels", room_id)
-
- @defer.inlineCallbacks
- def _get_level_from_table(self, table, room_id):
- sql = (
- "SELECT level FROM %(table)s as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- ) % {"table": table}
-
- rows = yield self._execute(None, sql, room_id)
-
- if len(rows) == 1:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(None)
-
def _store_room_topic_txn(self, txn, event):
self._simple_insert_txn(
txn,
@@ -258,84 +181,6 @@ class RoomStore(SQLBaseStore):
},
)
- def _store_power_levels(self, txn, event):
- for user_id, level in event.content.items():
- if user_id == "default":
- self._simple_insert_txn(
- txn,
- "room_default_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": level,
- },
- )
- else:
- self._simple_insert_txn(
- txn,
- "room_power_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "user_id": user_id,
- "level": level
- },
- )
-
- def _store_default_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_default_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["default_level"],
- },
- )
-
- def _store_add_state_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_add_state_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["level"],
- },
- )
-
- def _store_send_event_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_send_event_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["level"],
- },
- )
-
- def _store_ops_level(self, txn, event):
- content = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- }
-
- if "kick_level" in event.content:
- content["kick_level"] = event.content["kick_level"]
-
- if "ban_level" in event.content:
- content["ban_level"] = event.content["ban_level"]
-
- if "redact_level" in event.content:
- content["redact_level"] = event.content["redact_level"]
-
- self._simple_insert_txn(
- txn,
- "room_ops_levels",
- content,
- )
-
class RoomsTable(Table):
table_name = "rooms"
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
deleted file mode 100644
index 8a00868065..0000000000
--- a/synapse/storage/schema/edge_pdus.sql
+++ /dev/null
@@ -1,31 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS context_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE TABLE IF NOT EXISTS origin_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
-CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
new file mode 100644
index 0000000000..be1c72a775
--- /dev/null
+++ b/synapse/storage/schema/event_edges.sql
@@ -0,0 +1,75 @@
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+ event_id TEXT NOT NULL,
+ prev_event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_state INTEGER NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
+);
+
+CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
+CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+ room_id TEXT NOT NULL,
+ min_depth INTEGER NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (room_id)
+);
+
+CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+ event_id TEXT NOT NULL,
+ destination TEXT NOT NULL,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+ room_id, type, state_key
+);
+CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_auth(
+ event_id TEXT NOT NULL,
+ auth_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
+);
+
+CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
+CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);
\ No newline at end of file
diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql
new file mode 100644
index 0000000000..5491c7ecec
--- /dev/null
+++ b/synapse/storage/schema/event_signatures.sql
@@ -0,0 +1,65 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_origin_signatures (
+ event_id TEXT,
+ origin TEXT,
+ key_id TEXT,
+ signature BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+ event_id TEXT,
+ prev_event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ event_id, prev_event_id, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
+ event_id
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 3aa83f5c8c..8d6f655993 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
+ depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
deleted file mode 100644
index 16e111a56c..0000000000
--- a/synapse/storage/schema/pdu.sql
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
--- Stores pdus and their content
-CREATE TABLE IF NOT EXISTS pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- ts INTEGER,
- depth INTEGER DEFAULT 0 NOT NULL,
- is_state BOOL,
- content_json TEXT,
- unrecognized_keys TEXT,
- outlier BOOL NOT NULL,
- have_processed BOOL,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
--- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
-CREATE TABLE IF NOT EXISTS state_pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- power_level TEXT,
- prev_state_id TEXT,
- prev_state_origin TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
-);
-
-CREATE TABLE IF NOT EXISTS current_state(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
-);
-
--- Stores where each pdu we want to send should be sent and the delivery status.
-create TABLE IF NOT EXISTS pdu_destinations(
- pdu_id TEXT,
- origin TEXT,
- destination TEXT,
- delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_edges(
- pdu_id TEXT,
- origin TEXT,
- prev_pdu_id TEXT,
- prev_origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
-);
-
-CREATE TABLE IF NOT EXISTS context_depth(
- context TEXT,
- min_depth INTEGER,
- CONSTRAINT uniqueness UNIQUE (context)
-);
-
-CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
-
-
-CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
--- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
-
-CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
-CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
new file mode 100644
index 0000000000..b44c56b519
--- /dev/null
+++ b/synapse/storage/schema/state.sql
@@ -0,0 +1,33 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+ id INTEGER PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+ state_group INTEGER NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+ event_id TEXT NOT NULL,
+ state_group INTEGER NOT NULL
+);
\ No newline at end of file
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
new file mode 100644
index 0000000000..84a49088a2
--- /dev/null
+++ b/synapse/storage/signatures.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class SignatureStore(SQLBaseStore):
+ """Persistence for event signatures and hashes"""
+
+ def _get_event_content_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given Event.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_content_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_content_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a Event
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_content_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def get_event_reference_hashes(self, event_ids):
+ def f(txn):
+ return [
+ self._get_event_reference_hashes_txn(txn, ev)
+ for ev in event_ids
+ ]
+
+ return self.runInteraction(
+ "get_event_reference_hashes",
+ f
+ )
+
+ def _get_event_reference_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_reference_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_reference_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+
+ def _get_event_origin_signatures_txn(self, txn, event_id):
+ """Get all the signatures for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of key_id -> signature_bytes.
+ """
+ query = (
+ "SELECT key_id, signature"
+ " FROM event_origin_signatures"
+ " WHERE event_id = ? "
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
+ signature_bytes):
+ """Store a signature from the origin server for a PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ origin (str): origin of the Event.
+ key_id (str): Id for the signing key.
+ signature (bytes): The signature.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_origin_signatures",
+ {
+ "event_id": event_id,
+ "origin": origin,
+ "key_id": key_id,
+ "signature": buffer(signature_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def _get_prev_event_hashes_txn(self, txn, event_id):
+ """Get all the hashes for previous PDUs of a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+ """
+ query = (
+ "SELECT prev_event_id, algorithm, hash"
+ " FROM event_edge_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ results = {}
+ for prev_event_id, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault(prev_event_id, {})
+ hashes[algorithm] = hash_bytes
+ return results
+
+ def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+ algorithm, hash_bytes):
+ self._simple_insert_txn(
+ txn,
+ "event_edge_hashes",
+ {
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
new file mode 100644
index 0000000000..e08acd6404
--- /dev/null
+++ b/synapse/storage/state.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from collections import namedtuple
+
+
+StateGroup = namedtuple("StateGroup", ("group", "state"))
+
+
+class StateStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, event_ids):
+ groups = set()
+ for event_id in event_ids:
+ group = yield self._simple_select_one_onecol(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ groups.add(group)
+
+ res = []
+ for group in groups:
+ state_ids = yield self._simple_select_onecol(
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+ state = []
+ for state_id in state_ids:
+ s = yield self.get_event(
+ state_id,
+ allow_none=True,
+ )
+ if s:
+ state.append(s)
+
+ res.append(StateGroup(group, state))
+
+ defer.returnValue(res)
+
+ def store_state_groups(self, event):
+ return self.runInteraction(
+ "store_state_groups",
+ self._store_state_groups_txn, event
+ )
+
+ def _store_state_groups_txn(self, txn, event):
+ if not event.state_events:
+ return
+
+ state_group = event.state_group
+ if not state_group:
+ state_group = self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ }
+ )
+
+ for state in event.state_events.values():
+ self._simple_insert_txn(
+ txn,
+ table="state_groups_state",
+ values={
+ "state_group": state_group,
+ "room_id": state.room_id,
+ "type": state.type,
+ "state_key": state.state_key,
+ "event_id": state.event_id,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="event_to_state_groups",
+ values={
+ "state_group": state_group,
+ "event_id": event.event_id,
+ }
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d61f909939..8f7f61d29d 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
- return self.runInteraction(self._get_room_events_max_id_txn)
+ return self.runInteraction(
+ "get_room_events_max_id",
+ self._get_room_events_max_id_txn
+ )
def _get_room_events_max_id_txn(self, txn):
txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 2ba8e30efe..00d0f48082 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore, Table
-from .pdu import PdusTable
from collections import namedtuple
@@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
)
@@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "set_received_txn_response",
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
transaction_id (str)
destination (str)
origin_server_ts (int)
- pdu_list (list)
Returns:
list: A list of previous transaction ids.
"""
return self.runInteraction(
+ "prep_send_transaction",
self._prep_send_transaction,
- transaction_id, destination, origin_server_ts, pdu_list
+ transaction_id, destination, origin_server_ts
)
def _prep_send_transaction(self, txn, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
# First we find out what the prev_txs should be.
# Since we know that we are only sending one transaction at a time,
@@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
# Update the tx id -> pdu id mapping
- values = [
- (transaction_id, destination, pdu[0], pdu[1])
- for pdu in pdu_list
- ]
-
- logger.debug("Inserting: %s", repr(values))
-
- query = TransactionsToPduTable.insert_statement()
- txn.executemany(query, values)
+ # values = [
+ # (transaction_id, destination, pdu[0], pdu[1])
+ # for pdu in pdu_list
+ # ]
+ #
+ # logger.debug("Inserting: %s", repr(values))
+ #
+ # query = TransactionsToPduTable.insert_statement()
+ # txn.executemany(query, values)
return prev_txns
@@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.runInteraction(
+ "delivered_txn",
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self.runInteraction(
+ "get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
@@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
- def get_pdus_after_transaction(self, transaction_id, destination):
- """For a given local transaction_id that we sent to a given destination
- home server, return a list of PDUs that were sent to that destination
- after it.
-
- Args:
- txn
- transaction_id (str)
- destination (str)
-
- Returns
- list: A list of PduTuple
- """
- return self.runInteraction(
- self._get_pdus_after_transaction,
- transaction_id, destination
- )
-
- def _get_pdus_after_transaction(self, txn, transaction_id, destination):
-
- # Query that first get's all transaction_ids with an id greater than
- # the one given from the `sent_transactions` table. Then JOIN on this
- # from the `tx->pdu` table to get a list of (pdu_id, origin) that
- # specify the pdus that were sent in those transactions.
- query = (
- "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
- "INNER JOIN %(sent_tx)s as st "
- "ON tp.transaction_id = st.transaction_id "
- "AND tp.destination = st.destination "
- "WHERE st.id > ("
- "SELECT id FROM %(sent_tx)s "
- "WHERE transaction_id = ? AND destination = ?"
- ) % {
- "tx_pdu": TransactionsToPduTable.table_name,
- "sent_tx": SentTransactions.table_name,
- }
-
- txn.execute(query, (transaction_id, destination))
-
- pdus = PdusTable.decode_results(txn.fetchall())
-
- return self._get_pdu_tuples(txn, pdus)
-
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"
diff --git a/synapse/types.py b/synapse/types.py
index c51bc8e4f2..649ff2f7d7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -78,6 +78,11 @@ class DomainSpecificString(
"""Create a structure on the local domain"""
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
+ @classmethod
+ def create(cls, localpart, domain, hs):
+ is_mine = domain == hs.hostname
+ return cls(localpart=localpart, domain=domain, is_mine=is_mine)
+
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
@@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
SIGIL = "!"
+class EventID(DomainSpecificString):
+ """Structure representing an event id. """
+ SIGIL = "$"
+
+
class StreamToken(
namedtuple(
"Token",
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 647ea6142c..bf578f8bfb 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -21,3 +21,10 @@ def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
return d
+
+
+def run_on_reactor():
+ """ This will cause the rest of the function to be invoked upon the next
+ iteration of the main loop
+ """
+ return sleep(0)
\ No newline at end of file
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index c91eb897a8..e79b68f661 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -80,7 +80,7 @@ class JsonEncodedObject(object):
def get_full_dict(self):
d = {
- k: v for (k, v) in self.__dict__.items()
+ k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys
}
d.update(self.unrecognized_keys)
|