diff --git a/CHANGES.rst b/CHANGES.rst
index 08efbbf244..78c178bafd 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,19 @@
+Changes in synapse 0.4.2 (2014-10-31)
+=====================================
+
+Homeserver:
+ * Fix bugs where we did not notify users of correct presence updates.
+ * Fix bug where we did not handle sub second event stream timeouts.
+
+Webclient:
+ * Add ability to click on messages to see JSON.
+ * Add ability to redact messages.
+ * Add ability to view and edit all room state JSON.
+ * Handle incoming redactions.
+ * Improve feedback on errors.
+ * Fix bugs in mobile CSS.
+ * Fix bugs with desktop notifications.
+
Changes in synapse 0.4.1 (2014-10-17)
=====================================
Webclient:
diff --git a/VERSION b/VERSION
index 267577d47e..2b7c5ae018 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.4.1
+0.4.2
diff --git a/demo/start.sh b/demo/start.sh
index fc6cd6303f..0530f0a26e 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -32,7 +32,7 @@ for port in 8080 8081 8082; do
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
- $PARAMS
+ $PARAMS $SYNAPSE_PARAMS
python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7067188c5b..23ae5f003f 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a synapse home server.
"""
-__version__ = "0.4.1"
+__version__ = "0.4.2"
diff --git a/synapse/api/__init__.py b/synapse/api/__init__.py
index 9bff9ec169..f9811bfa04 100644
--- a/synapse/api/__init__.py
+++ b/synapse/api/__init__.py
@@ -12,4 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e1b1823cd7..c684265101 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -21,6 +21,8 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
+ RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent,
+ RoomCreateEvent,
)
from synapse.util.logutils import log_function
@@ -47,42 +49,60 @@ class Auth(object):
"""
try:
if hasattr(event, "room_id"):
+ if event.old_state_events is None:
+ # Oh, we don't know what the state of the room was, so we
+ # are trusting that this is allowed (at least for now)
+ defer.returnValue(True)
+
+ if hasattr(event, "outlier") and event.outlier is True:
+ # TODO (erikj): Auth for outliers is done differently.
+ defer.returnValue(True)
+
is_state = hasattr(event, "state_key")
+ if event.type == RoomCreateEvent.TYPE:
+ # FIXME
+ defer.returnValue(True)
+
if event.type == RoomMemberEvent.TYPE:
- yield self._can_replace_state(event)
- allowed = yield self.is_membership_change_allowed(event)
+ self._can_replace_state(event)
+ allowed = self.is_membership_change_allowed(event)
+ if allowed:
+ logger.debug("Allowing! %s", event)
+ else:
+ logger.debug("Denying! %s", event)
defer.returnValue(allowed)
return
- self._check_joined_room(
- member=snapshot.membership_state,
- user_id=snapshot.user_id,
- room_id=snapshot.room_id,
- )
+ if not event.type == InviteJoinEvent.TYPE:
+ self.check_event_sender_in_room(event)
if is_state:
# TODO (erikj): This really only should be called for *new*
# state
yield self._can_add_state(event)
- yield self._can_replace_state(event)
+ self._can_replace_state(event)
else:
yield self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
- yield self._check_power_levels(event)
+ self._check_power_levels(event)
if event.type == RoomRedactionEvent.TYPE:
- yield self._check_redaction(event)
+ self._check_redaction(event)
+
+ logger.debug("Allowing! %s", event)
defer.returnValue(True)
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e:
logger.info("Event auth check failed on event %s with msg: %s",
event, e.msg)
+ logger.info("Denying! %s", event)
if raises:
raise e
+
defer.returnValue(False)
@defer.inlineCallbacks
@@ -98,45 +118,72 @@ class Auth(object):
pass
defer.returnValue(None)
+ def check_event_sender_in_room(self, event):
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ member_event = event.state_events.get(key)
+
+ return self._check_joined_room(
+ member_event,
+ event.user_id,
+ event.room_id
+ )
+
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
- @defer.inlineCallbacks
+ @log_function
def is_membership_change_allowed(self, event):
target_user_id = event.state_key
- # does this room even exist
- room = yield self.store.get_room(event.room_id)
- if not room:
- raise AuthError(403, "Room does not exist")
-
# get info about the caller
- try:
- caller = yield self.store.get_room_member(
- user_id=event.user_id,
- room_id=event.room_id)
- except:
- caller = None
+ key = (RoomMemberEvent.TYPE, event.user_id, )
+ caller = event.old_state_events.get(key)
+
caller_in_room = caller and caller.membership == "join"
# get info about the target
- try:
- target = yield self.store.get_room_member(
- user_id=target_user_id,
- room_id=event.room_id)
- except:
- target = None
+ key = (RoomMemberEvent.TYPE, target_user_id, )
+ target = event.old_state_events.get(key)
+
target_in_room = target and target.membership == "join"
membership = event.content["membership"]
- join_rule = yield self.store.get_room_join_rule(event.room_id)
- if not join_rule:
+ key = (RoomJoinRulesEvent.TYPE, "", )
+ join_rule_event = event.old_state_events.get(key)
+ if join_rule_event:
+ join_rule = join_rule_event.content.get(
+ "join_rule", JoinRules.INVITE
+ )
+ else:
join_rule = JoinRules.INVITE
+ user_level = self._get_power_level_from_event_state(
+ event,
+ event.user_id,
+ )
+
+ ban_level, kick_level, redact_level = (
+ self._get_ops_level_from_event_state(
+ event
+ )
+ )
+
+ logger.debug(
+ "is_membership_change_allowed: %s",
+ {
+ "caller_in_room": caller_in_room,
+ "target_in_room": target_in_room,
+ "membership": membership,
+ "join_rule": join_rule,
+ "target_user_id": target_user_id,
+ "event.user_id": event.user_id,
+ }
+ )
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -153,13 +200,10 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
- elif join_rule == JoinRules.PUBLIC or room.is_public:
+ elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
- if (
- not caller or caller.membership not in
- [Membership.INVITE, Membership.JOIN]
- ):
+ if not caller_in_room:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
@@ -171,29 +215,16 @@ class Auth(object):
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(403, "You are not in room %s." % event.room_id)
elif target_user_id != event.user_id:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
- _, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
-
if kick_level:
kick_level = int(kick_level)
else:
- kick_level = 50
+ kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- user_level = yield self.store.get_power_level(
- event.room_id,
- event.user_id,
- )
-
- ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
-
if ban_level:
ban_level = int(ban_level)
else:
@@ -204,7 +235,30 @@ class Auth(object):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- defer.returnValue(True)
+ return True
+
+ def _get_power_level_from_event_state(self, event, user_id):
+ key = (RoomPowerLevelsEvent.TYPE, "", )
+ power_level_event = event.old_state_events.get(key)
+ level = None
+ if power_level_event:
+ level = power_level_event.content.get(user_id)
+ if not level:
+ level = power_level_event.content.get("default", 0)
+
+ return level
+
+ def _get_ops_level_from_event_state(self, event):
+ key = (RoomOpsPowerLevelsEvent.TYPE, "", )
+ ops_event = event.old_state_events.get(key)
+
+ if ops_event:
+ return (
+ ops_event.content.get("ban_level"),
+ ops_event.content.get("kick_level"),
+ ops_event.content.get("redact_level"),
+ )
+ return None, None, None,
@defer.inlineCallbacks
def get_user_by_req(self, request):
@@ -282,8 +336,8 @@ class Auth(object):
else:
send_level = 0
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -308,8 +362,8 @@ class Auth(object):
add_level = int(add_level)
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -322,19 +376,9 @@ class Auth(object):
defer.returnValue(True)
- @defer.inlineCallbacks
def _can_replace_state(self, event):
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
-
- if current_state:
- current_state = current_state[0]
-
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -346,6 +390,10 @@ class Auth(object):
logger.debug(
"Checking power level for %s, %s", event.user_id, user_level
)
+
+ key = (event.type, event.state_key, )
+ current_state = event.old_state_events.get(key)
+
if current_state and hasattr(current_state, "required_power_level"):
req = current_state.required_power_level
@@ -356,10 +404,9 @@ class Auth(object):
"You don't have permission to change that state"
)
- @defer.inlineCallbacks
def _check_redaction(self, event):
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
@@ -368,7 +415,9 @@ class Auth(object):
else:
user_level = 0
- _, _, redact_level = yield self.store.get_ops_levels(event.room_id)
+ _, _, redact_level = self._get_ops_level_from_event_state(
+ event
+ )
if not redact_level:
redact_level = 50
@@ -379,7 +428,6 @@ class Auth(object):
"You don't have permission to redact events"
)
- @defer.inlineCallbacks
def _check_power_levels(self, event):
for k, v in event.content.items():
if k == "default":
@@ -399,19 +447,16 @@ class Auth(object):
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
- current_state = yield self.store.get_current_state(
- event.room_id,
- event.type,
- event.state_key,
- )
+ key = (event.type, event.state_key, )
+ current_state = event.old_state_events.get(key)
if not current_state:
return
else:
current_state = current_state[0]
- user_level = yield self.store.get_power_level(
- event.room_id,
+ user_level = self._get_power_level_from_event_state(
+ event,
event.user_id,
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 618d3d7577..3cafff0e32 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -58,4 +58,4 @@ class LoginType(object):
EMAIL_CODE = u"m.login.email.code"
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
- RECAPTCHA = u"m.login.recaptcha"
\ No newline at end of file
+ RECAPTCHA = u"m.login.recaptcha"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6d7d499fea..38ccb4f9d1 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -54,7 +54,7 @@ class SynapseError(CodeMessageException):
"""Constructs a synapse error.
Args:
- code (int): The integer error code (typically an HTTP response code)
+ code (int): The integer error code (an HTTP response code)
msg (str): The human-readable error message.
err (str): The error code e.g 'M_FORBIDDEN'
"""
@@ -67,6 +67,7 @@ class SynapseError(CodeMessageException):
self.errcode,
)
+
class RoomError(SynapseError):
"""An error raised when a room event fails."""
pass
@@ -117,6 +118,7 @@ class InvalidCaptchaError(SynapseError):
error_url=self.error_url,
)
+
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index a5a55742e0..b855811b98 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -71,7 +71,9 @@ class SynapseEvent(JsonEncodedObject):
"outlier",
"power_level",
"redacted",
- "prev_pdus",
+ "prev_events",
+ "hashes",
+ "signatures",
]
required_keys = [
diff --git a/synapse/api/events/factory.py b/synapse/api/events/factory.py
index 74d0ef77f4..9134c82eff 100644
--- a/synapse/api/events/factory.py
+++ b/synapse/api/events/factory.py
@@ -21,6 +21,8 @@ from synapse.api.events.room import (
RoomRedactionEvent,
)
+from synapse.types import EventID
+
from synapse.util.stringutils import random_string
@@ -51,12 +53,26 @@ class EventFactory(object):
self.clock = hs.get_clock()
self.hs = hs
+ self.event_id_count = 0
+
+ def create_event_id(self):
+ i = str(self.event_id_count)
+ self.event_id_count += 1
+
+ local_part = str(int(self.clock.time())) + i + random_string(5)
+
+ e_id = EventID.create_local(local_part, self.hs)
+
+ return e_id.to_string()
+
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
- kwargs["event_id"] = "%s@%s" % (
- random_string(10), self.hs.hostname
- )
+ kwargs["event_id"] = self.create_event_id()
+ kwargs["origin"] = self.hs.hostname
+ else:
+ ev_id = self.hs.parse_eventid(kwargs["event_id"])
+ kwargs["origin"] = ev_id.domain
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
diff --git a/synapse/api/events/utils.py b/synapse/api/events/utils.py
index 7fdf45a264..31601fd3a9 100644
--- a/synapse/api/events/utils.py
+++ b/synapse/api/events/utils.py
@@ -32,7 +32,7 @@ def prune_event(event):
def prune_pdu(pdu):
"""Removes keys that contain unrestricted and non-essential data from a PDU
"""
- return _prune_event_or_pdu(pdu.pdu_type, pdu)
+ return _prune_event_or_pdu(pdu.type, pdu)
def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields.
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 9bff9ec169..f9811bfa04 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -12,4 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 6394bc27d1..a20376b9d6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -233,7 +233,10 @@ def setup():
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
- hs.start_listening(config.bind_port, config.unsecure_port)
+ bind_port = config.bind_port
+ if config.no_tls:
+ bind_port = None
+ hs.start_listening(bind_port, config.unsecure_port)
if config.daemonize:
print config.pid_file
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index b3aeff327c..8ebd2eba4a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -116,18 +116,25 @@ class Config(object):
config = {}
for key, value in vars(args).items():
if (key not in set(["config_path", "generate_config"])
- and value is not None):
+ and value is not None):
config[key] = value
with open(config_args.config_path, "w") as config_file:
# TODO(paul) it would be lovely if we wrote out vim- and emacs-
# style mode markers into the file, to hint to people that
# this is a YAML file.
yaml.dump(config, config_file, default_flow_style=False)
- print "A config file has been generated in %s for server name '%s') with corresponding SSL keys and self-signed certificates. Please review this file and customise it to your needs." % (config_args.config_path, config['server_name'])
- print "If this server name is incorrect, you will need to regenerate the SSL certificates"
+ print (
+ "A config file has been generated in %s for server name"
+ " '%s' with corresponding SSL keys and self-signed"
+ " certificates. Please review this file and customise it to"
+ " your needs."
+ ) % (
+ config_args.config_path, config['server_name']
+ )
+ print (
+ "If this server name is incorrect, you will need to regenerate"
+ " the SSL certificates"
+ )
sys.exit(0)
return cls(args)
-
-
-
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 460445f15d..0aac8c8382 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -16,6 +16,7 @@
from ._base import Config
import os
+
class DatabaseConfig(Config):
def __init__(self, args):
super(DatabaseConfig, self).__init__(args)
@@ -34,4 +35,3 @@ class DatabaseConfig(Config):
def generate_config(cls, args, config_dir_path):
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
args.database_path = os.path.abspath(args.database_path)
-
diff --git a/synapse/config/email.py b/synapse/config/email.py
index 9bcc5a8fea..6bab133224 100644
--- a/synapse/config/email.py
+++ b/synapse/config/email.py
@@ -35,5 +35,8 @@ class EmailConfig(Config):
email_group.add_argument(
"--email-smtp-server",
default="",
- help="The SMTP server to send emails from (e.g. for password resets)."
- )
\ No newline at end of file
+ help=(
+ "The SMTP server to send emails from (e.g. for password"
+ " resets)."
+ )
+ )
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 56cd095433..05611d02f7 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -19,6 +19,7 @@ from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
+
class LoggingConfig(Config):
def __init__(self, args):
super(LoggingConfig, self).__init__(args)
@@ -51,7 +52,7 @@ class LoggingConfig(Config):
level = logging.INFO
if self.verbosity:
- level = logging.DEBUG
+ level = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index f126782b8d..fb63ed7d9b 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -14,6 +14,7 @@
from ._base import Config
+
class RatelimitConfig(Config):
def __init__(self, args):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index b71d30227c..743bc26474 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -15,6 +15,7 @@
from ._base import Config
+
class ContentRepositoryConfig(Config):
def __init__(self, args):
super(ContentRepositoryConfig, self).__init__(args)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 086937044f..814a4c349b 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -30,11 +30,12 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file)
self.webclient = True
self.manhole = args.manhole
+ self.no_tls = args.no_tls
if not args.content_addr:
host = args.server_name
if ':' not in host:
- host = "%s:%d" % (host, args.bind_port)
+ host = "%s:%d" % (host, args.bind_port)
args.content_addr = "https://%s" % (host,)
self.content_addr = args.content_addr
@@ -67,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
+ server_group.add_argument("--no-tls", action='store_true',
+ help="Don't bind to the https port.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 72d5518a89..3600c3ea9e 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -19,7 +19,7 @@ from OpenSSL import crypto
import subprocess
import os
-GENERATE_DH_PARAMS=False
+GENERATE_DH_PARAMS = False
class TlsConfig(Config):
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 3a51664f46..06675966ce 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -33,7 +33,10 @@ class VoipConfig(Config):
)
group.add_argument(
"--turn-shared-secret", type=str, default=None,
- help="The shared secret used to compute passwords for the TURN server"
+ help=(
+ "The shared secret used to compute passwords for the TURN"
+ " server"
+ )
)
group.add_argument(
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
diff --git a/synapse/crypto/__init__.py b/synapse/crypto/__init__.py
index 9bff9ec169..f9811bfa04 100644
--- a/synapse/crypto/__init__.py
+++ b/synapse/crypto/__init__.py
@@ -12,4 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index f402c795bb..3143322d9c 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -20,6 +20,7 @@ import logging
logger = logging.getLogger(__name__)
+
class ServerContextFactory(ssl.ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers."""
@@ -43,4 +44,3 @@ class ServerContextFactory(ssl.ContextFactory):
def getContext(self):
return self._context
-
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 61edd2c6f9..0e8bc7eb6c 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -16,11 +16,12 @@
from synapse.federation.units import Pdu
-from synapse.api.events.utils import prune_pdu
+from synapse.api.events.utils import prune_pdu, prune_event
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json, verify_signed_json
+import copy
import hashlib
import logging
@@ -69,6 +70,16 @@ def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256):
return (hashed.name, hashed.digest())
+def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+ tmp_event = copy.deepcopy(event)
+ tmp_event = prune_event(tmp_event)
+ event_json = tmp_event.get_dict()
+ event_json.pop("signatures", None)
+ event_json_bytes = encode_canonical_json(event_json)
+ hashed = hash_algorithm(event_json_bytes)
+ return (hashed.name, hashed.digest())
+
+
def sign_event_pdu(pdu, signature_name, signing_key):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
@@ -83,3 +94,25 @@ def verify_signed_event_pdu(pdu, signature_name, verify_key):
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
verify_signed_json(pdu_json, signature_name, verify_key)
+
+
+def add_hashes_and_signatures(event, signature_name, signing_key,
+ hash_algorithm=hashlib.sha256):
+ tmp_event = copy.deepcopy(event)
+ tmp_event = prune_event(tmp_event)
+ redact_json = tmp_event.get_dict()
+ redact_json.pop("signatures", None)
+ redact_json = sign_json(redact_json, signature_name, signing_key)
+ event.signatures = redact_json["signatures"]
+
+ event_json = event.get_full_dict()
+ #TODO: We need to sign the JSON that is going out via fedaration.
+ event_json.pop("age_ts", None)
+ event_json.pop("unsigned", None)
+ event_json.pop("signatures", None)
+ event_json.pop("hashes", None)
+ event_json_bytes = encode_canonical_json(event_json)
+ hashed = hash_algorithm(event_json_bytes)
+ if not hasattr(event, "hashes"):
+ event.hashes = {}
+ event.hashes[hashed.name] = encode_base64(hashed.digest())
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 7cfec5148e..5191be4570 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -98,4 +98,3 @@ class SynapseKeyClientProtocol(HTTPClient):
class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol
-
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 2440d604c3..694aed3a7d 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -44,7 +44,7 @@ class Keyring(object):
raise SynapseError(
400,
"Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
+ Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
@@ -100,7 +100,7 @@ class Keyring(object):
)
if ("signatures" not in response
- or server_name not in response["signatures"]):
+ or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response:
diff --git a/synapse/federation/pdu_codec.py b/synapse/federation/pdu_codec.py
index 991aae2a56..5ec97a698e 100644
--- a/synapse/federation/pdu_codec.py
+++ b/synapse/federation/pdu_codec.py
@@ -17,22 +17,11 @@ from .units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, sign_event_pdu
)
+from synapse.types import EventID
import copy
-def decode_event_id(event_id, server_name):
- parts = event_id.split("@")
- if len(parts) < 2:
- return (event_id, server_name)
- else:
- return (parts[0], "".join(parts[1:]))
-
-
-def encode_event_id(pdu_id, origin):
- return "%s@%s" % (pdu_id, origin)
-
-
class PduCodec(object):
def __init__(self, hs):
@@ -40,30 +29,18 @@ class PduCodec(object):
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
+ self.hs = hs
def event_from_pdu(self, pdu):
kwargs = {}
- kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
- kwargs["room_id"] = pdu.context
- kwargs["etype"] = pdu.pdu_type
- kwargs["prev_pdus"] = pdu.prev_pdus
-
- if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
- kwargs["prev_state"] = encode_event_id(
- pdu.prev_state_id, pdu.prev_state_origin
- )
+ kwargs["etype"] = pdu.type
kwargs.update({
k: v
for k, v in pdu.get_full_dict().items()
if k not in [
- "pdu_id",
- "context",
- "pdu_type",
- "prev_pdus",
- "prev_state_id",
- "prev_state_origin",
+ "type",
]
})
@@ -72,27 +49,12 @@ class PduCodec(object):
def pdu_from_event(self, event):
d = event.get_full_dict()
- d["pdu_id"], d["origin"] = decode_event_id(
- event.event_id, self.server_name
- )
- d["context"] = event.room_id
- d["pdu_type"] = event.type
-
- if hasattr(event, "prev_pdus"):
- d["prev_pdus"] = event.prev_pdus
-
- if hasattr(event, "prev_state"):
- d["prev_state_id"], d["prev_state_origin"] = (
- decode_event_id(event.prev_state, self.server_name)
- )
-
if hasattr(event, "state_key"):
d["is_state"] = True
kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({
k: v for k, v in d.items()
- if k not in ["event_id", "room_id", "type"]
})
if "origin_server_ts" not in kwargs:
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 7043fcc504..b04fbb4177 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -32,76 +32,6 @@ import logging
logger = logging.getLogger(__name__)
-class PduActions(object):
- """ Defines persistence actions that relate to handling PDUs.
- """
-
- def __init__(self, datastore):
- self.store = datastore
-
- @log_function
- def mark_as_processed(self, pdu):
- """ Persist the fact that we have fully processed the given `Pdu`
-
- Returns:
- Deferred
- """
- return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
-
- @defer.inlineCallbacks
- @log_function
- def after_transaction(self, transaction_id, destination, origin):
- """ Returns all `Pdu`s that we sent to the given remote home server
- after a given transaction id.
-
- Returns:
- Deferred: Results in a list of `Pdu`s
- """
- results = yield self.store.get_pdus_after_transaction(
- transaction_id,
- destination
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def get_all_pdus_from_context(self, context):
- results = yield self.store.get_all_pdus_from_context(context)
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @defer.inlineCallbacks
- @log_function
- def backfill(self, context, pdu_list, limit):
- """ For a given list of PDU id and origins return the proceeding
- `limit` `Pdu`s in the given `context`.
-
- Returns:
- Deferred: Results in a list of `Pdu`s.
- """
- results = yield self.store.get_backfill(
- context, pdu_list, limit
- )
-
- defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
-
- @log_function
- def is_new(self, pdu):
- """ When we receive a `Pdu` from a remote home server, we want to
- figure out whether it is `new`, i.e. it is not some historic PDU that
- we haven't seen simply because we haven't backfilled back that far.
-
- Returns:
- Deferred: Results in a `bool`
- """
- return self.store.is_pdu_new(
- pdu_id=pdu.pdu_id,
- origin=pdu.origin,
- context=pdu.context,
- depth=pdu.depth
- )
-
-
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
"""
@@ -158,7 +88,6 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
transaction.origin_server_ts,
- [(p["pdu_id"], p["origin"]) for p in transaction.pdus]
)
@log_function
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 4a9414c1d4..838e660a46 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from .units import Transaction, Pdu, Edu
-from .persistence import PduActions, TransactionActions
+from .persistence import TransactionActions
from synapse.util.logutils import log_function
@@ -57,7 +57,7 @@ class ReplicationLayer(object):
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
- self.pdu_actions = PduActions(self.store)
+ # self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
@@ -106,20 +106,11 @@ class ReplicationLayer(object):
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
- This will fill out various attributes on the PDU object, e.g. the
- `prev_pdus` key.
-
- *Note:* The home server should always call `send_pdu` even if it knows
- that it does not need to be replicated to other home servers. This is
- in case e.g. someone else joins via a remote home server and then
- backfills.
-
TODO: Figure out when we should actually resolve the deferred.
Args:
@@ -132,18 +123,12 @@ class ReplicationLayer(object):
order = self._order
self._order += 1
- logger.debug("[%s] Persisting PDU", pdu.pdu_id)
-
- # Save *before* trying to send
- yield self.store.persist_event(pdu=pdu)
-
- logger.debug("[%s] Persisted PDU", pdu.pdu_id)
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
- logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
+ logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.event_id)
@log_function
def send_edu(self, destination, edu_type, content):
@@ -181,7 +166,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def backfill(self, dest, context, limit):
+ def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
@@ -189,12 +174,12 @@ class ReplicationLayer(object):
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
- extremities = yield self.store.get_oldest_pdus_in_context(context)
-
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
@@ -216,7 +201,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
+ def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
@@ -225,7 +210,7 @@ class ReplicationLayer(object):
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
- pdu_id (str)
+ event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@@ -234,8 +219,9 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU.
"""
- transaction_data = yield self.transport_layer.get_pdu(
- destination, pdu_origin, pdu_id)
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
transaction = Transaction(**transaction_data)
@@ -244,13 +230,13 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
- def get_state_for_context(self, destination, context):
+ def get_state_for_context(self, destination, context, event_id=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@@ -263,29 +249,32 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
- destination, context)
+ destination,
+ context,
+ event_id=event_id,
+ )
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
- yield self._handle_new_pdu(pdu)
+ yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def on_context_pdus_request(self, context):
- pdus = yield self.pdu_actions.get_all_pdus_from_context(
- context
+ raise NotImplementedError(
+ "on_context_pdus_request is a security violation"
)
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, context, versions, limit):
-
- pdus = yield self.pdu_actions.backfill(context, versions, limit)
+ pdus = yield self.handler.on_backfill_request(
+ context, versions, limit
+ )
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@@ -319,7 +308,7 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
- dl.append(self._handle_new_pdu(pdu))
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
@@ -351,20 +340,26 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context):
- results = yield self.store.get_current_state_for_context(
- context
- )
-
- logger.debug("Context returning %d results", len(results))
+ def on_context_state_request(self, context, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ event_id
+ )
+ else:
+ raise NotImplementedError("Specify an event")
+ # results = yield self.store.get_current_state_for_context(
+ # context
+ # )
+ # pdus = [Pdu.from_pdu_tuple(p) for p in results]
+ #
+ # logger.debug("Context returning %d results", len(pdus))
- pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
- def on_pdu_request(self, pdu_origin, pdu_id):
- pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
+ def on_pdu_request(self, event_id):
+ pdu = yield self._get_persisted_pdu(event_id)
if pdu:
defer.returnValue(
@@ -376,20 +371,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
- transaction_id = max([int(v) for v in versions])
-
- response = yield self.pdu_actions.after_transaction(
- transaction_id,
- origin,
- self.server_name
- )
-
- if not response:
- response = []
-
- defer.returnValue(
- (200, self._transaction_from_pdus(response).get_dict())
- )
+ raise NotImplementedError("Pull transacions not implemented")
+
+ # transaction_id = max([int(v) for v in versions])
+ #
+ # response = yield self.pdu_actions.after_transaction(
+ # transaction_id,
+ # origin,
+ # self.server_name
+ # )
+ #
+ # if not response:
+ # response = []
+ #
+ # defer.returnValue(
+ # (200, self._transaction_from_pdus(response).get_dict())
+ # )
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
@@ -397,21 +394,63 @@ class ReplicationLayer(object):
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
- defer.returnValue((404, "No handler for Query type '%s'"
- % (query_type)
- ))
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type, ))
+ )
@defer.inlineCallbacks
+ def on_make_join_request(self, context, user_id):
+ pdu = yield self.handler.on_make_join_request(context, user_id)
+ defer.returnValue(pdu.get_dict())
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = Pdu(**content)
+ ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, ret_pdu.get_dict()))
+
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ pdu = Pdu(**content)
+ state = yield self.handler.on_send_join_request(origin, pdu)
+ defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, context, user_id):
+ pdu_dict = yield self.transport_layer.make_join(
+ destination=destination,
+ context=context,
+ user_id=user_id,
+ )
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(Pdu(**pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ _, content = yield self.transport_layer.send_join(
+ destination,
+ pdu.room_id,
+ pdu.event_id,
+ pdu.get_dict(),
+ )
+
+ logger.debug("Got content: %s", content)
+ pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
+ for pdu in pdus:
+ yield self._handle_new_pdu(destination, pdu)
+
+ defer.returnValue(pdus)
+
@log_function
- def _get_persisted_pdu(self, pdu_id, pdu_origin):
+ def _get_persisted_pdu(self, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
- pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
-
- defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
+ return self.handler.get_persisted_pdu(event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
@@ -433,48 +472,60 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def _handle_new_pdu(self, pdu, backfilled=False):
+ def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
+ existing = yield self._get_persisted_pdu(pdu.event_id)
if existing and (not existing.outlier or pdu.outlier):
- logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
+ logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
+ state = None
+
# Get missing pdus if necessary.
- is_new = yield self.pdu_actions.is_new(pdu)
- if is_new and not pdu.outlier:
+ if not pdu.outlier:
# We only backfill backwards to the min depth.
- min_depth = yield self.store.get_min_depth_for_context(pdu.context)
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
if min_depth and pdu.depth > min_depth:
- for pdu_id, origin, hashes in pdu.prev_pdus:
- exists = yield self._get_persisted_pdu(pdu_id, origin)
+ for event_id, hashes in pdu.prev_events:
+ exists = yield self._get_persisted_pdu(event_id)
if not exists:
- logger.debug("Requesting pdu %s %s", pdu_id, origin)
+ logger.debug("Requesting pdu %s", event_id)
try:
yield self.get_pdu(
pdu.origin,
- pdu_id=pdu_id,
- pdu_origin=origin
+ event_id=event_id,
)
- logger.debug("Processed pdu %s %s", pdu_id, origin)
+ logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
+ else:
+ # We need to get the state at this event, since we have reached
+ # a backward extremity edge.
+ state = yield self.get_state_for_context(
+ origin, pdu.room_id, pdu.event_id,
+ )
# Persist the Pdu, but don't mark it as processed yet.
- yield self.store.persist_event(pdu=pdu)
+ # yield self.store.persist_event(pdu=pdu)
if not backfilled:
- ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
+ ret = yield self.handler.on_receive_pdu(
+ pdu,
+ backfilled=backfilled,
+ state=state,
+ )
else:
ret = None
- yield self.pdu_actions.mark_as_processed(pdu)
+ # yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index e7517cac4d..04ad7e63ae 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -72,7 +72,7 @@ class TransportLayer(object):
self.received_handler = None
@log_function
- def get_context_state(self, destination, context):
+ def get_context_state(self, destination, context, event_id=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@@ -89,54 +89,62 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
- return self._do_request_for_transaction(destination, subpath)
+ args = {}
+ if event_id:
+ args["event_id"] = event_id
+
+ return self._do_request_for_transaction(
+ destination, subpath, args=args
+ )
@log_function
- def get_pdu(self, destination, pdu_origin, pdu_id):
+ def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
- pdu_origin (str): The home server which created the PDU.
- pdu_id (str): The id of the PDU being requested.
+ event_id (str): The id of the event being requested.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
- destination, pdu_origin, pdu_id)
+ logger.debug("get_pdu dest=%s, event_id=%s",
+ destination, event_id)
- subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
+ subpath = "/event/%s/" % (event_id, )
return self._do_request_for_transaction(destination, subpath)
@log_function
- def backfill(self, dest, context, pdu_tuples, limit):
+ def backfill(self, dest, context, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
dest (str)
context (str)
- pdu_tuples (list)
+ event_tuples (list)
limt (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
- dest, context, repr(pdu_tuples), str(limit)
+ "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
+ dest, context, repr(event_tuples), str(limit)
)
- if not pdu_tuples:
+ if not event_tuples:
+ # TODO: raise?
return
- subpath = "/backfill/%s/" % context
+ subpath = "/backfill/%s/" % (context,)
- args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
- args["limit"] = limit
+ args = {
+ "v": event_tuples,
+ "limit": limit,
+ }
return self._do_request_for_transaction(
dest,
@@ -198,6 +206,57 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
+ @log_function
+ def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
+ path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
+
+ response = yield self.client.get_json(
+ destination=destination,
+ path=path,
+ retry_on_dns_fail=retry_on_dns_fail,
+ )
+
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_join(self, destination, context, event_id, content):
+ path = PREFIX + "/send_join/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_join", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite(self, destination, context, event_id, content):
+ path = PREFIX + "/invite/%s/%s" % (
+ context,
+ event_id,
+ )
+
+ code, content = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
+
+ @defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
@@ -313,10 +372,10 @@ class TransportLayer(object):
# data_id pair.
self.server.register_path(
"GET",
- re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
+ re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
- lambda origin, content, query, pdu_origin, pdu_id:
- handler.on_pdu_request(pdu_origin, pdu_id)
+ lambda origin, content, query, event_id:
+ handler.on_pdu_request(event_id)
)
)
@@ -326,7 +385,10 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
- handler.on_context_state_request(context)
+ handler.on_context_state_request(
+ context,
+ query.get("event_id", [None])[0],
+ )
)
)
@@ -362,6 +424,39 @@ class TransportLayer(object):
)
)
+ self.server.register_path(
+ "GET",
+ re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, user_id:
+ self._on_make_join_request(
+ origin, content, query, context, user_id
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_send_join_request(
+ origin, content, query,
+ )
+ )
+ )
+
+ self.server.register_path(
+ "PUT",
+ re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_invite_request(
+ origin, content, query,
+ )
+ )
+ )
+
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@@ -448,124 +543,34 @@ class TransportLayer(object):
limit = int(limits[-1])
- versions = [v.split(",", 1) for v in v_list]
+ versions = v_list
return self.request_handler.on_backfill_request(
- context, versions, limit)
-
-
-class TransportReceivedHandler(object):
- """ Callbacks used when we receive a transaction
- """
- def on_incoming_transaction(self, transaction):
- """ Called on PUT /send/<transaction_id>, or on response to a request
- that we sent (e.g. a backfill request)
-
- Args:
- transaction (synapse.transaction.Transaction): The transaction that
- was sent to us.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
-
-class TransportRequestHandler(object):
- """ Handlers used when someone want's data from us
- """
- def on_pull_request(self, versions):
- """ Called on GET /pull/?v=...
-
- This is hit when a remote home server wants to get all data
- after a given transaction. Mainly used when a home server comes back
- online and wants to get everything it has missed.
-
- Args:
- versions (list): A list of transaction_ids that should be used to
- determine what PDUs the remote side have not yet seen.
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_pdu_request(self, pdu_origin, pdu_id):
- """ Called on GET /pdu/<pdu_origin>/<pdu_id>/
-
- Someone wants a particular PDU. This PDU may or may not have originated
- from us.
-
- Args:
- pdu_origin (str)
- pdu_id (str)
-
- Returns:
- Deferred: Resultsin a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_context_state_request(self, context):
- """ Called on GET /state/<context>/
-
- Gets hit when someone wants all the *current* state for a given
- contexts.
-
- Args:
- context (str): The name of the context that we're interested in.
-
- Returns:
- twisted.internet.defer.Deferred: A deferred that gets fired when
- the transaction has finished being processed.
-
- The result should be a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
-
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
-
- def on_backfill_request(self, context, versions, limit):
- """ Called on GET /backfill/<context>/?v=...&limit=...
+ context, versions, limit
+ )
- Gets hit when we want to backfill backwards on a given context from
- the given point.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_make_join_request(self, origin, content, query, context, user_id):
+ content = yield self.request_handler.on_make_join_request(
+ context, user_id,
+ )
+ defer.returnValue((200, content))
- Args:
- context (str): The context to backfill
- versions (list): A list of 2-tuples representing where to backfill
- from, in the form `(pdu_id, origin)`
- limit (int): How many pdus to return.
+ @defer.inlineCallbacks
+ @log_function
+ def _on_send_join_request(self, origin, content, query):
+ content = yield self.request_handler.on_send_join_request(
+ origin, content,
+ )
- Returns:
- Deferred: Results in a tuple in the form of
- `(response_code, respond_body)`, where `response_body` is a python
- dict that will get serialized to JSON.
+ defer.returnValue((200, content))
- On errors, the dict should have an `error` key with a brief message
- of what went wrong.
- """
- pass
+ @defer.inlineCallbacks
+ @log_function
+ def _on_invite_request(self, origin, content, query):
+ content = yield self.request_handler.on_invite_request(
+ origin, content,
+ )
- def on_query_request(self):
- """ Called on a GET /query/<query_type> request. """
+ defer.returnValue((200, content))
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index adc3385644..c94dcf64cf 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -34,13 +34,13 @@ class Pdu(JsonEncodedObject):
A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done
- via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
+ via a unique constraint on the tuple (context, type, state_key). A pdu
is a state pdu if `is_state` is True.
Example pdu::
{
- "pdu_id": "78c",
+ "event_id": "$78c:example.com",
"origin_server_ts": 1404835423000,
"origin": "bar",
"prev_ids": [
@@ -53,14 +53,14 @@ class Pdu(JsonEncodedObject):
"""
valid_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"destinations",
"transaction_id",
- "prev_pdus",
+ "prev_events",
"depth",
"content",
"outlier",
@@ -68,8 +68,7 @@ class Pdu(JsonEncodedObject):
"signatures",
"is_state", # Below this are keys valid only for State Pdus.
"state_key",
- "prev_state_id",
- "prev_state_origin",
+ "prev_state",
"required_power_level",
"user_id",
]
@@ -81,18 +80,18 @@ class Pdu(JsonEncodedObject):
]
required_keys = [
- "pdu_id",
- "context",
+ "event_id",
+ "room_id",
"origin",
"origin_server_ts",
- "pdu_type",
+ "type",
"content",
]
# TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!)
- def __init__(self, destinations=[], is_state=False, prev_pdus=[],
+ def __init__(self, destinations=[], is_state=False, prev_events=[],
outlier=False, hashes={}, signatures={}, **kwargs):
if is_state:
for required_key in ["state_key"]:
@@ -102,66 +101,13 @@ class Pdu(JsonEncodedObject):
super(Pdu, self).__init__(
destinations=destinations,
is_state=bool(is_state),
- prev_pdus=prev_pdus,
+ prev_events=prev_events,
outlier=outlier,
hashes=hashes,
signatures=signatures,
**kwargs
)
- @classmethod
- def from_pdu_tuple(cls, pdu_tuple):
- """ Converts a PduTuple to a Pdu
-
- Args:
- pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
- convert
-
- Returns:
- Pdu
- """
- if pdu_tuple:
- d = copy.copy(pdu_tuple.pdu_entry._asdict())
- d["origin_server_ts"] = d.pop("ts")
-
- for k in d.keys():
- if d[k] is None:
- del d[k]
-
- d["content"] = json.loads(d["content_json"])
- del d["content_json"]
-
- args = {f: d[f] for f in cls.valid_keys if f in d}
- if "unrecognized_keys" in d and d["unrecognized_keys"]:
- args.update(json.loads(d["unrecognized_keys"]))
-
- hashes = {
- alg: encode_base64(hsh)
- for alg, hsh in pdu_tuple.hashes.items()
- }
-
- signatures = {
- kid: encode_base64(sig)
- for kid, sig in pdu_tuple.signatures.items()
- }
-
- prev_pdus = []
- for prev_pdu in pdu_tuple.prev_pdu_list:
- prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
- prev_hashes = {
- alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
- }
- prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
-
- return Pdu(
- prev_pdus=prev_pdus,
- hashes=hashes,
- signatures=signatures,
- **args
- )
- else:
- return None
-
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index de4d23bbb3..28b64565ae 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -16,6 +16,10 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
+from synapse.util.async import run_on_reactor
+
+from synapse.crypto.event_signing import add_hashes_and_signatures
+
class BaseHandler(object):
def __init__(self, hs):
@@ -30,6 +34,9 @@ class BaseHandler(object):
self.clock = hs.get_clock()
self.hs = hs
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@@ -44,9 +51,23 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
- extra_users=[]):
+ extra_users=[], suppress_auth=False):
+ yield run_on_reactor()
+
snapshot.fill_out_prev_events(event)
+ yield self.state_handler.annotate_state_groups(event)
+
+ yield add_hashes_and_signatures(
+ event, self.server_name, self.signing_key
+ )
+
+ if not suppress_auth:
+ yield self.auth.check(event, snapshot, raises=True)
+
+ if hasattr(event, "state_key"):
+ yield self.state_handler.handle_new_event(event, snapshot)
+
yield self.store.persist_event(event)
destinations = set(extra_destinations)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a56830d520..6e897e915d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler):
user_id=user_id,
)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user_id], suppress_auth=True
+ )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f52591d2a3..bdd28f04bb 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,6 +22,8 @@ from synapse.api.constants import Membership
from synapse.util.logutils import log_function
from synapse.federation.pdu_codec import PduCodec
from synapse.api.errors import SynapseError
+from synapse.util.async import run_on_reactor
+from synapse.types import EventID
from twisted.internet import defer, reactor
@@ -62,6 +64,9 @@ class FederationHandler(BaseHandler):
self.pdu_codec = PduCodec(hs)
+ # When joining a room we need to queue any events for that room up
+ self.room_queues = {}
+
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
@@ -78,6 +83,8 @@ class FederationHandler(BaseHandler):
processing.
"""
+ yield run_on_reactor()
+
pdu = self.pdu_codec.pdu_from_event(event)
if not hasattr(pdu, "destinations") or not pdu.destinations:
@@ -87,98 +94,83 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
- def on_receive_pdu(self, pdu, backfilled):
+ def on_receive_pdu(self, pdu, backfilled, state=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
- do auth checks and put it throught the StateHandler.
+ do auth checks and put it through the StateHandler.
"""
event = self.pdu_codec.event_from_pdu(pdu)
logger.debug("Got event: %s", event.event_id)
- with (yield self.lock_manager.lock(pdu.context)):
- if event.is_state and not backfilled:
- is_new_state = yield self.state_handler.handle_new_state(
- pdu
- )
- else:
- is_new_state = False
+ if event.room_id in self.room_queues:
+ self.room_queues[event.room_id].append(pdu)
+ return
+
+ logger.debug("Processing event: %s", event.event_id)
+
+ if state:
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
+
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
+
+ logger.debug("Event: %s", event)
+
+ if not backfilled:
+ yield self.auth.check(event, None, raises=True)
+
+ is_new_state = is_new_state and not backfilled
+
# TODO: Implement something in federation that allows us to
# respond to PDU.
- target_is_mine = False
- if hasattr(event, "target_host"):
- target_is_mine = event.target_host == self.hs.hostname
-
- if event.type == InviteJoinEvent.TYPE:
- if not target_is_mine:
- logger.debug("Ignoring invite/join event %s", event)
- return
-
- # If we receive an invite/join event then we need to join the
- # sender to the given room.
- # TODO: We should probably auth this or some such
- content = event.content
- content.update({"membership": Membership.JOIN})
- new_event = self.event_factory.create_event(
- etype=RoomMemberEvent.TYPE,
- state_key=event.user_id,
- room_id=event.room_id,
- user_id=event.user_id,
- membership=Membership.JOIN,
- content=content
+ with (yield self.room_lock.lock(event.room_id)):
+ yield self.store.persist_event(
+ event,
+ backfilled,
+ is_new_state=is_new_state
)
- yield self.hs.get_handlers().room_member_handler.change_membership(
- new_event,
- do_auth=False,
- )
+ room = yield self.store.get_room(event.room_id)
- else:
- with (yield self.room_lock.lock(event.room_id)):
- yield self.store.persist_event(
- event,
- backfilled,
- is_new_state=is_new_state
+ if not room:
+ # Huh, let's try and get the current state
+ try:
+ yield self.replication_layer.get_state_for_context(
+ event.origin, event.room_id, event.event_id,
)
- room = yield self.store.get_room(event.room_id)
-
- if not room:
- # Huh, let's try and get the current state
- try:
- yield self.replication_layer.get_state_for_context(
- event.origin, event.room_id
- )
-
- hosts = yield self.store.get_joined_hosts_for_room(
- event.room_id
- )
- if self.hs.hostname in hosts:
- try:
- yield self.store.store_room(
- room_id=event.room_id,
- room_creator_user_id="",
- is_public=False,
- )
- except:
- pass
- except:
- logger.exception(
- "Failed to get current state for room %s",
- event.room_id
- )
-
- if not backfilled:
- extra_users = []
- if event.type == RoomMemberEvent.TYPE:
- target_user_id = event.state_key
- target_user = self.hs.parse_userid(target_user_id)
- extra_users.append(target_user)
-
- yield self.notifier.on_new_room_event(
- event, extra_users=extra_users
+ hosts = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+ if self.hs.hostname in hosts:
+ try:
+ yield self.store.store_room(
+ room_id=event.room_id,
+ room_creator_user_id="",
+ is_public=False,
+ )
+ except:
+ pass
+ except:
+ logger.exception(
+ "Failed to get current state for room %s",
+ event.room_id
)
+ if not backfilled:
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
+
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
@@ -189,13 +181,28 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit):
- pdus = yield self.replication_layer.backfill(dest, room_id, limit)
+ extremities = yield self.store.get_oldest_events_in_room(room_id)
+
+ pdus = yield self.replication_layer.backfill(
+ dest,
+ room_id,
+ limit,
+ extremities=[
+ self.pdu_codec.decode_event_id(e)
+ for e in extremities
+ ]
+ )
events = []
for pdu in pdus:
event = self.pdu_codec.event_from_pdu(pdu)
+
+ # FIXME (erikj): Not sure this actually works :/
+ yield self.state_handler.annotate_state_groups(event)
+
events.append(event)
+
yield self.store.persist_event(event, backfilled=True)
defer.returnValue(events)
@@ -203,62 +210,230 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
-
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# We are already in the room.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
- # First get current state to see if we are already joined.
+ pdu = yield self.replication_layer.make_join(
+ target_host,
+ room_id,
+ joinee
+ )
+
+ logger.debug("Got response to make_join: %s", pdu)
+
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ # We should assert some things.
+ assert(event.type == RoomMemberEvent.TYPE)
+ assert(event.user_id == joinee)
+ assert(event.state_key == joinee)
+ assert(event.room_id == room_id)
+
+ event.outlier = False
+
+ self.room_queues[room_id] = []
+
try:
- yield self.replication_layer.get_state_for_context(
- target_host, room_id
+ event.event_id = self.event_factory.create_event_id()
+ event.content = content
+
+ state = yield self.replication_layer.send_join(
+ target_host,
+ self.pdu_codec.pdu_from_event(event)
)
- hosts = yield self.store.get_joined_hosts_for_room(room_id)
- if self.hs.hostname in hosts:
- # Oh, we were actually in the room already.
- logger.debug("We're already in the room apparently")
- defer.returnValue(False)
- except Exception:
- logger.exception("Failed to get current state")
-
- new_event = self.event_factory.create_event(
- etype=InviteJoinEvent.TYPE,
- target_host=target_host,
- room_id=room_id,
- user_id=joinee,
- content=content
- )
+ state = [self.pdu_codec.event_from_pdu(p) for p in state]
- new_event.destinations = [target_host]
+ logger.debug("do_invite_join state: %s", state)
- snapshot.fill_out_prev_events(new_event)
- yield self.handle_new_event(new_event, snapshot)
+ is_new_state = yield self.state_handler.annotate_state_groups(
+ event,
+ old_state=state
+ )
- # TODO (erikj): Time out here.
- d = defer.Deferred()
- self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
- reactor.callLater(10, d.cancel)
+ logger.debug("do_invite_join event: %s", event)
- try:
- yield d
- except defer.CancelledError:
- raise SynapseError(500, "Unable to join remote room")
+ try:
+ yield self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=False
+ )
+ except:
+ # FIXME
+ pass
- try:
- yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
+ for e in state:
+ # FIXME: Auth these.
+ e.outlier = True
+
+ yield self.state_handler.annotate_state_groups(
+ e,
+ )
+
+ yield self.store.persist_event(
+ e,
+ backfilled=False,
+ is_new_state=False
+ )
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
)
- except:
- pass
+ finally:
+ room_queue = self.room_queues[room_id]
+ del self.room_queues[room_id]
+ for p in room_queue:
+ try:
+ yield self.on_receive_pdu(p, backfilled=False)
+ except:
+ pass
defer.returnValue(True)
+ @defer.inlineCallbacks
+ @log_function
+ def on_make_join_request(self, context, user_id):
+ event = self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ content={"membership": Membership.JOIN},
+ room_id=context,
+ user_id=user_id,
+ state_key=user_id,
+ )
+
+ snapshot = yield self.store.snapshot_room(
+ event.room_id, event.user_id,
+ )
+ snapshot.fill_out_prev_events(event)
+
+ yield self.state_handler.annotate_state_groups(event)
+ yield self.auth.check(event, None, raises=True)
+
+ pdu = self.pdu_codec.pdu_from_event(event)
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_send_join_request(self, origin, pdu):
+ event = self.pdu_codec.event_from_pdu(pdu)
+
+ event.outlier = False
+
+ is_new_state = yield self.state_handler.annotate_state_groups(event)
+ yield self.auth.check(event, None, raises=True)
+
+ # FIXME (erikj): All this is duplicated above :(
+
+ yield self.store.persist_event(
+ event,
+ backfilled=False,
+ is_new_state=is_new_state
+ )
+
+ extra_users = []
+ if event.type == RoomMemberEvent.TYPE:
+ target_user_id = event.state_key
+ target_user = self.hs.parse_userid(target_user_id)
+ extra_users.append(target_user)
+
+ yield self.notifier.on_new_room_event(
+ event, extra_users=extra_users
+ )
+
+ if event.type == RoomMemberEvent.TYPE:
+ if event.membership == Membership.JOIN:
+ user = self.hs.parse_userid(event.state_key)
+ self.distributor.fire(
+ "user_joined_room", user=user, room_id=event.room_id
+ )
+
+ new_pdu = self.pdu_codec.pdu_from_event(event);
+ new_pdu.destinations = yield self.store.get_joined_hosts_for_room(
+ event.room_id
+ )
+
+ yield self.replication_layer.send_pdu(new_pdu)
+
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in event.state_events.values()
+ ])
+
+ @defer.inlineCallbacks
+ def get_state_for_pdu(self, event_id):
+ yield run_on_reactor()
+
+ state_groups = yield self.store.get_state_groups(
+ [event_id]
+ )
+
+ if state_groups:
+ results = {
+ (e.type, e.state_key): e for e in state_groups[0].state
+ }
+
+ event = yield self.store.get_event(event_id)
+ if hasattr(event, "state_key"):
+ # Get previous state
+ if hasattr(event, "prev_state") and event.prev_state:
+ prev_event = yield self.store.get_event(event.prev_state)
+ results[(event.type, event.state_key)] = prev_event
+ else:
+ del results[(event.type, event.state_key)]
+
+ defer.returnValue(
+ [
+ self.pdu_codec.pdu_from_event(s)
+ for s in results.values()
+ ]
+ )
+ else:
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_backfill_request(self, context, pdu_list, limit):
+
+ events = yield self.store.get_backfill_events(
+ context,
+ pdu_list,
+ limit
+ )
+
+ defer.returnValue([
+ self.pdu_codec.pdu_from_event(e)
+ for e in events
+ ])
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_persisted_pdu(self, event_id):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ event = yield self.store.get_event(
+ event_id,
+ allow_none=True,
+ )
+
+ if event:
+ defer.returnValue(self.pdu_codec.pdu_from_event(event))
+ else:
+ defer.returnValue(None)
+
+ @log_function
+ def get_min_depth_for_context(self, context):
+ return self.store.get_min_depth(context)
@log_function
def _on_user_joined(self, user, room_id):
diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py
index 3f152e18f0..99d15261d4 100644
--- a/synapse/handlers/login.py
+++ b/synapse/handlers/login.py
@@ -54,7 +54,7 @@ class LoginHandler(BaseHandler):
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
- logger.warn("Attempted to login as %s but they do not exist.", user)
+ logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7b2b8549ed..c6f6ab14d1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -83,10 +83,9 @@ class MessageHandler(BaseHandler):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- if not suppress_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self._on_new_room_event(event, snapshot)
+ yield self._on_new_room_event(
+ event, snapshot, suppress_auth=suppress_auth
+ )
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
@@ -115,8 +114,12 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(user_id)
- events, next_token = yield data_source.get_pagination_rows(
- user, pagin_config, room_id
+ events, next_key = yield data_source.get_pagination_rows(
+ user, pagin_config.get_source_config("room"), room_id
+ )
+
+ next_token = pagin_config.from_token.copy_and_replace(
+ "room_key", next_key
)
chunk = {
@@ -145,10 +148,6 @@ class MessageHandler(BaseHandler):
state_key=event.state_key,
)
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
-
yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks
@@ -197,7 +196,7 @@ class MessageHandler(BaseHandler):
raise RoomError(
403, "Member does not meet private room rules.")
- data = yield self.store.get_current_state(
+ data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@@ -217,8 +216,6 @@ class MessageHandler(BaseHandler):
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
- yield self.auth.check(event, snapshot, raises=True)
-
# store message in db
yield self._on_new_room_event(event, snapshot)
@@ -235,7 +232,7 @@ class MessageHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
- current_state = yield self.store.get_current_state(room_id)
+ current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
@@ -271,7 +268,7 @@ class MessageHandler(BaseHandler):
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
- user, pagination_config, None
+ user, pagination_config.get_source_config("presence"), None
)
public_rooms = yield self.store.get_rooms(is_public=True)
@@ -312,7 +309,7 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
- current_state = yield self.store.get_current_state(
+ current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [self.hs.serialize_event(c) for c in current_state]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b2af09f090..2ccc2245b7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -76,9 +76,7 @@ class PresenceHandler(BaseHandler):
"stopped_user_eventstream", self.stopped_user_eventstream
)
- distributor.observe("user_joined_room",
- self.user_joined_room
- )
+ distributor.observe("user_joined_room", self.user_joined_room)
distributor.declare("collect_presencelike_data")
@@ -156,14 +154,12 @@ class PresenceHandler(BaseHandler):
defer.returnValue(True)
if (yield self.store.user_rooms_intersect(
- [u.to_string() for u in observer_user, observed_user]
- )):
+ [u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if (yield self.store.is_presence_visible(
- observed_localpart=observed_user.localpart,
- observer_userid=observer_user.to_string(),
- )):
+ observed_localpart=observed_user.localpart,
+ observer_userid=observer_user.to_string())):
defer.returnValue(True)
defer.returnValue(False)
@@ -171,7 +167,8 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def get_state(self, target_user, auth_user):
if target_user.is_mine:
- visible = yield self.is_presence_visible(observer_user=auth_user,
+ visible = yield self.is_presence_visible(
+ observer_user=auth_user,
observed_user=target_user
)
@@ -219,9 +216,9 @@ class PresenceHandler(BaseHandler):
)
if state["presence"] not in self.STATE_LEVELS:
- raise SynapseError(400, "'%s' is not a valid presence state" %
- state["presence"]
- )
+ raise SynapseError(400, "'%s' is not a valid presence state" % (
+ state["presence"],
+ ))
logger.debug("Updating presence state of %s to %s",
target_user.localpart, state["presence"])
@@ -229,7 +226,7 @@ class PresenceHandler(BaseHandler):
state_to_store = dict(state)
state_to_store["state"] = state_to_store.pop("presence")
- statuscache=self._get_or_offline_usercache(target_user)
+ statuscache = self._get_or_offline_usercache(target_user)
was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]]
now_level = self.STATE_LEVELS[state["presence"]]
@@ -649,8 +646,9 @@ class PresenceHandler(BaseHandler):
del state["user_id"]
if "presence" not in state:
- logger.warning("Received a presence 'push' EDU from %s without"
- + " a 'presence' key", origin
+ logger.warning(
+ "Received a presence 'push' EDU from %s without a"
+ " 'presence' key", origin
)
continue
@@ -745,7 +743,7 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, observed_user, users_to_push=[],
- room_ids=[], statuscache=None):
+ room_ids=[], statuscache=None):
self.notifier.on_new_user_event(
users_to_push,
room_ids,
@@ -765,8 +763,7 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
- [u.to_string() for u in observer_user, observed_user]
- )):
+ [u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if observed_user.is_mine:
@@ -823,15 +820,12 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
- from_token = pagination_config.from_token
- to_token = pagination_config.to_token
-
observer_user = user
- from_key = int(from_token.presence_key)
+ from_key = int(pagination_config.from_key)
- if to_token:
- to_key = int(to_token.presence_key)
+ if pagination_config.to_key:
+ to_key = int(pagination_config.to_key)
else:
to_key = -1
@@ -841,7 +835,7 @@ class PresenceEventSource(object):
updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
- if not (to_key < cachemap[observed_user].serial < from_key):
+ if not (to_key < cachemap[observed_user].serial <= from_key):
continue
if (yield self.is_visible(observer_user, observed_user)):
@@ -849,30 +843,15 @@ class PresenceEventSource(object):
# TODO(paul): limit
- updates = [(k, cachemap[k]) for k in cachemap
- if to_key < cachemap[k].serial < from_key]
-
if updates:
clock = self.clock
earliest_serial = max([x[1].serial for x in updates])
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
- if to_token:
- next_token = to_token
- else:
- next_token = from_token
-
- next_token = next_token.copy_and_replace(
- "presence_key", earliest_serial
- )
- defer.returnValue((data, next_token))
+ defer.returnValue((data, earliest_serial))
else:
- if not to_token:
- to_token = from_token.copy_and_replace(
- "presence_key", 0
- )
- defer.returnValue(([], to_token))
+ defer.returnValue(([], 0))
class UserPresenceCache(object):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dab9b03f04..4cd0a06093 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler):
user_id=j.state_key,
)
- yield self.state_handler.handle_new_event(new_event, snapshot)
- yield self._on_new_room_event(new_event, snapshot)
+ yield self._on_new_room_event(
+ new_event, snapshot, suppress_auth=True
+ )
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 94b7890b5e..7df9d9b82d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,7 +15,6 @@
"""Contains functions for registering clients."""
from twisted.internet import defer
-from twisted.python import log
from synapse.types import UserID
from synapse.api.errors import (
@@ -64,9 +63,11 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
token = self._generate_token(user_id)
- yield self.store.register(user_id=user_id,
+ yield self.store.register(
+ user_id=user_id,
token=token,
- password_hash=password_hash)
+ password_hash=password_hash
+ )
self.distributor.fire("registered_user", user)
else:
@@ -127,7 +128,7 @@ class RegistrationHandler(BaseHandler):
try:
threepid = yield self._threepid_from_creds(c)
except:
- log.err()
+ logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
if not threepid:
@@ -181,8 +182,11 @@ class RegistrationHandler(BaseHandler):
data = yield httpCli.post_urlencoded_get_json(
creds['idServer'],
"/_matrix/identity/api/v1/3pid/bind",
- {'sid': creds['sid'], 'clientSecret': creds['clientSecret'],
- 'mxid': mxid}
+ {
+ 'sid': creds['sid'],
+ 'clientSecret': creds['clientSecret'],
+ 'mxid': mxid,
+ }
)
defer.returnValue(data)
@@ -223,5 +227,3 @@ class RegistrationHandler(BaseHandler):
}
)
defer.returnValue(data)
-
-
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 21ae03df0d..ffc0892f1a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler):
logger.debug("Event: %s", event)
- yield self.state_handler.handle_new_event(event, snapshot)
- yield self._on_new_room_event(event, snapshot, extra_users=[user])
+ yield self._on_new_room_event(
+ event, snapshot, extra_users=[user], suppress_auth=True
+ )
for event in creation_events:
yield handle_event(event)
@@ -391,8 +392,6 @@ class RoomMemberHandler(BaseHandler):
yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
# If we're banning someone, set a req power level
if event.membership == Membership.BAN:
@@ -414,6 +413,7 @@ class RoomMemberHandler(BaseHandler):
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
defer.returnValue({"room_id": room_id})
@@ -502,14 +502,11 @@ class RoomMemberHandler(BaseHandler):
if not have_joined:
logger.debug("Doing normal join")
- if do_auth:
- yield self.auth.check(event, snapshot, raises=True)
-
- yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
+ do_auth=do_auth,
)
user = self.hs.parse_userid(event.user_id)
@@ -553,7 +550,8 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
- def _do_local_membership_update(self, event, membership, snapshot):
+ def _do_local_membership_update(self, event, membership, snapshot,
+ do_auth):
destinations = []
# If we're inviting someone, then we should also send it to that
@@ -570,9 +568,10 @@ class RoomMemberHandler(BaseHandler):
return self._on_new_room_event(
event, snapshot, extra_destinations=destinations,
- extra_users=[target_user]
+ extra_users=[target_user], suppress_auth=(not do_auth),
)
+
class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
@@ -612,23 +611,14 @@ class RoomEventSource(object):
return self.store.get_room_events_max_id()
@defer.inlineCallbacks
- def get_pagination_rows(self, user, pagination_config, key):
- from_token = pagination_config.from_token
- to_token = pagination_config.to_token
- limit = pagination_config.limit
- direction = pagination_config.direction
-
- to_key = to_token.room_key if to_token else None
-
+ def get_pagination_rows(self, user, config, key):
events, next_key = yield self.store.paginate_room_events(
room_id=key,
- from_key=from_token.room_key,
- to_key=to_key,
- direction=direction,
- limit=limit,
+ from_key=config.from_key,
+ to_key=config.to_key,
+ direction=config.direction,
+ limit=config.limit,
with_feedback=True
)
- next_token = from_token.copy_and_replace("room_key", next_key)
-
- defer.returnValue((events, next_token))
+ defer.returnValue((events, next_key))
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0ca4e5c31e..d88a53242c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -96,9 +96,10 @@ class TypingNotificationHandler(BaseHandler):
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
- yield rm_handler.fetch_room_distributions_into(room_id,
- localusers=localusers, remotedomains=remotedomains,
- ignore_user=user)
+ yield rm_handler.fetch_room_distributions_into(
+ room_id, localusers=localusers, remotedomains=remotedomains,
+ ignore_user=user
+ )
for u in localusers:
self.push_update_to_clients(
@@ -130,8 +131,9 @@ class TypingNotificationHandler(BaseHandler):
localusers = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
- yield rm_handler.fetch_room_distributions_into(room_id,
- localusers=localusers)
+ yield rm_handler.fetch_room_distributions_into(
+ room_id, localusers=localusers
+ )
for u in localusers:
self.push_update_to_clients(
@@ -142,7 +144,7 @@ class TypingNotificationHandler(BaseHandler):
)
def push_update_to_clients(self, room_id, observer_user, observed_user,
- typing):
+ typing):
# TODO(paul) steal this from presence.py
pass
@@ -158,4 +160,4 @@ class TypingNotificationEventSource(object):
return 0
def get_pagination_rows(self, user, pagination_config, key):
- return ([], pagination_config.from_token)
+ return ([], pagination_config.from_key)
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 9bff9ec169..f9811bfa04 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -12,4 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 46c90dbb76..c34b086eb9 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -16,7 +16,9 @@
from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError
-from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError
+from twisted.web.client import (
+ _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError
+)
from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_endpoint
@@ -97,7 +99,7 @@ class BaseHttpClient(object):
retries_left = 5
- endpoint = self._getEndpoint(reactor, destination);
+ endpoint = self._getEndpoint(reactor, destination)
while True:
@@ -181,7 +183,7 @@ class MatrixHttpClient(BaseHttpClient):
auth_headers = []
- for key,sig in request["signatures"][self.server_name].items():
+ for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
@@ -276,7 +278,6 @@ class MatrixHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body))
-
def _getEndpoint(self, reactor, destination):
return matrix_endpoint(
reactor, destination, timeout=10,
@@ -351,6 +352,7 @@ class IdentityServerHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body))
+
class CaptchaServerHttpClient(MatrixHttpClient):
"""Separate HTTP client for talking to google's captcha servers"""
@@ -384,6 +386,7 @@ class CaptchaServerHttpClient(MatrixHttpClient):
else:
raise e
+
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
diff --git a/synapse/http/content_repository.py b/synapse/http/content_repository.py
index 7dd4a859f8..3159ffff0a 100644
--- a/synapse/http/content_repository.py
+++ b/synapse/http/content_repository.py
@@ -38,8 +38,8 @@ class ContentRepoResource(resource.Resource):
Uploads are POSTed to wherever this Resource is linked to. This resource
returns a "content token" which can be used to GET this content again. The
- token is typically a path, but it may not be. Tokens can expire, be one-time
- uses, etc.
+ token is typically a path, but it may not be. Tokens can expire, be
+ one-time uses, etc.
In this case, the token is a path to the file and contains 3 interesting
sections:
@@ -175,10 +175,9 @@ class ContentRepoResource(resource.Resource):
with open(fname, "wb") as f:
f.write(request.content.read())
-
# FIXME (erikj): These should use constants.
file_name = os.path.basename(fname)
- # FIXME: we can't assume what the public mounted path of the repo is
+ # FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# server it via the non-SSL listener...
@@ -201,6 +200,3 @@ class ContentRepoResource(resource.Resource):
500,
json.dumps({"error": "Internal server error"}),
send_cors=True)
-
-
-
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 5b02c71d1e..f38c410e33 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -167,7 +167,8 @@ class Notifier(object):
)
def eb(failure):
- logger.error("Failed to notify listener",
+ logger.error(
+ "Failed to notify listener",
exc_info=(
failure.type,
failure.value,
@@ -207,7 +208,7 @@ class Notifier(object):
)
if timeout:
- reactor.callLater(timeout/1000, self._timeout_listener, listener)
+ reactor.callLater(timeout/1000.0, self._timeout_listener, listener)
self._register_with_keys(listener)
diff --git a/synapse/rest/base.py b/synapse/rest/base.py
index 2e8e3fa7d4..dc784c1527 100644
--- a/synapse/rest/base.py
+++ b/synapse/rest/base.py
@@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
+import logging
+
+
+logger = logging.getLogger(__name__)
+
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
diff --git a/synapse/rest/events.py b/synapse/rest/events.py
index 097195d7cc..92ff5e5ca7 100644
--- a/synapse/rest/events.py
+++ b/synapse/rest/events.py
@@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
@@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
-
- handler = self.handlers.event_stream_handler
- pagin_config = PaginationConfig.from_request(request)
- timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
- try:
- timeout = int(request.args["timeout"][0])
- except ValueError:
- raise SynapseError(400, "timeout must be in milliseconds.")
-
- chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
- timeout=timeout)
+ try:
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ chunk = yield handler.get_stream(
+ auth_user.to_string(), pagin_config, timeout=timeout
+ )
+ except:
+ logger.exception("Event stream failed")
+ raise
defer.returnValue((200, chunk))
diff --git a/synapse/rest/profile.py b/synapse/rest/profile.py
index dad5a208c7..72e02d8dd8 100644
--- a/synapse/rest/profile.py
+++ b/synapse/rest/profile.py
@@ -108,9 +108,9 @@ class ProfileRestServlet(RestServlet):
)
defer.returnValue((200, {
- "displayname": displayname,
- "avatar_url": avatar_url
- }))
+ "displayname": displayname,
+ "avatar_url": avatar_url
+ }))
def register_servlets(hs, http_server):
diff --git a/synapse/rest/register.py b/synapse/rest/register.py
index 804117ee09..5c15614ea9 100644
--- a/synapse/rest/register.py
+++ b/synapse/rest/register.py
@@ -60,40 +60,45 @@ class RegisterRestServlet(RestServlet):
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
- return (200, {
- "flows": [
+ return (
+ 200,
+ {"flows": [
{
"type": LoginType.RECAPTCHA,
- "stages": ([LoginType.RECAPTCHA,
- LoginType.EMAIL_IDENTITY,
- LoginType.PASSWORD])
+ "stages": [
+ LoginType.RECAPTCHA,
+ LoginType.EMAIL_IDENTITY,
+ LoginType.PASSWORD
+ ]
},
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]
- })
+ ]}
+ )
else:
- return (200, {
- "flows": [
+ return (
+ 200,
+ {"flows": [
{
"type": LoginType.EMAIL_IDENTITY,
- "stages": ([LoginType.EMAIL_IDENTITY,
- LoginType.PASSWORD])
+ "stages": [
+ LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
+ ]
},
{
"type": LoginType.PASSWORD
}
- ]
- })
+ ]}
+ )
@defer.inlineCallbacks
def on_POST(self, request):
register_json = _parse_json(request)
- session = (register_json["session"] if "session" in register_json
- else None)
+ session = (register_json["session"]
+ if "session" in register_json else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
@@ -122,7 +127,9 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((200, response))
except KeyError as e:
logger.exception(e)
- raise SynapseError(400, "Missing JSON keys for login type %s." % login_type)
+ raise SynapseError(400, "Missing JSON keys for login type %s." % (
+ login_type,
+ ))
def on_OPTIONS(self, request):
return (200, {})
@@ -183,8 +190,10 @@ class RegisterRestServlet(RestServlet):
session["user"] = register_json["user"]
defer.returnValue(None)
else:
- raise SynapseError(400, "Captcha bypass HMAC incorrect",
- errcode=Codes.CAPTCHA_NEEDED)
+ raise SynapseError(
+ 400, "Captcha bypass HMAC incorrect",
+ errcode=Codes.CAPTCHA_NEEDED
+ )
challenge = None
user_response = None
@@ -230,12 +239,15 @@ class RegisterRestServlet(RestServlet):
if ("user" in session and "user" in register_json and
session["user"] != register_json["user"]):
- raise SynapseError(400, "Cannot change user ID during registration")
+ raise SynapseError(
+ 400, "Cannot change user ID during registration"
+ )
password = register_json["password"].encode("utf-8")
- desired_user_id = (register_json["user"].encode("utf-8") if "user"
- in register_json else None)
- if desired_user_id and urllib.quote(desired_user_id) != desired_user_id:
+ desired_user_id = (register_json["user"].encode("utf-8")
+ if "user" in register_json else None)
+ if (desired_user_id
+ and urllib.quote(desired_user_id) != desired_user_id):
raise SynapseError(
400,
"User ID must only contain characters which do not " +
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
index c72bdc2c34..ec0ce78fda 100644
--- a/synapse/rest/room.py
+++ b/synapse/rest/room.py
@@ -48,7 +48,9 @@ class RoomCreateRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, txn_id):
try:
- defer.returnValue(self.txns.get_client_transaction(request, txn_id))
+ defer.returnValue(
+ self.txns.get_client_transaction(request, txn_id)
+ )
except KeyError:
pass
@@ -98,8 +100,8 @@ class RoomStateEventRestServlet(RestServlet):
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
# /room/$roomid/state/$eventtype/$statekey
- state_key = ("/rooms/(?P<room_id>[^/]*)/state/" +
- "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
+ state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
+ "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_path("GET",
client_path_pattern(state_key),
@@ -133,7 +135,9 @@ class RoomStateEventRestServlet(RestServlet):
)
if not data:
- raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(
+ 404, "Event not found.", errcode=Codes.NOT_FOUND
+ )
defer.returnValue((200, data[0].get_dict()["content"]))
@defer.inlineCallbacks
@@ -195,7 +199,9 @@ class RoomSendEventRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:
- defer.returnValue(self.txns.get_client_transaction(request, txn_id))
+ defer.returnValue(
+ self.txns.get_client_transaction(request, txn_id)
+ )
except KeyError:
pass
@@ -254,7 +260,9 @@ class JoinRoomAliasServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
- defer.returnValue(self.txns.get_client_transaction(request, txn_id))
+ defer.returnValue(
+ self.txns.get_client_transaction(request, txn_id)
+ )
except KeyError:
pass
@@ -293,7 +301,8 @@ class RoomMemberListRestServlet(RestServlet):
target_user = self.hs.parse_userid(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
- presence_state = yield self.handlers.presence_handler.get_state(
+ presence_handler = self.handlers.presence_handler
+ presence_state = yield presence_handler.get_state(
target_user=target_user, auth_user=user
)
event["content"].update(presence_state)
@@ -359,11 +368,11 @@ class RoomInitialSyncRestServlet(RestServlet):
# { state event } , { state event }
# ]
# }
- # Probably worth keeping the keys room_id and membership for parity with
- # /initialSync even though they must be joined to sync this and know the
- # room ID, so clients can reuse the same code (room_id and membership
- # are MANDATORY for /initialSync, so the code will expect it to be
- # there)
+ # Probably worth keeping the keys room_id and membership for parity
+ # with /initialSync even though they must be joined to sync this and
+ # know the room ID, so clients can reuse the same code (room_id and
+ # membership are MANDATORY for /initialSync, so the code will expect
+ # it to be there)
defer.returnValue((200, {}))
@@ -388,8 +397,8 @@ class RoomMembershipRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
- PATTERN = ("/rooms/(?P<room_id>[^/]*)/" +
- "(?P<membership_action>join|invite|leave|ban|kick)")
+ PATTERN = ("/rooms/(?P<room_id>[^/]*)/"
+ "(?P<membership_action>join|invite|leave|ban|kick)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
@@ -422,7 +431,9 @@ class RoomMembershipRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
- defer.returnValue(self.txns.get_client_transaction(request, txn_id))
+ defer.returnValue(
+ self.txns.get_client_transaction(request, txn_id)
+ )
except KeyError:
pass
@@ -431,6 +442,7 @@ class RoomMembershipRestServlet(RestServlet):
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
+
class RoomRedactEventRestServlet(RestServlet):
def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -457,7 +469,9 @@ class RoomRedactEventRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id):
try:
- defer.returnValue(self.txns.get_client_transaction(request, txn_id))
+ defer.returnValue(
+ self.txns.get_client_transaction(request, txn_id)
+ )
except KeyError:
pass
@@ -503,10 +517,10 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
)
if with_get:
http_server.register_path(
- "GET",
- client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
- servlet.on_GET
- )
+ "GET",
+ client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
+ servlet.on_GET
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/transactions.py b/synapse/rest/transactions.py
index e06dcc8c57..93c0122f30 100644
--- a/synapse/rest/transactions.py
+++ b/synapse/rest/transactions.py
@@ -30,9 +30,9 @@ class HttpTransactionStore(object):
"""Retrieve a response for this request.
Args:
- key (str): A transaction-independent key for this request. Typically
- this is a combination of the path (without the transaction id) and
- the user's access token.
+ key (str): A transaction-independent key for this request. Usually
+ this is a combination of the path (without the transaction id)
+ and the user's access token.
txn_id (str): The transaction ID for this request
Returns:
A tuple of (HTTP response code, response content) or None.
@@ -51,9 +51,9 @@ class HttpTransactionStore(object):
"""Stores an HTTP response tuple.
Args:
- key (str): A transaction-independent key for this request. Typically
- this is a combination of the path (without the transaction id) and
- the user's access token.
+ key (str): A transaction-independent key for this request. Usually
+ this is a combination of the path (without the transaction id)
+ and the user's access token.
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
@@ -92,5 +92,3 @@ class HttpTransactionStore(object):
token = request.args["access_token"][0]
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token
-
-
diff --git a/synapse/rest/voip.py b/synapse/rest/voip.py
index 0d0243a249..432c2475f8 100644
--- a/synapse/rest/voip.py
+++ b/synapse/rest/voip.py
@@ -34,23 +34,23 @@ class VoipRestServlet(RestServlet):
turnSecret = self.hs.config.turn_shared_secret
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
- defer.returnValue( (200, {}) )
+ defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, auth_user.to_string())
-
+
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
- # We need to use standard base64 encoding here, *not* syutil's encode_base64
- # because we need to add the standard padding to get the same result as the
- # TURN server.
+ # We need to use standard base64 encoding here, *not* syutil's
+ # encode_base64 because we need to add the standard padding to get the
+ # same result as the TURN server.
password = base64.b64encode(mac.digest())
- defer.returnValue( (200, {
+ defer.returnValue((200, {
'username': username,
'password': password,
'ttl': userLifetime / 1000,
'uris': turnUris,
- }) )
+ }))
def on_OPTIONS(self, request):
return (200, {})
diff --git a/synapse/server.py b/synapse/server.py
index a4d2d4aba5..d770b20b19 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -28,7 +28,7 @@ from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, EventID
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
@@ -143,6 +143,11 @@ class BaseHomeServer(object):
object."""
return RoomID.from_string(s, hs=self)
+ def parse_eventid(self, s):
+ """Parse the string given by 's' as a Event ID and return a EventID
+ object."""
+ return EventID.from_string(s, hs=self)
+
def serialize_event(self, e):
return serialize_event(self, e)
diff --git a/synapse/state.py b/synapse/state.py
index bc6b928ec7..f4efc287c9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -16,11 +16,14 @@
from twisted.internet import defer
-from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function
+from synapse.util.async import run_on_reactor
+
+from synapse.types import EventID
from collections import namedtuple
+import copy
import logging
import hashlib
@@ -35,13 +38,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object):
- """ Repsonsible for doing state conflict resolution.
+ """ Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self._replication = hs.get_replication_layer()
self.server_name = hs.hostname
+ self.hs = hs
@defer.inlineCallbacks
@log_function
@@ -50,7 +54,7 @@ class StateHandler(object):
to update the state and b) works out what the prev_state should be.
Returns:
- Deferred: Resolved with a boolean indicating if we succesfully
+ Deferred: Resolved with a boolean indicating if we successfully
updated the state.
Raised:
@@ -61,200 +65,149 @@ class StateHandler(object):
if not hasattr(event, "state_key"):
return
- key = KeyStateTuple(
- event.room_id,
- event.type,
- _get_state_key_from_event(event)
- )
-
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
snapshot.fill_out_prev_events(event)
+ yield self.annotate_state_groups(event)
- current_state = snapshot.prev_state_pdu
-
- if current_state:
- event.prev_state = encode_event_id(
- current_state.pdu_id, current_state.origin
+ if event.old_state_events:
+ current_state = event.old_state_events.get(
+ (event.type, event.state_key)
)
- # TODO check current_state to see if the min power level is less
- # than the power level of the user
- # power_level = self._get_power_level_for_event(event)
-
- pdu_id, origin = decode_event_id(event.event_id, self.server_name)
-
- yield self.store.update_current_state(
- pdu_id=pdu_id,
- origin=origin,
- context=key.context,
- pdu_type=key.type,
- state_key=key.state_key
- )
+ if current_state:
+ event.prev_state = current_state.event_id
defer.returnValue(True)
@defer.inlineCallbacks
@log_function
- def handle_new_state(self, new_pdu):
- """ Apply conflict resolution to `new_pdu`.
-
- This should be called on every new state pdu, regardless of whether or
- not there is a conflict.
-
- This function is safe against the race of it getting called with two
- `PDU`s trying to update the same state.
- """
-
- # This needs to be done in a transaction.
+ def annotate_state_groups(self, event, old_state=None):
+ yield run_on_reactor()
- is_new = yield self._handle_new_state(new_pdu)
+ if old_state:
+ event.state_group = None
+ event.old_state_events = {
+ (s.type, s.state_key): s for s in old_state
+ }
+ event.state_events = event.old_state_events
- logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
+ if hasattr(event, "state_key"):
+ event.state_events[(event.type, event.state_key)] = event
- if is_new:
- yield self.store.update_current_state(
- pdu_id=new_pdu.pdu_id,
- origin=new_pdu.origin,
- context=new_pdu.context,
- pdu_type=new_pdu.pdu_type,
- state_key=new_pdu.state_key
- )
-
- defer.returnValue(is_new)
-
- def _get_power_level_for_event(self, event):
- # return self._persistence.get_power_level_for_user(event.room_id,
- # event.sender)
- return event.power_level
+ defer.returnValue(False)
+ return
- @defer.inlineCallbacks
- @log_function
- def _handle_new_state(self, new_pdu):
- tree, missing_branch = yield self.store.get_unresolved_state_tree(
- new_pdu
- )
- new_branch, current_branch = tree
+ if hasattr(event, "outlier") and event.outlier:
+ event.state_group = None
+ event.old_state_events = None
+ event.state_events = {}
+ defer.returnValue(False)
+ return
- logger.debug(
- "_handle_new_state new=%s, current=%s",
- new_branch, current_branch
+ new_state = yield self.resolve_state_groups(
+ [e for e, _ in event.prev_events]
)
- if missing_branch is not None:
- # We're missing some PDUs. Fetch them.
- # TODO (erikj): Limit this.
- missing_prev = tree[missing_branch][-1]
-
- pdu_id = missing_prev.prev_state_id
- origin = missing_prev.prev_state_origin
-
- is_missing = yield self.store.get_pdu(pdu_id, origin) is None
- if not is_missing:
- raise Exception("Conflict resolution failed")
-
- yield self._replication.get_pdu(
- destination=missing_prev.origin,
- pdu_origin=origin,
- pdu_id=pdu_id,
- outlier=True
- )
+ event.old_state_events = copy.deepcopy(new_state)
- updated_current = yield self._handle_new_state(new_pdu)
- defer.returnValue(updated_current)
+ if hasattr(event, "state_key"):
+ new_state[(event.type, event.state_key)] = event
- if not current_branch:
- # There is no current state
- defer.returnValue(True)
- return
+ event.state_group = None
+ event.state_events = new_state
- n = new_branch[-1]
- c = current_branch[-1]
+ defer.returnValue(hasattr(event, "state_key"))
- common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
+ @defer.inlineCallbacks
+ def get_current_state(self, room_id, event_type=None, state_key=""):
+ events = yield self.store.get_latest_events_in_room(room_id)
- if common_ancestor:
- # We found a common ancestor!
+ event_ids = [
+ e_id
+ for e_id, _, _ in events
+ ]
- if len(current_branch) == 1:
- # This is a direct clobber so we can just...
- defer.returnValue(True)
+ res = yield self.resolve_state_groups(event_ids)
- else:
- # We didn't find a common ancestor. This is probably fine.
- pass
+ if event_type:
+ defer.returnValue(res.get((event_type, state_key)))
+ return
- result = yield self._do_conflict_res(
- new_branch, current_branch, common_ancestor
- )
- defer.returnValue(result)
+ defer.returnValue(res.values())
@defer.inlineCallbacks
- def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
- conflict_res = [
- self._do_power_level_conflict_res,
- self._do_chain_length_conflict_res,
- self._do_hash_conflict_res,
- ]
-
- for algo in conflict_res:
- new_res, curr_res = yield defer.maybeDeferred(
- algo,
- new_branch, current_branch, common_ancestor
- )
-
- if new_res < curr_res:
- defer.returnValue(False)
- elif new_res > curr_res:
- defer.returnValue(True)
+ @log_function
+ def resolve_state_groups(self, event_ids):
+ state_groups = yield self.store.get_state_groups(
+ event_ids
+ )
- raise Exception("Conflict resolution failed.")
+ state = {}
+ for group in state_groups:
+ for s in group.state:
+ state.setdefault(
+ (s.type, s.state_key),
+ {}
+ )[s.event_id] = s
+
+ unconflicted_state = {
+ k: v.values()[0] for k, v in state.items()
+ if len(v.values()) == 1
+ }
+
+ conflicted_state = {
+ k: v.values()
+ for k, v in state.items()
+ if len(v.values()) > 1
+ }
+
+ try:
+ new_state = {}
+ new_state.update(unconflicted_state)
+ for key, events in conflicted_state.items():
+ new_state[key] = yield self._resolve_state_events(events)
+ except:
+ logger.exception("Failed to resolve state")
+ raise
+
+ defer.returnValue(new_state)
@defer.inlineCallbacks
- def _do_power_level_conflict_res(self, new_branch, current_branch,
- common_ancestor):
+ @log_function
+ def _resolve_state_events(self, events):
+ curr_events = events
+
new_powers_deferreds = []
- for e in new_branch[:-1] if common_ancestor else new_branch:
- if hasattr(e, "user_id"):
- new_powers_deferreds.append(
- self.store.get_power_level(e.context, e.user_id)
- )
-
- current_powers_deferreds = []
- for e in current_branch[:-1] if common_ancestor else current_branch:
- if hasattr(e, "user_id"):
- current_powers_deferreds.append(
- self.store.get_power_level(e.context, e.user_id)
- )
+ for e in curr_events:
+ new_powers_deferreds.append(
+ self.store.get_power_level(e.room_id, e.user_id)
+ )
new_powers = yield defer.gatherResults(
new_powers_deferreds,
consumeErrors=True
)
- current_powers = yield defer.gatherResults(
- current_powers_deferreds,
- consumeErrors=True
- )
-
- max_power_new = max(new_powers)
- max_power_current = max(current_powers)
+ max_power = max([int(p) for p in new_powers])
- defer.returnValue(
- (max_power_new, max_power_current)
- )
-
- def _do_chain_length_conflict_res(self, new_branch, current_branch,
- common_ancestor):
- return (len(new_branch), len(current_branch))
+ curr_events = [
+ z[0] for z in zip(curr_events, new_powers)
+ if int(z[1]) == max_power
+ ]
- def _do_hash_conflict_res(self, new_branch, current_branch,
- common_ancestor):
- new_str = "".join([p.pdu_id + p.origin for p in new_branch])
- c_str = "".join([p.pdu_id + p.origin for p in current_branch])
+ if not curr_events:
+ raise RuntimeError("Max didn't get a max?")
+ elif len(curr_events) == 1:
+ defer.returnValue(curr_events[0])
- return (
- hashlib.sha1(new_str).hexdigest(),
- hashlib.sha1(c_str).hexdigest()
+ # TODO: For now, just choose the one with the largest event_id.
+ defer.returnValue(
+ sorted(
+ curr_events,
+ key=lambda e: hashlib.sha1(
+ e.event_id + e.user_id + e.room_id + e.type
+ ).hexdigest()
+ )[0]
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 1639e2c973..6b8fed4502 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,14 +37,17 @@ from .registration import RegistrationStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .stream import StreamStore
-from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore
from .keys import KeyStore
+from .event_federation import EventFederationStore
+
+from .state import StateStore
from .signatures import SignatureStore
from syutil.base64util import decode_base64
-from synapse.crypto.event_signing import compute_pdu_event_reference_hash
+from synapse.crypto.event_signing import compute_event_reference_hash
+
import json
import logging
@@ -56,7 +59,6 @@ logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
- "pdu",
"users",
"profiles",
"presence",
@@ -64,7 +66,9 @@ SCHEMAS = [
"room_aliases",
"keys",
"redactions",
- "signatures",
+ "state",
+ "event_edges",
+ "event_signatures",
]
@@ -79,10 +83,12 @@ class _RollbackButIsFineException(Exception):
"""
pass
+
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
- PresenceStore, PduStore, StatePduStore, TransactionStore,
- DirectoryStore, KeyStore, SignatureStore):
+ PresenceStore, TransactionStore,
+ DirectoryStore, KeyStore, StateStore, SignatureStore,
+ EventFederationStore, ):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
@@ -105,6 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
try:
yield self.runInteraction(
+ "persist_event",
self._persist_pdu_event_txn,
pdu=pdu,
event=event,
@@ -125,7 +132,8 @@ class DataStore(RoomMemberStore, RoomStore,
"type",
"room_id",
"content",
- "unrecognized_keys"
+ "unrecognized_keys",
+ "depth",
],
allow_none=allow_none,
)
@@ -139,68 +147,12 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
backfilled=False, stream_ordering=None,
is_new_state=True):
- if pdu is not None:
- self._persist_event_pdu_txn(txn, pdu)
if event is not None:
return self._persist_event_txn(
txn, event, backfilled, stream_ordering,
is_new_state=is_new_state,
)
- def _persist_event_pdu_txn(self, txn, pdu):
- cols = dict(pdu.__dict__)
- unrec_keys = dict(pdu.unrecognized_keys)
- del cols["hashes"]
- del cols["signatures"]
- del cols["content"]
- del cols["prev_pdus"]
- cols["content_json"] = json.dumps(pdu.content)
-
- unrec_keys.update({
- k: v for k, v in cols.items()
- if k not in PdusTable.fields
- })
-
- cols["unrecognized_keys"] = json.dumps(unrec_keys)
-
- cols["ts"] = cols.pop("origin_server_ts")
-
- logger.debug("Persisting: %s", repr(cols))
-
- for hash_alg, hash_base64 in pdu.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_pdu_content_hash_txn(
- txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
- )
-
- signatures = pdu.signatures.get(pdu.origin, {})
-
- for key_id, signature_base64 in signatures.items():
- signature_bytes = decode_base64(signature_base64)
- self._store_pdu_origin_signature_txn(
- txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
- )
-
- for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_pdu_hash_txn(
- txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
- hash_bytes
- )
-
- (ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
- self._store_pdu_reference_hash_txn(
- txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
- )
-
- if pdu.is_state:
- self._persist_state_txn(txn, pdu.prev_pdus, cols)
- else:
- self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
-
- self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
-
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
@@ -225,6 +177,10 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event)
+ outlier = False
+ if hasattr(event, "outlier"):
+ outlier = event.outlier
+
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
@@ -232,25 +188,33 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"content": json.dumps(event.content),
"processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
- if hasattr(event, "outlier"):
- vals["outlier"] = event.outlier
- else:
- vals["outlier"] = False
-
unrec = {
k: v
for k, v in event.get_full_dict().items()
- if k not in vals.keys() and k not in ["redacted", "redacted_because"]
+ if k not in vals.keys() and k not in [
+ "redacted",
+ "redacted_because",
+ "signatures",
+ "hashes",
+ "prev_events",
+ ]
}
vals["unrecognized_keys"] = json.dumps(unrec)
try:
- self._simple_insert_txn(txn, "events", vals)
+ self._simple_insert_txn(
+ txn,
+ "events",
+ vals,
+ or_replace=(not outlier),
+ )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
@@ -259,6 +223,16 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ self._store_state_groups_txn(txn, event)
+
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
vals = {
@@ -284,6 +258,35 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
+ for hash_alg, hash_base64 in event.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_event_content_hash_txn(
+ txn, event.event_id, hash_alg, hash_bytes,
+ )
+
+ if hasattr(event, "signatures"):
+ signatures = event.signatures.get(event.origin, {})
+
+ for key_id, signature_base64 in signatures.items():
+ signature_bytes = decode_base64(signature_base64)
+ self._store_event_origin_signature_txn(
+ txn, event.event_id, event.origin, key_id, signature_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg, hash_bytes
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+ self._store_event_reference_hash_txn(
+ txn, event.event_id, ref_alg, ref_hash_bytes
+ )
+
+ self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
+
def _store_redaction(self, txn, event):
txn.execute(
"INSERT OR IGNORE INTO redactions "
@@ -366,29 +369,19 @@ class DataStore(RoomMemberStore, RoomStore,
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id)
- prev_pdus = self._get_latest_pdus_in_context(
- txn, room_id
- )
-
- if state_type is not None and state_key is not None:
- prev_state_pdu = self._get_current_state_pdu(
- txn, room_id, state_type, state_key
- )
- else:
- prev_state_pdu = None
+ prev_events = self._get_latest_events_in_room(txn, room_id)
return Snapshot(
store=self,
room_id=room_id,
user_id=user_id,
- prev_pdus=prev_pdus,
+ prev_events=prev_events,
membership_state=membership_state,
state_type=state_type,
state_key=state_key,
- prev_state_pdu=prev_state_pdu,
)
- return self.runInteraction(_snapshot)
+ return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
@@ -397,7 +390,7 @@ class Snapshot(object):
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
- prev_pdus (list): The list of PDU ids this snapshot is after.
+ prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
@@ -406,29 +399,29 @@ class Snapshot(object):
the previous value of the state type and key in the room.
"""
- def __init__(self, store, room_id, user_id, prev_pdus,
+ def __init__(self, store, room_id, user_id, prev_events,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
- self.prev_pdus = prev_pdus
+ self.prev_events = prev_events
self.membership_state = membership_state
self.state_type = state_type
self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
- if hasattr(event, "prev_pdus"):
+ if hasattr(event, "prev_events"):
return
- event.prev_pdus = [
- (p_id, origin, hashes)
- for p_id, origin, hashes, _ in self.prev_pdus
+ event.prev_events = [
+ (event_id, hashes)
+ for event_id, hashes, _ in self.prev_events
]
- if self.prev_pdus:
- event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
+ if self.prev_events:
+ event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
@@ -487,9 +480,10 @@ def prepare_database(db_conn):
db_conn.commit()
else:
- sql_script = "BEGIN TRANSACTION;"
+ sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS:
sql_script += read_schema(sql_loc)
+ sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..464b12f032 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,54 +19,66 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
import collections
import copy
import json
+import sys
+import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
- __slots__ = ["txn"]
+ __slots__ = ["txn", "name"]
- def __init__(self, txn):
+ def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
- def __getattribute__(self, name):
- if name == "execute":
- return object.__getattribute__(self, "execute")
-
- return getattr(object.__getattribute__(self, "txn"), name)
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
def __setattr__(self, name, value):
- setattr(object.__getattribute__(self, "txn"), name, value)
+ setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] %s", sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
- sql_logger.debug("[SQL values] " +
- ", ".join(("<%s>",) * len(values)), *values)
+ sql_logger.debug(
+ "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+ self.name,
+ *values
+ )
except:
# Don't let logging failures stop SQL from working
pass
- # TODO(paul): Here would be an excellent place to put some timing
- # measurements, and log (warning?) slow queries.
- return object.__getattribute__(self, "txn").execute(
- sql, *args, **kwargs
- )
+ start = time.clock() * 1000
+ try:
+ return self.txn.execute(
+ sql, *args, **kwargs
+ )
+ except:
+ logger.exception("[SQL FAIL] {%s}", self.name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
+ _TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@@ -74,10 +86,30 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
- def runInteraction(self, func, *args, **kwargs):
+ def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
- return func(LoggingTransaction(txn), *args, **kwargs)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
@@ -113,7 +145,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction(interaction)
+ return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +162,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
+ "_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -170,7 +203,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
- @defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -181,19 +213,41 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
- ret = yield self._simple_select_one(
+ return self.runInteraction(
+ "_simple_select_one_onecol_txn",
+ self._simple_select_one_onecol_txn,
+ table, keyvalues, retcol, allow_none=allow_none,
+ )
+
+ def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ allow_none=False):
+ ret = self._simple_select_onecol_txn(
+ txn,
table=table,
keyvalues=keyvalues,
- retcols=[retcol],
- allow_none=allow_none
+ retcol=retcol,
)
if ret:
- defer.returnValue(ret[retcol])
+ return ret[0]
else:
- defer.returnValue(None)
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ txn.execute(sql, keyvalues.values())
+
+ return [r[0] for r in txn.fetchall()]
+
- @defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -206,19 +260,11 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
- sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
- "retcol": retcol,
- "table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- }
-
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return txn.fetchall()
-
- res = yield self.runInteraction(func)
-
- defer.returnValue([r[0] for r in res])
+ return self.runInteraction(
+ "_simple_select_onecol",
+ self._simple_select_onecol_txn,
+ table, keyvalues, retcol
+ )
def _simple_select_list(self, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
@@ -239,7 +285,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
- return self.runInteraction(func)
+ return self.runInteraction("_simple_select_list", func)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -307,7 +353,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction(func)
+ return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +365,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@@ -328,7 +374,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction(func)
+ return self.runInteraction("_simple_delete_one", func)
+
+ def _simple_delete(self, table, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+
+ return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+ def _simple_delete_txn(self, txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+
+ return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +410,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self.runInteraction(func)
+ return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -370,7 +434,9 @@ class SQLBaseStore(object):
)
def _parse_events(self, rows):
- return self.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
@@ -378,6 +444,17 @@ class SQLBaseStore(object):
sql = "SELECT * FROM events WHERE event_id = ?"
for ev in events:
+ signatures = self._get_event_origin_signatures_txn(
+ txn, ev.event_id,
+ )
+
+ ev.signatures = {
+ k: encode_base64(v) for k, v in signatures.items()
+ }
+
+ prev_events = self._get_latest_events_in_room(txn, ev.room_id)
+ ev.prev_events = [(e_id, s,) for e_id, s, _ in prev_events]
+
if hasattr(ev, "prev_state"):
# Load previous state_content.
# TODO: Should we be pulling this out above?
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 52373a28a6..d6a7113b9c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
def delete_room_alias(self, room_alias):
return self.runInteraction(
+ "delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
new file mode 100644
index 0000000000..dcc116bad2
--- /dev/null
+++ b/synapse/storage/event_federation.py
@@ -0,0 +1,253 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from syutil.base64util import encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class EventFederationStore(SQLBaseStore):
+
+ def get_oldest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_oldest_events_in_room",
+ self._get_oldest_events_in_room_txn,
+ room_id,
+ )
+
+ def _get_oldest_events_in_room_txn(self, txn, room_id):
+ return self._simple_select_onecol_txn(
+ txn,
+ table="event_backward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ )
+
+ def get_latest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_latest_events_in_room",
+ self._get_latest_events_in_room,
+ room_id,
+ )
+
+ def _get_latest_events_in_room(self, txn, room_id):
+ self._simple_select_onecol_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ )
+
+ sql = (
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "WHERE f.room_id = ?"
+ )
+
+ txn.execute(sql, (room_id, ))
+
+ results = []
+ for event_id, depth in txn.fetchall():
+ hashes = self._get_event_reference_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes, depth))
+
+ return results
+
+ def get_min_depth(self, room_id):
+ return self.runInteraction(
+ "get_min_depth",
+ self._get_min_depth_interaction,
+ room_id,
+ )
+
+ def _get_min_depth_interaction(self, txn, room_id):
+ min_depth = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id,},
+ retcol="min_depth",
+ allow_none=True,
+ )
+
+ return int(min_depth) if min_depth is not None else None
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ self._simple_insert_txn(
+ txn,
+ table="room_depth",
+ values={
+ "room_id": room_id,
+ "min_depth": depth,
+ },
+ or_replace=True,
+ )
+
+ def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+ room_id):
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_delete_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "event_id": e_id,
+ "room_id": room_id,
+ }
+ )
+
+
+
+ # We only insert as a forward extremity the new pdu if there are no
+ # other pdus that reference it as a prev pdu
+ query = (
+ "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
+ "SELECT ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(event_edges)s WHERE "
+ "prev_event_id = ? "
+ ")"
+ ) % {
+ "table": "event_forward_extremities",
+ "event_edges": "event_edges",
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (event_id, room_id, event_id))
+
+ # Insert all the prev_pdus as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference pdus that we have already seen
+ query = (
+ "DELETE FROM event_backward_extremities WHERE EXISTS ("
+ "SELECT 1 FROM events "
+ "WHERE "
+ "event_backward_extremities.event_id = events.event_id "
+ "AND not events.outlier "
+ ")"
+ )
+ txn.execute(query)
+
+
+ def get_backfill_events(self, room_id, event_list, limit):
+ """Get a list of Events for a given topic that occured before (and
+ including) the pdus in pdu_list. Return a list of max size `limit`.
+
+ Args:
+ txn
+ room_id (str)
+ event_list (list)
+ limit (int)
+
+ Return:
+ list: A list of PduTuples
+ """
+ return self.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events, room_id, event_list, limit
+ )
+
+ def _get_backfill_events(self, txn, room_id, event_list, limit):
+ logger.debug(
+ "_get_backfill_events: %s, %s, %s",
+ room_id, repr(event_list), limit
+ )
+
+ # We seed the pdu_results with the things from the pdu_list.
+ event_results = event_list
+
+ front = event_list
+
+ query = (
+ "SELECT prev_event_id FROM event_edges "
+ "WHERE room_id = ? AND event_id = ? "
+ "LIMIT ?"
+ )
+
+ # We iterate through all event_ids in `front` to select their previous
+ # events. These are dumped in `new_front`.
+ # We continue until we reach the limit *or* new_front is empty (i.e.,
+ # we've run out of things to select
+ while front and len(event_results) < limit:
+
+ new_front = []
+ for event_id in front:
+ logger.debug(
+ "_backfill_interaction: id=%s",
+ event_id
+ )
+
+ txn.execute(
+ query,
+ (room_id, event_id, limit - len(event_results))
+ )
+
+ for row in txn.fetchall():
+ logger.debug(
+ "_backfill_interaction: got id=%s",
+ *row
+ )
+ new_front.append(row)
+
+ front = new_front
+ event_results += new_front
+
+ # We also want to update the `prev_pdus` attributes before returning.
+ return self._get_pdu_tuples(txn, event_results)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 4feb8335ba..fd705138e6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -18,9 +18,10 @@ from _base import SQLBaseStore
from twisted.internet import defer
import OpenSSL
-from syutil.crypto.signing_key import decode_verify_key_bytes
+from syutil.crypto.signing_key import decode_verify_key_bytes
import hashlib
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
"""
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
deleted file mode 100644
index 3a90c382f0..0000000000
--- a/synapse/storage/pdu.py
+++ /dev/null
@@ -1,932 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore, Table, JoinHelper
-
-from synapse.federation.units import Pdu
-from synapse.util.logutils import log_function
-
-from syutil.base64util import encode_base64
-
-from collections import namedtuple
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class PduStore(SQLBaseStore):
- """A collection of queries for handling PDUs.
- """
-
- def get_pdu(self, pdu_id, origin):
- """Given a pdu_id and origin, get a PDU.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
-
- Returns:
- PduTuple: If the pdu does not exist in the database, returns None
- """
-
- return self.runInteraction(
- self._get_pdu_tuple, pdu_id, origin
- )
-
- def _get_pdu_tuple(self, txn, pdu_id, origin):
- res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
- return res[0] if res else None
-
- def _get_pdu_tuples(self, txn, pdu_id_tuples):
- results = []
- for pdu_id, origin in pdu_id_tuples:
- txn.execute(
- PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
- (pdu_id, origin)
- )
-
- edges = [
- (r.prev_pdu_id, r.prev_origin)
- for r in PduEdgesTable.decode_results(txn.fetchall())
- ]
-
- edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
-
- hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
- signatures = self._get_pdu_origin_signatures_txn(
- txn, pdu_id, origin
- )
-
- query = (
- "SELECT %(fields)s FROM %(pdus)s as p "
- "LEFT JOIN %(state)s as s "
- "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
- "WHERE p.pdu_id = ? AND p.origin = ? "
- ) % {
- "fields": _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s"),
- "pdus": PdusTable.table_name,
- "state": StatePdusTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- row = txn.fetchone()
- if row:
- results.append(PduTuple(
- PduEntry(*row), edges, hashes, signatures, edge_hashes
- ))
-
- return results
-
- def get_current_state_for_context(self, context):
- """Get a list of PDUs that represent the current state for a given
- context
-
- Args:
- context (str)
-
- Returns:
- list: A list of PduTuples
- """
-
- return self.runInteraction(
- self._get_current_state_for_context,
- context
- )
-
- def _get_current_state_for_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s WHERE context = ?"
- % CurrentStateTable.table_name
- )
-
- logger.debug("get_current_state %s, Args=%s", query, context)
- txn.execute(query, (context,))
-
- res = txn.fetchall()
-
- logger.debug("get_current_state %d results", len(res))
-
- return self._get_pdu_tuples(txn, res)
-
- def _persist_pdu_txn(self, txn, prev_pdus, cols):
- """Inserts a (non-state) PDU into the database.
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable.
- """
- entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
-
- txn.execute(PdusTable.insert_statement(), entry)
-
- self._handle_prev_pdus(
- txn, entry.outlier, entry.pdu_id, entry.origin,
- prev_pdus, entry.context
- )
-
- def mark_pdu_as_processed(self, pdu_id, pdu_origin):
- """Mark a received PDU as processed.
-
- Args:
- txn
- pdu_id (str)
- pdu_origin (str)
- """
-
- return self.runInteraction(
- self._mark_as_processed, pdu_id, pdu_origin
- )
-
- def _mark_as_processed(self, txn, pdu_id, pdu_origin):
- txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
-
- def get_all_pdus_from_context(self, context):
- """Get a list of all PDUs for a given context."""
- return self.runInteraction(
- self._get_all_pdus_from_context, context,
- )
-
- def _get_all_pdus_from_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s "
- "WHERE context = ?"
- ) % PdusTable.table_name
-
- txn.execute(query, (context,))
-
- return self._get_pdu_tuples(txn, txn.fetchall())
-
- def get_backfill(self, context, pdu_list, limit):
- """Get a list of Pdus for a given topic that occured before (and
- including) the pdus in pdu_list. Return a list of max size `limit`.
-
- Args:
- txn
- context (str)
- pdu_list (list)
- limit (int)
-
- Return:
- list: A list of PduTuples
- """
- return self.runInteraction(
- self._get_backfill, context, pdu_list, limit
- )
-
- def _get_backfill(self, txn, context, pdu_list, limit):
- logger.debug(
- "backfill: %s, %s, %s",
- context, repr(pdu_list), limit
- )
-
- # We seed the pdu_results with the things from the pdu_list.
- pdu_results = pdu_list
-
- front = pdu_list
-
- query = (
- "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
- "WHERE context = ? AND pdu_id = ? AND origin = ? "
- "LIMIT ?"
- ) % {
- "edges_table": PduEdgesTable.table_name,
- }
-
- # We iterate through all pdu_ids in `front` to select their previous
- # pdus. These are dumped in `new_front`. We continue until we reach the
- # limit *or* new_front is empty (i.e., we've run out of things to
- # select
- while front and len(pdu_results) < limit:
-
- new_front = []
- for pdu_id, origin in front:
- logger.debug(
- "_backfill_interaction: i=%s, o=%s",
- pdu_id, origin
- )
-
- txn.execute(
- query,
- (context, pdu_id, origin, limit - len(pdu_results))
- )
-
- for row in txn.fetchall():
- logger.debug(
- "_backfill_interaction: got i=%s, o=%s",
- *row
- )
- new_front.append(row)
-
- front = new_front
- pdu_results += new_front
-
- # We also want to update the `prev_pdus` attributes before returning.
- return self._get_pdu_tuples(txn, pdu_results)
-
- def get_min_depth_for_context(self, context):
- """Get the current minimum depth for a context
-
- Args:
- txn
- context (str)
- """
- return self.runInteraction(
- self._get_min_depth_for_context, context
- )
-
- def _get_min_depth_for_context(self, txn, context):
- return self._get_min_depth_interaction(txn, context)
-
- def _get_min_depth_interaction(self, txn, context):
- txn.execute(
- "SELECT min_depth FROM %s WHERE context = ?"
- % ContextDepthTable.table_name,
- (context,)
- )
-
- row = txn.fetchone()
-
- return row[0] if row else None
-
- def _update_min_depth_for_context_txn(self, txn, context, depth):
- """Update the minimum `depth` of the given context, which is the line
- on which we stop backfilling backwards.
-
- Args:
- context (str)
- depth (int)
- """
- min_depth = self._get_min_depth_interaction(txn, context)
-
- do_insert = depth < min_depth if min_depth else True
-
- if do_insert:
- txn.execute(
- "INSERT OR REPLACE INTO %s (context, min_depth) "
- "VALUES (?,?)" % ContextDepthTable.table_name,
- (context, depth)
- )
-
- def _get_latest_pdus_in_context(self, txn, context):
- """Get's a list of the most current pdus for a given context. This is
- used when we are sending a Pdu and need to fill out the `prev_pdus`
- key
-
- Args:
- txn
- context
- """
- query = (
- "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
- "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
- "AND f.origin = p.origin "
- "WHERE f.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- logger.debug("get_prev query: %s", query)
-
- txn.execute(
- query,
- (context, )
- )
-
- results = []
- for pdu_id, origin, depth in txn.fetchall():
- hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
- sha256_bytes = hashes["sha256"]
- prev_hashes = {"sha256": encode_base64(sha256_bytes)}
- results.append((pdu_id, origin, prev_hashes, depth))
-
- return results
-
- @defer.inlineCallbacks
- def get_oldest_pdus_in_context(self, context):
- """Get a list of Pdus that we haven't backfilled beyond yet (and havent
- seen). This list is used when we want to backfill backwards and is the
- list we send to the remote server.
-
- Args:
- txn
- context (str)
-
- Returns:
- list: A list of PduIdTuple.
- """
- results = yield self._execute(
- None,
- "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
- % {"back": PduBackwardExtremitiesTable.table_name, },
- context
- )
-
- defer.returnValue([PduIdTuple(i, o) for i, o in results])
-
- def is_pdu_new(self, pdu_id, origin, context, depth):
- """For a given Pdu, try and figure out if it's 'new', i.e., if it's
- not something we got randomly from the past, for example when we
- request the current state of the room that will probably return a bunch
- of pdus from before we joined.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
- context (str)
- depth (int)
-
- Returns:
- bool
- """
-
- return self.runInteraction(
- self._is_pdu_new,
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- depth=depth
- )
-
- def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
- # If depth > min depth in back table, then we classify it as new.
- # OR if there is nothing in the back table, then it kinda needs to
- # be a new thing.
- query = (
- "SELECT min(p.depth) FROM %(edges)s as e "
- "INNER JOIN %(back)s as b "
- "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
- "INNER JOIN %(pdus)s as p "
- "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
- "WHERE p.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "edges": PduEdgesTable.table_name,
- "back": PduBackwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (context,))
-
- min_depth, = txn.fetchone()
-
- if not min_depth or depth > int(min_depth):
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s min_depth=%s",
- pdu_id, origin, depth, min_depth
- )
- return True
-
- # If this pdu is in the forwards table, then it also is a new one
- query = (
- "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
- ) % {
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- # Did we get anything?
- if txn.fetchall():
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s was forward",
- pdu_id, origin, depth
- )
- return True
-
- logger.debug(
- "is_new false: id=%s, o=%s, d=%s",
- pdu_id, origin, depth
- )
-
- # FINE THEN. It's probably old.
- return False
-
- @staticmethod
- @log_function
- def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
- context):
- txn.executemany(
- PduEdgesTable.insert_statement(),
- [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
- )
-
- # Update the extremities table if this is not an outlier.
- if not outlier:
-
- # First, we delete the new one from the forwards extremities table.
- query = (
- "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
- % PduForwardExtremitiesTable.table_name
- )
- txn.executemany(query, list(p[:2] for p in prev_pdus))
-
- # We only insert as a forward extremety the new pdu if there are no
- # other pdus that reference it as a prev pdu
- query = (
- "INSERT INTO %(table)s (pdu_id, origin, context) "
- "SELECT ?, ?, ? WHERE NOT EXISTS ("
- "SELECT 1 FROM %(pdu_edges)s WHERE "
- "prev_pdu_id = ? AND prev_origin = ?"
- ")"
- ) % {
- "table": PduForwardExtremitiesTable.table_name,
- "pdu_edges": PduEdgesTable.table_name
- }
-
- logger.debug("query: %s", query)
-
- txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
-
- # Insert all the prev_pdus as a backwards thing, they'll get
- # deleted in a second if they're incorrect anyway.
- txn.executemany(
- PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o, _ in prev_pdus]
- )
-
- # Also delete from the backwards extremities table all ones that
- # reference pdus that we have already seen
- query = (
- "DELETE FROM %(pdu_back)s WHERE EXISTS ("
- "SELECT 1 FROM %(pdus)s AS pdus "
- "WHERE "
- "%(pdu_back)s.pdu_id = pdus.pdu_id "
- "AND %(pdu_back)s.origin = pdus.origin "
- "AND not pdus.outlier "
- ")"
- ) % {
- "pdu_back": PduBackwardExtremitiesTable.table_name,
- "pdus": PdusTable.table_name,
- }
- txn.execute(query)
-
-
-class StatePduStore(SQLBaseStore):
- """A collection of queries for handling state PDUs.
- """
-
- def _persist_state_txn(self, txn, prev_pdus, cols):
- """Inserts a state PDU into the database
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable and StatePdusTable
- """
- pdu_entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
- state_entry = StatePdusTable.EntryType(
- **{k: cols.get(k, None) for k in StatePdusTable.fields}
- )
-
- logger.debug("Inserting pdu: %s", repr(pdu_entry))
- logger.debug("Inserting state: %s", repr(state_entry))
-
- txn.execute(PdusTable.insert_statement(), pdu_entry)
- txn.execute(StatePdusTable.insert_statement(), state_entry)
-
- self._handle_prev_pdus(
- txn,
- pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
- pdu_entry.context
- )
-
- def get_unresolved_state_tree(self, new_state_pdu):
- return self.runInteraction(
- self._get_unresolved_state_tree, new_state_pdu
- )
-
- @log_function
- def _get_unresolved_state_tree(self, txn, new_pdu):
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- ReturnType = namedtuple(
- "StateReturnType", ["new_branch", "current_branch"]
- )
- return_value = ReturnType([new_pdu], [])
-
- if not current:
- logger.debug("get_unresolved_state_tree No current state.")
- return (return_value, None)
-
- return_value.current_branch.append(current)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
-
- missing_branch = None
- for branch, prev_state, state in enum_branches:
- if state:
- return_value[branch].append(state)
- else:
- # We don't have prev_state :(
- missing_branch = branch
- break
-
- return (return_value, missing_branch)
-
- def update_current_state(self, pdu_id, origin, context, pdu_type,
- state_key):
- return self.runInteraction(
- self._update_current_state,
- pdu_id, origin, context, pdu_type, state_key
- )
-
- def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
- state_key):
- query = (
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- ) % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- }
-
- query_args = CurrentStateTable.EntryType(
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- pdu_type=pdu_type,
- state_key=state_key
- )
-
- txn.execute(query, query_args)
-
- def get_current_state_pdu(self, context, pdu_type, state_key):
- """For a given context, pdu_type, state_key 3-tuple, return what is
- currently considered the current state.
-
- Args:
- txn
- context (str)
- pdu_type (str)
- state_key (str)
-
- Returns:
- PduEntry
- """
-
- return self.runInteraction(
- self._get_current_state_pdu, context, pdu_type, state_key
- )
-
- def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
- return self._get_current_interaction(txn, context, pdu_type, state_key)
-
- def _get_current_interaction(self, txn, context, pdu_type, state_key):
- logger.debug(
- "_get_current_interaction %s %s %s",
- context, pdu_type, state_key
- )
-
- fields = _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s")
-
- current_query = (
- "SELECT %(fields)s FROM %(state)s as s "
- "INNER JOIN %(pdus)s as p "
- "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
- "INNER JOIN %(curr)s as c "
- "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
- "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
- ) % {
- "fields": fields,
- "curr": CurrentStateTable.table_name,
- "state": StatePdusTable.table_name,
- "pdus": PdusTable.table_name,
- }
-
- txn.execute(
- current_query,
- (context, pdu_type, state_key)
- )
-
- row = txn.fetchone()
-
- result = PduEntry(*row) if row else None
-
- if not result:
- logger.debug("_get_current_interaction not found")
- else:
- logger.debug(
- "_get_current_interaction found %s %s",
- result.pdu_id, result.origin
- )
-
- return result
-
- def handle_new_state(self, new_pdu):
- """Actually perform conflict resolution on the new_pdu on the
- assumption we have all the pdus required to perform it.
-
- Args:
- new_pdu
-
- Returns:
- bool: True if the new_pdu clobbered the current state, False if not
- """
- return self.runInteraction(
- self._handle_new_state, new_pdu
- )
-
- def _handle_new_state(self, txn, new_pdu):
- logger.debug(
- "handle_new_state %s %s",
- new_pdu.pdu_id, new_pdu.origin
- )
-
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- is_current = False
-
- if (not current or not current.prev_state_id
- or not current.prev_state_origin):
- # Oh, we don't have any state for this yet.
- is_current = True
- elif (current.pdu_id == new_pdu.prev_state_id
- and current.origin == new_pdu.prev_state_origin):
- # Oh! A direct clobber. Just do it.
- is_current = True
- else:
- ##
- # Ok, now loop through until we get to a common ancestor.
- max_new = int(new_pdu.power_level)
- max_current = int(current.power_level)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
- for branch, prev_state, state in enum_branches:
- if not state:
- raise RuntimeError(
- "Could not find state_pdu %s %s" %
- (
- prev_state.prev_state_id,
- prev_state.prev_state_origin
- )
- )
-
- if branch == 0:
- max_new = max(int(state.depth), max_new)
- else:
- max_current = max(int(state.depth), max_current)
-
- is_current = max_new > max_current
-
- if is_current:
- logger.debug("handle_new_state make current")
-
- # Right, this is a new thing, so woo, just insert it.
- txn.execute(
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- },
- CurrentStateTable.EntryType(
- *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
- )
- )
- else:
- logger.debug("handle_new_state not current")
-
- logger.debug("handle_new_state done")
-
- return is_current
-
- @log_function
- def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
- branch_a = pdu_a
- branch_b = pdu_b
-
- while True:
- if (branch_a.pdu_id == branch_b.pdu_id
- and branch_a.origin == branch_b.origin):
- # Woo! We found a common ancestor
- logger.debug("_enumerate_state_branches Found common ancestor")
- break
-
- do_branch_a = (
- hasattr(branch_a, "prev_state_id") and
- branch_a.prev_state_id
- )
-
- do_branch_b = (
- hasattr(branch_b, "prev_state_id") and
- branch_b.prev_state_id
- )
-
- logger.debug(
- "do_branch_a=%s, do_branch_b=%s",
- do_branch_a, do_branch_b
- )
-
- if do_branch_a and do_branch_b:
- do_branch_a = int(branch_a.depth) > int(branch_b.depth)
-
- if do_branch_a:
- pdu_tuple = PduIdTuple(
- branch_a.prev_state_id,
- branch_a.prev_state_origin
- )
-
- prev_branch = branch_a
-
- logger.debug("getting branch_a prev %s", pdu_tuple)
- branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_a:
- branch_a = Pdu.from_pdu_tuple(branch_a)
-
- logger.debug("branch_a=%s", branch_a)
-
- yield (0, prev_branch, branch_a)
-
- if not branch_a:
- break
- elif do_branch_b:
- pdu_tuple = PduIdTuple(
- branch_b.prev_state_id,
- branch_b.prev_state_origin
- )
-
- prev_branch = branch_b
-
- logger.debug("getting branch_b prev %s", pdu_tuple)
- branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_b:
- branch_b = Pdu.from_pdu_tuple(branch_b)
-
- logger.debug("branch_b=%s", branch_b)
-
- yield (1, prev_branch, branch_b)
-
- if not branch_b:
- break
- else:
- break
-
-
-class PdusTable(Table):
- table_name = "pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "ts",
- "depth",
- "is_state",
- "content_json",
- "unrecognized_keys",
- "outlier",
- "have_processed",
- ]
-
- EntryType = namedtuple("PdusEntry", fields)
-
-
-class PduDestinationsTable(Table):
- table_name = "pdu_destinations"
-
- fields = [
- "pdu_id",
- "origin",
- "destination",
- "delivered_ts",
- ]
-
- EntryType = namedtuple("PduDestinationsEntry", fields)
-
-
-class PduEdgesTable(Table):
- table_name = "pdu_edges"
-
- fields = [
- "pdu_id",
- "origin",
- "prev_pdu_id",
- "prev_origin",
- "context"
- ]
-
- EntryType = namedtuple("PduEdgesEntry", fields)
-
-
-class PduForwardExtremitiesTable(Table):
- table_name = "pdu_forward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
-
-
-class PduBackwardExtremitiesTable(Table):
- table_name = "pdu_backward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
-
-
-class ContextDepthTable(Table):
- table_name = "context_depth"
-
- fields = [
- "context",
- "min_depth",
- ]
-
- EntryType = namedtuple("ContextDepthEntry", fields)
-
-
-class StatePdusTable(Table):
- table_name = "state_pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- "power_level",
- "prev_state_id",
- "prev_state_origin",
- ]
-
- EntryType = namedtuple("StatePdusEntry", fields)
-
-
-class CurrentStateTable(Table):
- table_name = "current_state"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- ]
-
- EntryType = namedtuple("CurrentStateEntry", fields)
-
-_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
-
-
-# TODO: These should probably be put somewhere more sensible
-PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
-
-PduEntry = _pdu_state_joiner.EntryType
-""" We are always interested in the join of the PdusTable and StatePdusTable,
-rather than just the PdusTable.
-
-This does not include a prev_pdus key.
-"""
-
-PduTuple = namedtuple(
- "PduTuple",
- ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
-)
-""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
-the `prev_pdus` key of a PDU.
-"""
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 719806f82b..a2ca6f9a69 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(self._register, user_id, token,
- password_hash)
+ yield self.runInteraction(
+ "register",
+ self._register, user_id, token, password_hash
+ )
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
@@ -100,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
+ "get_user_by_token",
self._query_for_auth,
token
)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8cd46334cf..7e48ce9cc3 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -150,6 +150,7 @@ class RoomStore(SQLBaseStore):
def get_power_level(self, room_id, user_id):
return self.runInteraction(
+ "get_power_level",
self._get_power_level,
room_id, user_id,
)
@@ -183,6 +184,7 @@ class RoomStore(SQLBaseStore):
def get_ops_levels(self, room_id):
return self.runInteraction(
+ "get_ops_levels",
self._get_ops_levels,
room_id,
)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index ceeef5880e..93329703a2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -33,7 +33,9 @@ class RoomMemberStore(SQLBaseStore):
target_user_id = event.state_key
domain = self.hs.parse_userid(target_user_id).domain
except:
- logger.exception("Failed to parse target_user_id=%s", target_user_id)
+ logger.exception(
+ "Failed to parse target_user_id=%s", target_user_id
+ )
raise
logger.debug(
@@ -65,7 +67,8 @@ class RoomMemberStore(SQLBaseStore):
# Check if this was the last person to have left.
member_events = self._get_members_query_txn(
txn,
- where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?",
+ where_clause=("c.room_id = ? AND m.membership = ?"
+ " AND m.user_id != ?"),
where_values=(event.room_id, Membership.JOIN, target_user_id,)
)
@@ -120,7 +123,6 @@ class RoomMemberStore(SQLBaseStore):
else:
return None
-
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
deleted file mode 100644
index 8a00868065..0000000000
--- a/synapse/storage/schema/edge_pdus.sql
+++ /dev/null
@@ -1,31 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS context_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE TABLE IF NOT EXISTS origin_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
-CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
new file mode 100644
index 0000000000..e5f768c705
--- /dev/null
+++ b/synapse/storage/schema/event_edges.sql
@@ -0,0 +1,49 @@
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+ event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+ event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+ event_id TEXT,
+ prev_event_id TEXT,
+ room_id TEXT,
+ CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
+);
+
+CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
+CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+ room_id TEXT,
+ min_depth INTEGER,
+ CONSTRAINT uniqueness UNIQUE (room_id)
+);
+
+CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+ event_id TEXT,
+ destination TEXT,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql
new file mode 100644
index 0000000000..5491c7ecec
--- /dev/null
+++ b/synapse/storage/schema/event_signatures.sql
@@ -0,0 +1,65 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_origin_signatures (
+ event_id TEXT,
+ origin TEXT,
+ key_id TEXT,
+ signature BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+ event_id TEXT,
+ prev_event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ event_id, prev_event_id, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
+ event_id
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 3aa83f5c8c..8d6f655993 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
+ depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
deleted file mode 100644
index 16e111a56c..0000000000
--- a/synapse/storage/schema/pdu.sql
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
--- Stores pdus and their content
-CREATE TABLE IF NOT EXISTS pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- ts INTEGER,
- depth INTEGER DEFAULT 0 NOT NULL,
- is_state BOOL,
- content_json TEXT,
- unrecognized_keys TEXT,
- outlier BOOL NOT NULL,
- have_processed BOOL,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
--- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
-CREATE TABLE IF NOT EXISTS state_pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- power_level TEXT,
- prev_state_id TEXT,
- prev_state_origin TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
-);
-
-CREATE TABLE IF NOT EXISTS current_state(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
-);
-
--- Stores where each pdu we want to send should be sent and the delivery status.
-create TABLE IF NOT EXISTS pdu_destinations(
- pdu_id TEXT,
- origin TEXT,
- destination TEXT,
- delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_edges(
- pdu_id TEXT,
- origin TEXT,
- prev_pdu_id TEXT,
- prev_origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
-);
-
-CREATE TABLE IF NOT EXISTS context_depth(
- context TEXT,
- min_depth INTEGER,
- CONSTRAINT uniqueness UNIQUE (context)
-);
-
-CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
-
-
-CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
--- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
-
-CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
-CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/signatures.sql b/synapse/storage/schema/signatures.sql
deleted file mode 100644
index 1c45a51bec..0000000000
--- a/synapse/storage/schema/signatures.sql
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE TABLE IF NOT EXISTS pdu_content_hashes (
- pdu_id TEXT,
- origin TEXT,
- algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
-);
-
-CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes (
- pdu_id, origin
-);
-
-CREATE TABLE IF NOT EXISTS pdu_reference_hashes (
- pdu_id TEXT,
- origin TEXT,
- algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
-);
-
-CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes (
- pdu_id, origin
-);
-
-CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
- pdu_id TEXT,
- origin TEXT,
- key_id TEXT,
- signature BLOB,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id)
-);
-
-CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
- pdu_id, origin
-);
-
-CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
- pdu_id TEXT,
- origin TEXT,
- prev_pdu_id TEXT,
- prev_origin TEXT,
- algorithm TEXT,
- hash BLOB,
- CONSTRAINT uniqueness UNIQUE (
- pdu_id, origin, prev_pdu_id, prev_origin, algorithm
- )
-);
-
-CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
- pdu_id, origin
-);
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
new file mode 100644
index 0000000000..b44c56b519
--- /dev/null
+++ b/synapse/storage/schema/state.sql
@@ -0,0 +1,33 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+ id INTEGER PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+ state_group INTEGER NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+ event_id TEXT NOT NULL,
+ state_group INTEGER NOT NULL
+);
\ No newline at end of file
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 82be946d3f..b4b3d5d7ea 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -17,139 +17,149 @@ from _base import SQLBaseStore
class SignatureStore(SQLBaseStore):
- """Persistence for PDU signatures and hashes"""
+ """Persistence for event signatures and hashes"""
- def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin):
- """Get all the hashes for a given PDU.
+ def _get_event_content_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given Event.
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
- " FROM pdu_content_hashes"
- " WHERE pdu_id = ? and origin = ?"
+ " FROM event_content_hashes"
+ " WHERE event_id = ?"
)
- txn.execute(query, (pdu_id, origin))
+ txn.execute(query, (event_id, ))
return dict(txn.fetchall())
- def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm,
+ def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
- """Store a hash for a PDU
+ """Store a hash for a Event
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
- self._simple_insert_txn(txn, "pdu_content_hashes", {
- "pdu_id": pdu_id,
- "origin": origin,
- "algorithm": algorithm,
- "hash": buffer(hash_bytes),
- })
+ self._simple_insert_txn(
+ txn,
+ "event_content_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
- def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin):
+ def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
- " FROM pdu_reference_hashes"
- " WHERE pdu_id = ? and origin = ?"
+ " FROM event_reference_hashes"
+ " WHERE event_id = ?"
)
- txn.execute(query, (pdu_id, origin))
+ txn.execute(query, (event_id, ))
return dict(txn.fetchall())
- def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm,
+ def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a PDU
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
- self._simple_insert_txn(txn, "pdu_reference_hashes", {
- "pdu_id": pdu_id,
- "origin": origin,
- "algorithm": algorithm,
- "hash": buffer(hash_bytes),
- })
+ self._simple_insert_txn(
+ txn,
+ "event_reference_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
- def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin):
+ def _get_event_origin_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
Returns:
A dict of key_id -> signature_bytes.
"""
query = (
"SELECT key_id, signature"
- " FROM pdu_origin_signatures"
- " WHERE pdu_id = ? and origin = ?"
+ " FROM event_origin_signatures"
+ " WHERE event_id = ? "
)
- txn.execute(query, (pdu_id, origin))
+ txn.execute(query, (event_id, ))
return dict(txn.fetchall())
- def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id,
- signature_bytes):
+ def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
+ signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
- pdu_id (str): Id for the PDU.
- origin (str): origin of the PDU.
+ event_id (str): Id for the Event.
+ origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
- self._simple_insert_txn(txn, "pdu_origin_signatures", {
- "pdu_id": pdu_id,
- "origin": origin,
- "key_id": key_id,
- "signature": buffer(signature_bytes),
- })
+ self._simple_insert_txn(
+ txn,
+ "event_origin_signatures",
+ {
+ "event_id": event_id,
+ "origin": origin,
+ "key_id": key_id,
+ "signature": buffer(signature_bytes),
+ },
+ or_ignore=True,
+ )
- def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
+ def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
- pdu_id (str): Id of the PDU.
- origin (str): Origin of the PDU.
+ event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
- "SELECT prev_pdu_id, prev_origin, algorithm, hash"
- " FROM pdu_edge_hashes"
- " WHERE pdu_id = ? and origin = ?"
+ "SELECT prev_event_id, algorithm, hash"
+ " FROM event_edge_hashes"
+ " WHERE event_id = ?"
)
- txn.execute(query, (pdu_id, origin))
+ txn.execute(query, (event_id, ))
results = {}
- for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
- hashes = results.setdefault((prev_pdu_id, prev_origin), {})
+ for prev_event_id, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
- def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
- prev_origin, algorithm, hash_bytes):
- self._simple_insert_txn(txn, "pdu_edge_hashes", {
- "pdu_id": pdu_id,
- "origin": origin,
- "prev_pdu_id": prev_pdu_id,
- "prev_origin": prev_origin,
- "algorithm": algorithm,
- "hash": buffer(hash_bytes),
- })
+ def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+ algorithm, hash_bytes):
+ self._simple_insert_txn(
+ txn,
+ "event_edge_hashes",
+ {
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
new file mode 100644
index 0000000000..e08acd6404
--- /dev/null
+++ b/synapse/storage/state.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from collections import namedtuple
+
+
+StateGroup = namedtuple("StateGroup", ("group", "state"))
+
+
+class StateStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, event_ids):
+ groups = set()
+ for event_id in event_ids:
+ group = yield self._simple_select_one_onecol(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ groups.add(group)
+
+ res = []
+ for group in groups:
+ state_ids = yield self._simple_select_onecol(
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+ state = []
+ for state_id in state_ids:
+ s = yield self.get_event(
+ state_id,
+ allow_none=True,
+ )
+ if s:
+ state.append(s)
+
+ res.append(StateGroup(group, state))
+
+ defer.returnValue(res)
+
+ def store_state_groups(self, event):
+ return self.runInteraction(
+ "store_state_groups",
+ self._store_state_groups_txn, event
+ )
+
+ def _store_state_groups_txn(self, txn, event):
+ if not event.state_events:
+ return
+
+ state_group = event.state_group
+ if not state_group:
+ state_group = self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ }
+ )
+
+ for state in event.state_events.values():
+ self._simple_insert_txn(
+ txn,
+ table="state_groups_state",
+ values={
+ "state_group": state_group,
+ "room_id": state.room_id,
+ "type": state.type,
+ "state_key": state.state_key,
+ "event_id": state.event_id,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="event_to_state_groups",
+ values={
+ "state_group": state_group,
+ "event_id": event.event_id,
+ }
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d61f909939..8f7f61d29d 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -309,7 +309,10 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
- return self.runInteraction(self._get_room_events_max_id_txn)
+ return self.runInteraction(
+ "get_room_events_max_id",
+ self._get_room_events_max_id_txn
+ )
def _get_room_events_max_id_txn(self, txn):
txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 2ba8e30efe..00d0f48082 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore, Table
-from .pdu import PdusTable
from collections import namedtuple
@@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
)
@@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "set_received_txn_response",
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
transaction_id (str)
destination (str)
origin_server_ts (int)
- pdu_list (list)
Returns:
list: A list of previous transaction ids.
"""
return self.runInteraction(
+ "prep_send_transaction",
self._prep_send_transaction,
- transaction_id, destination, origin_server_ts, pdu_list
+ transaction_id, destination, origin_server_ts
)
def _prep_send_transaction(self, txn, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
# First we find out what the prev_txs should be.
# Since we know that we are only sending one transaction at a time,
@@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
# Update the tx id -> pdu id mapping
- values = [
- (transaction_id, destination, pdu[0], pdu[1])
- for pdu in pdu_list
- ]
-
- logger.debug("Inserting: %s", repr(values))
-
- query = TransactionsToPduTable.insert_statement()
- txn.executemany(query, values)
+ # values = [
+ # (transaction_id, destination, pdu[0], pdu[1])
+ # for pdu in pdu_list
+ # ]
+ #
+ # logger.debug("Inserting: %s", repr(values))
+ #
+ # query = TransactionsToPduTable.insert_statement()
+ # txn.executemany(query, values)
return prev_txns
@@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.runInteraction(
+ "delivered_txn",
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self.runInteraction(
+ "get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
@@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
- def get_pdus_after_transaction(self, transaction_id, destination):
- """For a given local transaction_id that we sent to a given destination
- home server, return a list of PDUs that were sent to that destination
- after it.
-
- Args:
- txn
- transaction_id (str)
- destination (str)
-
- Returns
- list: A list of PduTuple
- """
- return self.runInteraction(
- self._get_pdus_after_transaction,
- transaction_id, destination
- )
-
- def _get_pdus_after_transaction(self, txn, transaction_id, destination):
-
- # Query that first get's all transaction_ids with an id greater than
- # the one given from the `sent_transactions` table. Then JOIN on this
- # from the `tx->pdu` table to get a list of (pdu_id, origin) that
- # specify the pdus that were sent in those transactions.
- query = (
- "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
- "INNER JOIN %(sent_tx)s as st "
- "ON tp.transaction_id = st.transaction_id "
- "AND tp.destination = st.destination "
- "WHERE st.id > ("
- "SELECT id FROM %(sent_tx)s "
- "WHERE transaction_id = ? AND destination = ?"
- ) % {
- "tx_pdu": TransactionsToPduTable.table_name,
- "sent_tx": SentTransactions.table_name,
- }
-
- txn.execute(query, (transaction_id, destination))
-
- pdus = PdusTable.decode_results(txn.fetchall())
-
- return self._get_pdu_tuples(txn, pdus)
-
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 6483ce2e25..527507e5cd 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -22,6 +22,19 @@ import logging
logger = logging.getLogger(__name__)
+class SourcePaginationConfig(object):
+
+ """A configuration object which stores pagination parameters for a
+ specific event source."""
+
+ def __init__(self, from_key=None, to_key=None, direction='f',
+ limit=0):
+ self.from_key = from_key
+ self.to_key = to_key
+ self.direction = 'f' if direction == 'f' else 'b'
+ self.limit = int(limit)
+
+
class PaginationConfig(object):
"""A configuration object which stores pagination parameters."""
@@ -82,3 +95,13 @@ class PaginationConfig(object):
"<PaginationConfig from_tok=%s, to_tok=%s, "
"direction=%s, limit=%s>"
) % (self.from_token, self.to_token, self.direction, self.limit)
+
+ def get_source_config(self, source_name):
+ keyname = "%s_key" % source_name
+
+ return SourcePaginationConfig(
+ from_key=getattr(self.from_token, keyname),
+ to_key=getattr(self.to_token, keyname) if self.to_token else None,
+ direction=self.direction,
+ limit=self.limit,
+ )
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 41715436b0..fb698d2d71 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -35,7 +35,7 @@ class NullSource(object):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
- return defer.succeed(([], pagination_config.from_token))
+ return defer.succeed(([], pagination_config.from_key))
class EventSources(object):
diff --git a/synapse/types.py b/synapse/types.py
index c51bc8e4f2..649ff2f7d7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -78,6 +78,11 @@ class DomainSpecificString(
"""Create a structure on the local domain"""
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
+ @classmethod
+ def create(cls, localpart, domain, hs):
+ is_mine = domain == hs.hostname
+ return cls(localpart=localpart, domain=domain, is_mine=is_mine)
+
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
@@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
SIGIL = "!"
+class EventID(DomainSpecificString):
+ """Structure representing an event id. """
+ SIGIL = "$"
+
+
class StreamToken(
namedtuple(
"Token",
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 647ea6142c..bf578f8bfb 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -21,3 +21,10 @@ def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
return d
+
+
+def run_on_reactor():
+ """ This will cause the rest of the function to be invoked upon the next
+ iteration of the main loop
+ """
+ return sleep(0)
\ No newline at end of file
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 1de50e049f..eddbe5837f 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -42,7 +42,8 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
- self.signals[name] = Signal(name,
+ self.signals[name] = Signal(
+ name,
suppress_failures=self.suppress_failures,
)
diff --git a/synapse/util/emailutils.py b/synapse/util/emailutils.py
index cdb0abd7ea..7038cab6c2 100644
--- a/synapse/util/emailutils.py
+++ b/synapse/util/emailutils.py
@@ -42,8 +42,8 @@ def send_email(smtp_server, from_addr, to_addr, subject, body):
EmailException if there was a problem sending the mail.
"""
if not smtp_server or not from_addr or not to_addr:
- raise EmailException("Need SMTP server, from and to addresses. Check " +
- "the config to set these.")
+ raise EmailException("Need SMTP server, from and to addresses. Check"
+ " the config to set these.")
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
@@ -68,4 +68,4 @@ def send_email(smtp_server, from_addr, to_addr, subject, body):
twisted.python.log.err()
ese = EmailException()
ese.cause = origException
- raise ese
\ No newline at end of file
+ raise ese
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 6c99705747..c91eb897a8 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import copy
+
class JsonEncodedObject(object):
""" A common base class for defining protocol units that are represented
as JSON.
@@ -89,6 +89,7 @@ class JsonEncodedObject(object):
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
+
def _encode(obj):
if type(obj) is list:
return [_encode(o) for o in obj]
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 1850deacf5..fdc2e8de4a 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -29,6 +29,7 @@ from synapse.server import HomeServer
from synapse.api.constants import PresenceState
from synapse.api.errors import SynapseError
from synapse.handlers.presence import PresenceHandler, UserPresenceCache
+from synapse.streams.config import SourcePaginationConfig
OFFLINE = PresenceState.OFFLINE
@@ -676,6 +677,21 @@ class PresencePushTestCase(unittest.TestCase):
msg="Presence event should be visible to self-reflection"
)
+ config = SourcePaginationConfig(from_key=1, to_key=0)
+ (chunk, _) = yield self.event_source.get_pagination_rows(
+ self.u_apple, config, None
+ )
+ self.assertEquals(chunk,
+ [
+ {"type": "m.presence",
+ "content": {
+ "user_id": "@apple:test",
+ "presence": ONLINE,
+ "last_active_ago": 0,
+ }},
+ ]
+ )
+
# Banana sees it because of presence subscription
(events, _) = yield self.event_source.get_new_events_for_user(
self.u_banana, 0, None
diff --git a/webclient/app-controller.js b/webclient/app-controller.js
index 7d61207554..e4b7cd286f 100644
--- a/webclient/app-controller.js
+++ b/webclient/app-controller.js
@@ -53,7 +53,7 @@ angular.module('MatrixWebClientController', ['matrixService', 'mPresence', 'even
* Open a given page.
* @param {String} url url of the page
*/
- $scope.goToPage = function(url) {
+ $rootScope.goToPage = function(url) {
$location.url(url);
};
diff --git a/webclient/app-directive.js b/webclient/app-directive.js
index 75283598ab..c1ba0af3a9 100644
--- a/webclient/app-directive.js
+++ b/webclient/app-directive.js
@@ -40,4 +40,45 @@ angular.module('matrixWebClient')
}
}
};
-}]);
\ No newline at end of file
+}])
+.directive('asjson', function() {
+ return {
+ restrict: 'A',
+ require: 'ngModel',
+ link: function (scope, element, attrs, ngModelCtrl) {
+ function isValidJson(model) {
+ var flag = true;
+ try {
+ angular.fromJson(model);
+ } catch (err) {
+ flag = false;
+ }
+ return flag;
+ };
+
+ function string2JSON(text) {
+ try {
+ var j = angular.fromJson(text);
+ ngModelCtrl.$setValidity('json', true);
+ return j;
+ } catch (err) {
+ //returning undefined results in a parser error as of angular-1.3-rc.0, and will not go through $validators
+ //return undefined
+ ngModelCtrl.$setValidity('json', false);
+ return text;
+ }
+ };
+
+ function JSON2String(object) {
+ return angular.toJson(object, true);
+ };
+
+ //$validators is an object, where key is the error
+ //ngModelCtrl.$validators.json = isValidJson;
+
+ //array pipelines
+ ngModelCtrl.$parsers.push(string2JSON);
+ ngModelCtrl.$formatters.push(JSON2String);
+ }
+ }
+});
diff --git a/webclient/app-filter.js b/webclient/app-filter.js
index 39ea1d637d..f19db4141d 100644
--- a/webclient/app-filter.js
+++ b/webclient/app-filter.js
@@ -76,6 +76,17 @@ angular.module('matrixWebClient')
return filtered;
};
})
+.filter('stateEventsFilter', function($sce) {
+ return function(events) {
+ var filtered = {};
+ angular.forEach(events, function(value, key) {
+ if (value && typeof(value.state_key) === "string") {
+ filtered[key] = value;
+ }
+ });
+ return filtered;
+ };
+})
.filter('unsafe', ['$sce', function($sce) {
return function(text) {
return $sce.trustAsHtml(text);
diff --git a/webclient/app.css b/webclient/app.css
index bdf475d635..5ab8e2b8fd 100755
--- a/webclient/app.css
+++ b/webclient/app.css
@@ -403,6 +403,7 @@ textarea, input {
}
.roomNameSection, .roomTopicSection {
+ text-align: right;
float: right;
width: 100%;
}
@@ -412,9 +413,40 @@ textarea, input {
}
.roomHeaderInfo {
+ text-align: right;
float: right;
margin-top: 15px;
- width: 50%;
+}
+
+/*** Room Info Dialog ***/
+
+.room-info {
+ border-collapse: collapse;
+ width: 100%;
+}
+
+.room-info-event {
+ border-bottom: 1pt solid black;
+}
+
+.room-info-event-meta {
+ padding-top: 1em;
+ padding-bottom: 1em;
+}
+
+.room-info-event-content {
+ padding-top: 1em;
+ padding-bottom: 1em;
+}
+
+.monospace {
+ font-family: monospace;
+}
+
+.room-info-textarea-content {
+ height: auto;
+ width: 100%;
+ resize: vertical;
}
/*** Participant list ***/
diff --git a/webclient/app.js b/webclient/app.js
index 099e2170a0..c091f8c6cf 100644
--- a/webclient/app.js
+++ b/webclient/app.js
@@ -30,8 +30,10 @@ var matrixWebClient = angular.module('matrixWebClient', [
'MatrixCall',
'eventStreamService',
'eventHandlerService',
+ 'notificationService',
'infinite-scroll',
- 'ui.bootstrap'
+ 'ui.bootstrap',
+ 'monospaced.elastic'
]);
matrixWebClient.config(['$routeProvider', '$provide', '$httpProvider',
diff --git a/webclient/components/matrix/event-handler-service.js b/webclient/components/matrix/event-handler-service.js
index e7109c0cb4..e63584510b 100644
--- a/webclient/components/matrix/event-handler-service.js
+++ b/webclient/components/matrix/event-handler-service.js
@@ -27,8 +27,8 @@ Typically, this service will store events or broadcast them to any listeners
if typically all the $on method would do is update its own $scope.
*/
angular.module('eventHandlerService', [])
-.factory('eventHandlerService', ['matrixService', '$rootScope', '$q', '$timeout', 'mPresence',
-function(matrixService, $rootScope, $q, $timeout, mPresence) {
+.factory('eventHandlerService', ['matrixService', '$rootScope', '$q', '$timeout', 'mPresence', 'notificationService',
+function(matrixService, $rootScope, $q, $timeout, mPresence, notificationService) {
var ROOM_CREATE_EVENT = "ROOM_CREATE_EVENT";
var MSG_EVENT = "MSG_EVENT";
var MEMBER_EVENT = "MEMBER_EVENT";
@@ -45,44 +45,6 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
var eventMap = {};
$rootScope.presence = {};
-
- // TODO: This is attached to the rootScope so .html can just go containsBingWord
- // for determining classes so it is easy to highlight bing messages. It seems a
- // bit strange to put the impl in this service though, but I can't think of a better
- // file to put it in.
- $rootScope.containsBingWord = function(content) {
- if (!content || $.type(content) != "string") {
- return false;
- }
- var bingWords = matrixService.config().bingWords;
- var shouldBing = false;
-
- // case-insensitive name check for user_id OR display_name if they exist
- var myUserId = matrixService.config().user_id;
- if (myUserId) {
- myUserId = myUserId.toLocaleLowerCase();
- }
- var myDisplayName = matrixService.config().display_name;
- if (myDisplayName) {
- myDisplayName = myDisplayName.toLocaleLowerCase();
- }
- if ( (myDisplayName && content.toLocaleLowerCase().indexOf(myDisplayName) != -1) ||
- (myUserId && content.toLocaleLowerCase().indexOf(myUserId) != -1) ) {
- shouldBing = true;
- }
-
- // bing word list check
- if (bingWords && !shouldBing) {
- for (var i=0; i<bingWords.length; i++) {
- var re = RegExp(bingWords[i]);
- if (content.search(re) != -1) {
- shouldBing = true;
- break;
- }
- }
- }
- return shouldBing;
- };
var initialSyncDeferred;
@@ -172,6 +134,17 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
};
var handleMessage = function(event, isLiveEvent) {
+ // Check for empty event content
+ var hasContent = false;
+ for (var prop in event.content) {
+ hasContent = true;
+ break;
+ }
+ if (!hasContent) {
+ // empty json object is a redacted event, so ignore.
+ return;
+ }
+
if (isLiveEvent) {
if (event.user_id === matrixService.config().user_id &&
(event.content.msgtype === "m.text" || event.content.msgtype === "m.emote") ) {
@@ -190,7 +163,12 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
}
if (window.Notification && event.user_id != matrixService.config().user_id) {
- var shouldBing = $rootScope.containsBingWord(event.content.body);
+ var shouldBing = notificationService.containsBingWord(
+ matrixService.config().user_id,
+ matrixService.config().display_name,
+ matrixService.config().bingWords,
+ event.content.body
+ );
// Ideally we would notify only when the window is hidden (i.e. document.hidden = true).
//
@@ -220,17 +198,29 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
if (event.content.msgtype === "m.emote") {
message = "* " + displayname + " " + message;
}
+ else if (event.content.msgtype === "m.image") {
+ message = displayname + " sent an image.";
+ }
+
+ var roomTitle = matrixService.getRoomIdToAliasMapping(event.room_id);
+ var theRoom = $rootScope.events.rooms[event.room_id];
+ if (!roomTitle && theRoom && theRoom["m.room.name"] && theRoom["m.room.name"].content) {
+ roomTitle = theRoom["m.room.name"].content.name;
+ }
- var notification = new window.Notification(
- displayname +
- " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")", // FIXME: don't leak room_ids here
- {
- "body": message,
- "icon": member ? member.avatar_url : undefined
- });
- $timeout(function() {
- notification.close();
- }, 5 * 1000);
+ if (!roomTitle) {
+ roomTitle = event.room_id;
+ }
+
+ notificationService.showNotification(
+ displayname + " (" + roomTitle + ")",
+ message,
+ member ? member.avatar_url : undefined,
+ function() {
+ console.log("notification.onclick() room=" + event.room_id);
+ $rootScope.goToPage('room/' + event.room_id);
+ }
+ );
}
}
}
@@ -319,6 +309,31 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
$rootScope.events.rooms[event.room_id].messages.push(event);
}
};
+
+ var handleRedaction = function(event, isLiveEvent) {
+ if (!isLiveEvent) {
+ // we have nothing to remove, so just ignore it.
+ console.log("Received redacted event: "+JSON.stringify(event));
+ return;
+ }
+
+ // we need to remove something possibly: do we know the redacted
+ // event ID?
+ if (eventMap[event.redacts]) {
+ // remove event from list of messages in this room.
+ var eventList = $rootScope.events.rooms[event.room_id].messages;
+ for (var i=0; i<eventList.length; i++) {
+ if (eventList[i].event_id === event.redacts) {
+ console.log("Removing event " + event.redacts);
+ eventList.splice(i, 1);
+ break;
+ }
+ }
+
+ // broadcast the redaction so controllers can nuke this
+ console.log("Redacted an event.");
+ }
+ }
/**
* Get the index of the event in $rootScope.events.rooms[room_id].messages
@@ -481,7 +496,17 @@ function(matrixService, $rootScope, $q, $timeout, mPresence) {
case 'm.room.topic':
handleRoomTopic(event, isLiveEvent, isStateEvent);
break;
+ case 'm.room.redaction':
+ handleRedaction(event, isLiveEvent);
+ break;
default:
+ // if it is a state event, then just add it in so it
+ // displays on the Room Info screen.
+ if (typeof(event.state_key) === "string") { // incls. 0-len strings
+ if (event.room_id) {
+ handleRoomDateEvent(event, isLiveEvent, false);
+ }
+ }
console.log("Unable to handle event type " + event.type);
console.log(JSON.stringify(event, undefined, 4));
break;
diff --git a/webclient/components/matrix/matrix-filter.js b/webclient/components/matrix/matrix-filter.js
index e6f2acc5fd..3d64a569a1 100644
--- a/webclient/components/matrix/matrix-filter.js
+++ b/webclient/components/matrix/matrix-filter.js
@@ -47,7 +47,6 @@ angular.module('matrixFilter', [])
else if (room.members && !isPublicRoom) { // Do not rename public room
var user_id = matrixService.config().user_id;
-
// Else, build the name from its users
// Limit the room renaming to 1:1 room
if (2 === Object.keys(room.members).length) {
@@ -65,8 +64,16 @@ angular.module('matrixFilter', [])
var otherUserId;
- if (Object.keys(room.members)[0] && Object.keys(room.members)[0] !== user_id) {
+ if (Object.keys(room.members)[0]) {
otherUserId = Object.keys(room.members)[0];
+ // this could be an invite event (from event stream)
+ if (otherUserId === user_id &&
+ room.members[user_id].content.membership === "invite") {
+ // this is us being invited to this room, so the
+ // *user_id* is the other user ID and not the state
+ // key.
+ otherUserId = room.members[user_id].user_id;
+ }
}
else {
// it's got to be an invite, or failing that a self-chat;
diff --git a/webclient/components/matrix/matrix-service.js b/webclient/components/matrix/matrix-service.js
index a4f0568bce..1840cf46c0 100644
--- a/webclient/components/matrix/matrix-service.js
+++ b/webclient/components/matrix/matrix-service.js
@@ -438,6 +438,14 @@ angular.module('matrixService', [])
return this.sendMessage(room_id, msg_id, content);
},
+ redactEvent: function(room_id, event_id) {
+ var path = "/rooms/$room_id/redact/$event_id";
+ path = path.replace("$room_id", room_id);
+ path = path.replace("$event_id", event_id);
+ var content = {};
+ return doRequest("POST", path, undefined, content);
+ },
+
// get a snapshot of the members in a room.
getMemberList: function(room_id) {
// Like the cmd client, escape room ids
diff --git a/webclient/components/matrix/notification-service.js b/webclient/components/matrix/notification-service.js
new file mode 100644
index 0000000000..9a911413c3
--- /dev/null
+++ b/webclient/components/matrix/notification-service.js
@@ -0,0 +1,104 @@
+/*
+Copyright 2014 OpenMarket Ltd
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+'use strict';
+
+/*
+This service manages notifications: enabling, creating and showing them. This
+also contains 'bing word' logic.
+*/
+angular.module('notificationService', [])
+.factory('notificationService', ['$timeout', function($timeout) {
+
+ var getLocalPartFromUserId = function(user_id) {
+ if (!user_id) {
+ return null;
+ }
+ var localpartRegex = /@(.*):\w+/i
+ var results = localpartRegex.exec(user_id);
+ if (results && results.length == 2) {
+ return results[1];
+ }
+ return null;
+ };
+
+ return {
+
+ containsBingWord: function(userId, displayName, bingWords, content) {
+ // case-insensitive name check for user_id OR display_name if they exist
+ var userRegex = "";
+ if (userId) {
+ var localpart = getLocalPartFromUserId(userId);
+ if (localpart) {
+ localpart = localpart.toLocaleLowerCase();
+ userRegex += "\\b" + localpart + "\\b";
+ }
+ }
+ if (displayName) {
+ displayName = displayName.toLocaleLowerCase();
+ if (userRegex.length > 0) {
+ userRegex += "|";
+ }
+ userRegex += "\\b" + displayName + "\\b";
+ }
+
+ var regexList = [new RegExp(userRegex, 'i')];
+
+ // bing word list check
+ if (bingWords && bingWords.length > 0) {
+ for (var i=0; i<bingWords.length; i++) {
+ var re = RegExp(bingWords[i], 'i');
+ regexList.push(re);
+ }
+ }
+ return this.hasMatch(regexList, content);
+ },
+
+ hasMatch: function(regExps, content) {
+ if (!content || $.type(content) != "string") {
+ return false;
+ }
+
+ if (regExps && regExps.length > 0) {
+ for (var i=0; i<regExps.length; i++) {
+ if (content.search(regExps[i]) != -1) {
+ return true;
+ }
+ }
+ }
+ return false;
+ },
+
+ showNotification: function(title, body, icon, onclick) {
+ var notification = new window.Notification(
+ title,
+ {
+ "body": body,
+ "icon": icon
+ }
+ );
+
+ if (onclick) {
+ notification.onclick = onclick;
+ }
+
+ $timeout(function() {
+ notification.close();
+ }, 5 * 1000);
+ }
+ };
+
+}]);
diff --git a/webclient/index.html b/webclient/index.html
index 35c8051298..bc011a6c72 100644
--- a/webclient/index.html
+++ b/webclient/index.html
@@ -20,6 +20,7 @@
<script type='text/javascript' src="js/ui-bootstrap-tpls-0.11.2.js"></script>
<script type='text/javascript' src='js/ng-infinite-scroll-matrix.js'></script>
<script type='text/javascript' src='js/autofill-event.js'></script>
+ <script type='text/javascript' src='js/elastic.js'></script>
<script src="app.js"></script>
<script src="config.js"></script>
<script src="app-controller.js"></script>
@@ -40,6 +41,7 @@
<script src="components/matrix/matrix-phone-service.js"></script>
<script src="components/matrix/event-stream-service.js"></script>
<script src="components/matrix/event-handler-service.js"></script>
+ <script src="components/matrix/notification-service.js"></script>
<script src="components/matrix/presence-service.js"></script>
<script src="components/fileInput/file-input-directive.js"></script>
<script src="components/fileUpload/file-upload-service.js"></script>
diff --git a/webclient/js/elastic.js b/webclient/js/elastic.js
new file mode 100644
index 0000000000..d585d81109
--- /dev/null
+++ b/webclient/js/elastic.js
@@ -0,0 +1,216 @@
+/*
+ * angular-elastic v2.4.0
+ * (c) 2014 Monospaced http://monospaced.com
+ * License: MIT
+ */
+
+angular.module('monospaced.elastic', [])
+
+ .constant('msdElasticConfig', {
+ append: ''
+ })
+
+ .directive('msdElastic', [
+ '$timeout', '$window', 'msdElasticConfig',
+ function($timeout, $window, config) {
+ 'use strict';
+
+ return {
+ require: 'ngModel',
+ restrict: 'A, C',
+ link: function(scope, element, attrs, ngModel) {
+
+ // cache a reference to the DOM element
+ var ta = element[0],
+ $ta = element;
+
+ // ensure the element is a textarea, and browser is capable
+ if (ta.nodeName !== 'TEXTAREA' || !$window.getComputedStyle) {
+ return;
+ }
+
+ // set these properties before measuring dimensions
+ $ta.css({
+ 'overflow': 'hidden',
+ 'overflow-y': 'hidden',
+ 'word-wrap': 'break-word'
+ });
+
+ // force text reflow
+ var text = ta.value;
+ ta.value = '';
+ ta.value = text;
+
+ var append = attrs.msdElastic ? attrs.msdElastic.replace(/\\n/g, '\n') : config.append,
+ $win = angular.element($window),
+ mirrorInitStyle = 'position: absolute; top: -999px; right: auto; bottom: auto;' +
+ 'left: 0; overflow: hidden; -webkit-box-sizing: content-box;' +
+ '-moz-box-sizing: content-box; box-sizing: content-box;' +
+ 'min-height: 0 !important; height: 0 !important; padding: 0;' +
+ 'word-wrap: break-word; border: 0;',
+ $mirror = angular.element('<textarea tabindex="-1" ' +
+ 'style="' + mirrorInitStyle + '"/>').data('elastic', true),
+ mirror = $mirror[0],
+ taStyle = getComputedStyle(ta),
+ resize = taStyle.getPropertyValue('resize'),
+ borderBox = taStyle.getPropertyValue('box-sizing') === 'border-box' ||
+ taStyle.getPropertyValue('-moz-box-sizing') === 'border-box' ||
+ taStyle.getPropertyValue('-webkit-box-sizing') === 'border-box',
+ boxOuter = !borderBox ? {width: 0, height: 0} : {
+ width: parseInt(taStyle.getPropertyValue('border-right-width'), 10) +
+ parseInt(taStyle.getPropertyValue('padding-right'), 10) +
+ parseInt(taStyle.getPropertyValue('padding-left'), 10) +
+ parseInt(taStyle.getPropertyValue('border-left-width'), 10),
+ height: parseInt(taStyle.getPropertyValue('border-top-width'), 10) +
+ parseInt(taStyle.getPropertyValue('padding-top'), 10) +
+ parseInt(taStyle.getPropertyValue('padding-bottom'), 10) +
+ parseInt(taStyle.getPropertyValue('border-bottom-width'), 10)
+ },
+ minHeightValue = parseInt(taStyle.getPropertyValue('min-height'), 10),
+ heightValue = parseInt(taStyle.getPropertyValue('height'), 10),
+ minHeight = Math.max(minHeightValue, heightValue) - boxOuter.height,
+ maxHeight = parseInt(taStyle.getPropertyValue('max-height'), 10),
+ mirrored,
+ active,
+ copyStyle = ['font-family',
+ 'font-size',
+ 'font-weight',
+ 'font-style',
+ 'letter-spacing',
+ 'line-height',
+ 'text-transform',
+ 'word-spacing',
+ 'text-indent'];
+
+ // exit if elastic already applied (or is the mirror element)
+ if ($ta.data('elastic')) {
+ return;
+ }
+
+ // Opera returns max-height of -1 if not set
+ maxHeight = maxHeight && maxHeight > 0 ? maxHeight : 9e4;
+
+ // append mirror to the DOM
+ if (mirror.parentNode !== document.body) {
+ angular.element(document.body).append(mirror);
+ }
+
+ // set resize and apply elastic
+ $ta.css({
+ 'resize': (resize === 'none' || resize === 'vertical') ? 'none' : 'horizontal'
+ }).data('elastic', true);
+
+ /*
+ * methods
+ */
+
+ function initMirror() {
+ var mirrorStyle = mirrorInitStyle;
+
+ mirrored = ta;
+ // copy the essential styles from the textarea to the mirror
+ taStyle = getComputedStyle(ta);
+ angular.forEach(copyStyle, function(val) {
+ mirrorStyle += val + ':' + taStyle.getPropertyValue(val) + ';';
+ });
+ mirror.setAttribute('style', mirrorStyle);
+ }
+
+ function adjust() {
+ var taHeight,
+ taComputedStyleWidth,
+ mirrorHeight,
+ width,
+ overflow;
+
+ if (mirrored !== ta) {
+ initMirror();
+ }
+
+ // active flag prevents actions in function from calling adjust again
+ if (!active) {
+ active = true;
+
+ mirror.value = ta.value + append; // optional whitespace to improve animation
+ mirror.style.overflowY = ta.style.overflowY;
+
+ taHeight = ta.style.height === '' ? 'auto' : parseInt(ta.style.height, 10);
+
+ taComputedStyleWidth = getComputedStyle(ta).getPropertyValue('width');
+
+ // ensure getComputedStyle has returned a readable 'used value' pixel width
+ if (taComputedStyleWidth.substr(taComputedStyleWidth.length - 2, 2) === 'px') {
+ // update mirror width in case the textarea width has changed
+ width = parseInt(taComputedStyleWidth, 10) - boxOuter.width;
+ mirror.style.width = width + 'px';
+ }
+
+ mirrorHeight = mirror.scrollHeight;
+
+ if (mirrorHeight > maxHeight) {
+ mirrorHeight = maxHeight;
+ overflow = 'scroll';
+ } else if (mirrorHeight < minHeight) {
+ mirrorHeight = minHeight;
+ }
+ mirrorHeight += boxOuter.height;
+
+ ta.style.overflowY = overflow || 'hidden';
+
+ if (taHeight !== mirrorHeight) {
+ ta.style.height = mirrorHeight + 'px';
+ scope.$emit('elastic:resize', $ta);
+ }
+
+ // small delay to prevent an infinite loop
+ $timeout(function() {
+ active = false;
+ }, 1);
+
+ }
+ }
+
+ function forceAdjust() {
+ active = false;
+ adjust();
+ }
+
+ /*
+ * initialise
+ */
+
+ // listen
+ if ('onpropertychange' in ta && 'oninput' in ta) {
+ // IE9
+ ta['oninput'] = ta.onkeyup = adjust;
+ } else {
+ ta['oninput'] = adjust;
+ }
+
+ $win.bind('resize', forceAdjust);
+
+ scope.$watch(function() {
+ return ngModel.$modelValue;
+ }, function(newValue) {
+ forceAdjust();
+ });
+
+ scope.$on('elastic:adjust', function() {
+ initMirror();
+ forceAdjust();
+ });
+
+ $timeout(adjust);
+
+ /*
+ * destroy
+ */
+
+ scope.$on('$destroy', function() {
+ $mirror.remove();
+ $win.unbind('resize', forceAdjust);
+ });
+ }
+ };
+ }
+ ]);
diff --git a/webclient/mobile.css b/webclient/mobile.css
index 7c62a072d5..6fa9221ccf 100644
--- a/webclient/mobile.css
+++ b/webclient/mobile.css
@@ -65,13 +65,16 @@
}
#roomName {
- float: left;
- font-size: 14px ! important;
+ font-size: 12px ! important;
margin-top: 0px ! important;
}
+
+ .roomTopicSection {
+ display: none;
+ }
#roomPage {
- top: 35px ! important;
+ top: 40px ! important;
left: 5px ! important;
right: 5px ! important;
bottom: 70px ! important;
diff --git a/webclient/room/room-controller.js b/webclient/room/room-controller.js
index 78520a829d..486ead0da9 100644
--- a/webclient/room/room-controller.js
+++ b/webclient/room/room-controller.js
@@ -15,11 +15,21 @@ limitations under the License.
*/
angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput'])
-.controller('RoomController', ['$modal', '$filter', '$scope', '$timeout', '$routeParams', '$location', '$rootScope', 'matrixService', 'mPresence', 'eventHandlerService', 'mFileUpload', 'matrixPhoneService', 'MatrixCall',
- function($modal, $filter, $scope, $timeout, $routeParams, $location, $rootScope, matrixService, mPresence, eventHandlerService, mFileUpload, matrixPhoneService, MatrixCall) {
+.controller('RoomController', ['$modal', '$filter', '$scope', '$timeout', '$routeParams', '$location', '$rootScope', 'matrixService', 'mPresence', 'eventHandlerService', 'mFileUpload', 'matrixPhoneService', 'MatrixCall', 'notificationService',
+ function($modal, $filter, $scope, $timeout, $routeParams, $location, $rootScope, matrixService, mPresence, eventHandlerService, mFileUpload, matrixPhoneService, MatrixCall, notificationService) {
'use strict';
var MESSAGES_PER_PAGINATION = 30;
var THUMBNAIL_SIZE = 320;
+
+ // .html needs this
+ $scope.containsBingWord = function(content) {
+ return notificationService.containsBingWord(
+ matrixService.config().user_id,
+ matrixService.config().display_name,
+ matrixService.config().bingWords,
+ content
+ );
+ };
// Room ids. Computed and resolved in onInit
$scope.room_id = undefined;
@@ -133,7 +143,9 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput'])
// Do not autoscroll to the bottom to display the new event if the user is not at the bottom.
// Exception: in case where the event is from the user, we want to force scroll to the bottom
var objDiv = document.getElementById("messageTableWrapper");
- if ((objDiv.offsetHeight + objDiv.scrollTop >= objDiv.scrollHeight) || force) {
+ // add a 10px buffer to this check so if the message list is not *quite*
+ // at the bottom it still scrolls since it basically is at the bottom.
+ if ((10 + objDiv.offsetHeight + objDiv.scrollTop >= objDiv.scrollHeight) || force) {
$timeout(function() {
objDiv.scrollTop = objDiv.scrollHeight;
@@ -189,16 +201,20 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput'])
// Notify when a user joins
if ((document.hidden || matrixService.presence.unavailable === mPresence.getState())
&& event.state_key !== $scope.state.user_id && "join" === event.membership) {
- var notification = new window.Notification(
- event.content.displayname +
- " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")", // FIXME: don't leak room_ids here
- {
- "body": event.content.displayname + " joined",
- "icon": event.content.avatar_url ? event.content.avatar_url : undefined
- });
- $timeout(function() {
- notification.close();
- }, 5 * 1000);
+ var userName = event.content.displayname;
+ if (!userName) {
+ userName = event.state_key;
+ }
+ notificationService.showNotification(
+ userName +
+ " (" + (matrixService.getRoomIdToAliasMapping(event.room_id) || event.room_id) + ")",
+ userName + " joined",
+ event.content.avatar_url ? event.content.avatar_url : undefined,
+ function() {
+ console.log("notification.onclick() room=" + event.room_id);
+ $rootScope.goToPage('room/' + event.room_id);
+ }
+ );
}
}
}
@@ -983,10 +999,87 @@ angular.module('RoomController', ['ngSanitize', 'matrixFilter', 'mFileInput'])
};
$scope.openJson = function(content) {
- console.log("Displaying modal dialog for " + JSON.stringify(content));
+ $scope.event_selected = content;
+ // scope this so the template can check power levels and enable/disable
+ // buttons
+ $scope.pow = matrixService.getUserPowerLevel;
+
+ var modalInstance = $modal.open({
+ templateUrl: 'eventInfoTemplate.html',
+ controller: 'EventInfoController',
+ scope: $scope
+ });
+
+ modalInstance.result.then(function(action) {
+ if (action === "redact") {
+ var eventId = $scope.event_selected.event_id;
+ console.log("Redacting event ID " + eventId);
+ matrixService.redactEvent(
+ $scope.event_selected.room_id,
+ eventId
+ ).then(function(response) {
+ console.log("Redaction = " + JSON.stringify(response));
+ }, function(error) {
+ console.error("Failed to redact event: "+JSON.stringify(error));
+ if (error.data.error) {
+ $scope.feedback = error.data.error;
+ }
+ });
+ }
+ }, function() {
+ // any dismiss code
+ });
+ };
+
+ $scope.openRoomInfo = function() {
+ $scope.roomInfo = {};
+ $scope.roomInfo.newEvent = {
+ content: {},
+ type: "",
+ state_key: ""
+ };
+
+ var stateFilter = $filter("stateEventsFilter");
+ var stateEvents = stateFilter($scope.events.rooms[$scope.room_id]);
+ // The modal dialog will 2-way bind this field, so we MUST make a deep
+ // copy of the state events else we will be *actually adjusing our view
+ // of the world* when fiddling with the JSON!! Apparently parse/stringify
+ // is faster than jQuery's extend when doing deep copies.
+ $scope.roomInfo.stateEvents = JSON.parse(JSON.stringify(stateEvents));
var modalInstance = $modal.open({
- template: "<pre>" + angular.toJson(content, true) + "</pre>"
+ templateUrl: 'roomInfoTemplate.html',
+ controller: 'RoomInfoController',
+ size: 'lg',
+ scope: $scope
});
};
-}]);
+}])
+.controller('EventInfoController', function($scope, $modalInstance) {
+ console.log("Displaying modal dialog for >>>> " + JSON.stringify($scope.event_selected));
+ $scope.redact = function() {
+ console.log("User level = "+$scope.pow($scope.room_id, $scope.state.user_id)+
+ " Redact level = "+$scope.events.rooms[$scope.room_id]["m.room.ops_levels"].content.redact_level);
+ console.log("Redact event >> " + JSON.stringify($scope.event_selected));
+ $modalInstance.close("redact");
+ };
+})
+.controller('RoomInfoController', function($scope, $modalInstance, $filter, matrixService) {
+ console.log("Displaying room info.");
+
+ $scope.submit = function(event) {
+ if (event.content) {
+ console.log("submit >>> " + JSON.stringify(event.content));
+ matrixService.sendStateEvent($scope.room_id, event.type,
+ event.content, event.state_key).then(function(response) {
+ $modalInstance.dismiss();
+ }, function(err) {
+ $scope.feedback = err.data.error;
+ }
+ );
+ }
+ };
+
+ $scope.dismiss = $modalInstance.dismiss;
+
+});
diff --git a/webclient/room/room.html b/webclient/room/room.html
index e753b037fe..5265f42dd8 100644
--- a/webclient/room/room.html
+++ b/webclient/room/room.html
@@ -1,5 +1,59 @@
<div ng-controller="RoomController" data-ng-init="onInit()" class="room" style="height: 100%;">
+ <script type="text/ng-template" id="eventInfoTemplate.html">
+ <div class="modal-body">
+ <pre> {{event_selected | json}} </pre>
+ </div>
+ <div class="modal-footer">
+ <button ng-click="redact()" type="button" class="btn btn-danger"
+ ng-disabled="!events.rooms[room_id]['m.room.ops_levels'].content.redact_level || !pow(room_id, state.user_id) || pow(room_id, state.user_id) < events.rooms[room_id]['m.room.ops_levels'].content.redact_level"
+ title="Delete this event on all home servers. This cannot be undone.">
+ Redact
+ </button>
+ </div>
+ </script>
+
+ <script type="text/ng-template" id="roomInfoTemplate.html">
+ <div class="modal-body">
+ <table class="room-info">
+ <tr ng-repeat="(key, event) in roomInfo.stateEvents" class="room-info-event">
+ <td class="room-info-event-meta" width="30%">
+ <span class="monospace">{{ key }}</span>
+ <br/>
+ {{ (event.origin_server_ts) | date:'MMM d HH:mm' }}
+ <br/>
+ Set by: <span class="monospace">{{ event.user_id }}</span>
+ <br/>
+ <span ng-show="event.required_power_level >= 0">Required power level: {{event.required_power_level}}<br/></span>
+ <button ng-click="submit(event)" type="button" class="btn btn-success" ng-disabled="!event.content">
+ Submit
+ </button>
+ </td>
+ <td class="room-info-event-content" width="70%">
+ <textarea class="room-info-textarea-content" msd-elastic ng-model="event.content" asjson></textarea>
+ </td>
+ </tr>
+ <tr>
+ <td class="room-info-event-meta" width="30%">
+ <input ng-model="roomInfo.newEvent.type" placeholder="your.event.type" />
+ <br/>
+ <button ng-click="submit(roomInfo.newEvent)" type="button" class="btn btn-success" ng-disabled="!roomInfo.newEvent.content || !roomInfo.newEvent.type">
+ Submit
+ </button>
+ </td>
+ <td class="room-info-event-content" width="70%">
+ <textarea class="room-info-textarea-content" msd-elastic ng-model="roomInfo.newEvent.content" asjson></textarea>
+ </td>
+ </tr>
+ </table>
+ </div>
+ <div class="modal-footer">
+ <button ng-click="dismiss()" type="button" class="btn">
+ Close
+ </button>
+ </div>
+ </script>
+
<div id="roomHeader">
<a href ng-click="goToPage('/')"><img src="img/logo-small.png" width="100" height="43" alt="[matrix]"/></a>
<div class="roomHeaderInfo">
@@ -79,15 +133,15 @@
</div>
</td>
<td class="avatar">
- <img class="avatarImage" ng-src="{{ members[msg.user_id].avatar_url || 'img/default-profile.png' }}" width="32" height="32"
+ <img class="avatarImage" ng-src="{{ members[msg.user_id].avatar_url || 'img/default-profile.png' }}" width="32" height="32" title="{{msg.user_id}}"
ng-hide="events.rooms[room_id].messages[$index - 1].user_id === msg.user_id || msg.user_id === state.user_id"/>
</td>
<td ng-class="(!msg.content.membership && ('m.room.topic' !== msg.type && 'm.room.name' !== msg.type))? (msg.content.msgtype === 'm.emote' ? 'emote text' : 'text') : 'membership text'">
- <div class="bubble">
- <span ng-if="'join' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)">
+ <div class="bubble" ng-click="openJson(msg)">
+ <span ng-if="'join' === msg.content.membership && msg.changedKey === 'membership'">
{{ members[msg.state_key].displayname || msg.state_key }} joined
</span>
- <span ng-if="'leave' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)">
+ <span ng-if="'leave' === msg.content.membership && msg.changedKey === 'membership'">
<span ng-if="msg.user_id === msg.state_key">
{{ members[msg.state_key].displayname || msg.state_key }} left
</span>
@@ -101,7 +155,7 @@
</span>
</span>
<span ng-if="'invite' === msg.content.membership && msg.changedKey === 'membership' ||
- 'ban' === msg.content.membership && msg.changedKey === 'membership'" ng-click="openJson(msg)">
+ 'ban' === msg.content.membership && msg.changedKey === 'membership'">
{{ members[msg.user_id].displayname || msg.user_id }}
{{ {"invite": "invited", "ban": "banned"}[msg.content.membership] }}
{{ members[msg.state_key].displayname || msg.state_key }}
@@ -109,25 +163,24 @@
: {{ msg.content.reason }}
</span>
</span>
- <span ng-if="msg.changedKey === 'displayname'" ng-click="openJson(msg)">
+ <span ng-if="msg.changedKey === 'displayname'">
{{ msg.user_id }} changed their display name from {{ msg.prev_content.displayname }} to {{ msg.content.displayname }}
</span>
<span ng-show='msg.content.msgtype === "m.emote"'
ng-class="msg.echo_msg_state"
ng-bind-html="'* ' + (members[msg.user_id].displayname || msg.user_id) + ' ' + msg.content.body | linky:'_blank'"
- ng-click="openJson(msg)"/>
+ />
<span ng-show='msg.content.msgtype === "m.text"'
class="message"
- ng-click="openJson(msg)"
ng-class="containsBingWord(msg.content.body) && msg.user_id != state.user_id ? msg.echo_msg_state + ' messageBing' : msg.echo_msg_state"
ng-bind-html="(msg.content.msgtype === 'm.text' && msg.type === 'm.room.message' && msg.content.format === 'org.matrix.custom.html') ?
(msg.content.formatted_body | unsanitizedLinky) :
(msg.content.msgtype === 'm.text' && msg.type === 'm.room.message') ? (msg.content.body | linky:'_blank') : '' "/>
- <span ng-show='msg.type === "m.call.invite" && msg.user_id == state.user_id' ng-click="openJson(msg)">Outgoing Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span>
- <span ng-show='msg.type === "m.call.invite" && msg.user_id != state.user_id' ng-click="openJson(msg)">Incoming Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span>
+ <span ng-show='msg.type === "m.call.invite" && msg.user_id == state.user_id'>Outgoing Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span>
+ <span ng-show='msg.type === "m.call.invite" && msg.user_id != state.user_id'>Incoming Call{{ isWebRTCSupported ? '' : ' (But your browser does not support VoIP)' }}</span>
<div ng-show='msg.content.msgtype === "m.image"'>
<div ng-hide='msg.content.thumbnail_url' ng-style="msg.content.body.h && { 'height' : (msg.content.body.h < 320) ? msg.content.body.h : 320}">
@@ -135,15 +188,15 @@
</div>
<div ng-show='msg.content.thumbnail_url' ng-style="{ 'height' : msg.content.thumbnail_info.h }">
<img class="image mouse-pointer" ng-src="{{ msg.content.thumbnail_url }}"
- ng-click="$parent.fullScreenImageURL = msg.content.url"/>
+ ng-click="$parent.fullScreenImageURL = msg.content.url; $event.stopPropagation();"/>
</div>
</div>
- <span ng-if="'m.room.topic' === msg.type" ng-click="openJson(msg)">
+ <span ng-if="'m.room.topic' === msg.type">
{{ members[msg.user_id].displayname || msg.user_id }} changed the topic to: {{ msg.content.topic }}
</span>
- <span ng-if="'m.room.name' === msg.type" ng-click="openJson(msg)">
+ <span ng-if="'m.room.name' === msg.type">
{{ members[msg.user_id].displayname || msg.user_id }} changed the room name to: {{ msg.content.name }}
</span>
@@ -204,6 +257,9 @@
>
Video Call
</button>
+ <button ng-click="openRoomInfo()">
+ Room Info
+ </button>
</div>
{{ feedback }}
|