diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 96a9b143ca..b31518bf62 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -27,6 +27,7 @@ from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
+from .sync import SyncHandler
class Handlers(object):
@@ -53,3 +54,4 @@ class Handlers(object):
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.appservice_handler = ApplicationServicesHandler(hs)
+ self.sync_handler = SyncHandler(hs)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d997917cd6..025e7e7e62 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -49,24 +49,25 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
- as_client_event=True):
+ as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id)
try:
- if auth_user not in self._streams_per_user:
- self._streams_per_user[auth_user] = 0
- if auth_user in self._stop_timer_per_user:
- try:
- self.clock.cancel_call_later(
- self._stop_timer_per_user.pop(auth_user)
+ if affect_presence:
+ if auth_user not in self._streams_per_user:
+ self._streams_per_user[auth_user] = 0
+ if auth_user in self._stop_timer_per_user:
+ try:
+ self.clock.cancel_call_later(
+ self._stop_timer_per_user.pop(auth_user)
+ )
+ except:
+ logger.exception("Failed to cancel event timer")
+ else:
+ yield self.distributor.fire(
+ "started_user_eventstream", auth_user
)
- except:
- logger.exception("Failed to cancel event timer")
- else:
- yield self.distributor.fire(
- "started_user_eventstream", auth_user
- )
- self._streams_per_user[auth_user] += 1
+ self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
@@ -94,27 +95,28 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk)
finally:
- self._streams_per_user[auth_user] -= 1
- if not self._streams_per_user[auth_user]:
- del self._streams_per_user[auth_user]
-
- # 10 seconds of grace to allow the client to reconnect again
- # before we think they're gone
- def _later():
- logger.debug(
- "_later stopped_user_eventstream %s", auth_user
- )
+ if affect_presence:
+ self._streams_per_user[auth_user] -= 1
+ if not self._streams_per_user[auth_user]:
+ del self._streams_per_user[auth_user]
+
+ # 10 seconds of grace to allow the client to reconnect again
+ # before we think they're gone
+ def _later():
+ logger.debug(
+ "_later stopped_user_eventstream %s", auth_user
+ )
- self._stop_timer_per_user.pop(auth_user, None)
+ self._stop_timer_per_user.pop(auth_user, None)
- yield self.distributor.fire(
- "stopped_user_eventstream", auth_user
- )
+ return self.distributor.fire(
+ "stopped_user_eventstream", auth_user
+ )
- logger.debug("Scheduling _later: for %s", auth_user)
- self._stop_timer_per_user[auth_user] = (
- self.clock.call_later(30, _later)
- )
+ logger.debug("Scheduling _later: for %s", auth_user)
+ self._stop_timer_per_user[auth_user] = (
+ self.clock.call_later(30, _later)
+ )
class EventHandler(BaseHandler):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..8bf5a4cc11 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
from ._base import BaseHandler
-from synapse.events.utils import prune_event
from synapse.api.errors import (
- AuthError, FederationError, SynapseError, StoreError,
+ AuthError, FederationError, StoreError,
)
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
- compute_event_signature, check_event_content_hash,
- add_hashes_and_signatures,
+ compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id)
- redacted_event = prune_event(event)
-
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- event.origin, redacted_pdu_json
- )
- except SynapseError as e:
- logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
- )
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
-
- if not check_event_content_hash(event):
- logger.warn(
- "Event content has been tampered, redacting %s, %s",
- event.event_id, encode_canonical_json(event.get_dict())
- )
- event = redacted_event
-
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
- if not is_in_room and not event.internal_metadata.outlier:
+ if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
-
- replication = self.replication_layer
-
- if not state:
- state, auth_chain = yield replication.get_state_for_room(
- origin, context=event.room_id, event_id=event.event_id,
- )
-
- if not auth_chain:
- auth_chain = yield replication.get_event_auth(
- origin,
- context=event.room_id,
- event_id=event.event_id,
- )
-
- for e in auth_chain:
- e.internal_metadata.outlier = True
- try:
- yield self._handle_new_event(e, fetch_auth_from=origin)
- except:
- logger.exception(
- "Failed to handle auth event %s",
- e.event_id,
- )
-
current_state = state
- if state:
+ if state and auth_chain is not None:
for e in state:
- logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(origin, e, auth_events=auth)
except:
logger.exception(
"Failed to handle state event %s",
@@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
try:
yield self._handle_new_event(
+ origin,
event,
state=state,
backfilled=backfilled,
@@ -393,8 +343,19 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
+
+ if e.event_id == event.event_id:
+ continue
+
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(
+ target_host, e, auth_events=auth
+ )
except:
logger.exception(
"Failed to handle auth event %s",
@@ -402,11 +363,18 @@ class FederationHandler(BaseHandler):
)
for e in state:
- # FIXME: Auth these.
+ if e.event_id == event.event_id:
+ continue
+
e.internal_metadata.outlier = True
try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
yield self._handle_new_event(
- e, fetch_auth_from=target_host
+ target_host, e, auth_events=auth
)
except:
logger.exception(
@@ -414,10 +382,18 @@ class FederationHandler(BaseHandler):
e.event_id,
)
+ auth_ids = [e_id for e_id, _ in event.auth_events]
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+
yield self._handle_new_event(
+ target_host,
new_event,
state=state,
current_state=state,
+ auth_events=auth_events,
)
yield self.notifier.on_new_room_event(
@@ -481,7 +457,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- context = yield self._handle_new_event(event)
+ context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +658,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None)
@defer.inlineCallbacks
- def _handle_new_event(self, event, state=None, backfilled=False,
- current_state=None, fetch_auth_from=None):
+ @log_function
+ def _handle_new_event(self, origin, event, state=None, backfilled=False,
+ current_state=None, auth_events=None):
logger.debug(
- "_handle_new_event: Before annotate: %s, sigs: %s",
+ "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
@@ -694,65 +671,44 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
+ if not auth_events:
+ auth_events = context.auth_events
+
logger.debug(
- "_handle_new_event: Before auth fetch: %s, sigs: %s",
- event.event_id, event.signatures,
+ "_handle_new_event: %s, auth_events: %s",
+ event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
- known_ids = set(
- [s.event_id for s in context.auth_events.values()]
- )
-
- for e_id, _ in event.auth_events:
- if e_id not in known_ids:
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e and fetch_auth_from is not None:
- # Grab the auth_chain over federation if we are missing
- # auth events.
- auth_chain = yield self.replication_layer.get_event_auth(
- fetch_auth_from, event.event_id, event.room_id
- )
- for auth_event in auth_chain:
- yield self._handle_new_event(auth_event)
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e:
- # TODO: Do some conflict res to make sure that we're
- # not the ones who are wrong.
- logger.info(
- "Rejecting %s as %s not in db or %s",
- event.event_id, e_id, known_ids,
- )
- # FIXME: How does raising AuthError work with federation?
- raise AuthError(403, "Cannot find auth event")
-
- context.auth_events[(e.type, e.state_key)] = e
-
- logger.debug(
- "_handle_new_event: Before hack: %s, sigs: %s",
- event.event_id, event.signatures,
- )
-
+ # This is a hack to fix some old rooms where the initial join event
+ # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
- context.auth_events[(c.type, c.state_key)] = c
+ auth_events[(c.type, c.state_key)] = c
- logger.debug(
- "_handle_new_event: Before auth check: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ try:
+ yield self.do_auth(
+ origin, event, context, auth_events=auth_events
+ )
+ except AuthError as e:
+ logger.warn(
+ "Rejecting %s because %s",
+ event.event_id, e.msg
+ )
- self.auth.check(event, auth_events=context.auth_events)
+ context.rejected = RejectedReason.AUTH_ERROR
- logger.debug(
- "_handle_new_event: Before persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ is_new_state=False,
+ current_state=current_state,
+ )
+ raise
yield self.store.persist_event(
event,
@@ -762,9 +718,294 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
- logger.debug(
- "_handle_new_event: After persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
+ defer.returnValue(context)
+
+ @defer.inlineCallbacks
+ def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+ missing):
+ # Just go through and process each event in `remote_auth_chain`. We
+ # don't want to fall into the trap of `missing` being wrong.
+ for e in remote_auth_chain:
+ try:
+ yield self._handle_new_event(origin, e)
+ except AuthError:
+ pass
+
+ # Now get the current auth_chain for the event.
+ local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+ # TODO: Check if we would now reject event_id. If so we need to tell
+ # everyone.
+
+ ret = yield self.construct_auth_difference(
+ local_auth_chain, remote_auth_chain
)
- defer.returnValue(context)
+ for event in ret["auth_chain"]:
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
+ )
+
+ logger.debug("on_query_auth reutrning: %s", ret)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ @log_function
+ def do_auth(self, origin, event, context, auth_events):
+ # Check if we have all the auth events.
+ res = yield self.store.have_events(
+ [e_id for e_id, _ in event.auth_events]
+ )
+
+ event_auth_events = set(e_id for e_id, _ in event.auth_events)
+ seen_events = set(res.keys())
+
+ missing_auth = event_auth_events - seen_events
+
+ if missing_auth:
+ logger.debug("Missing auth: %s", missing_auth)
+ # If we don't have all the auth events, we need to get them.
+ remote_auth_chain = yield self.replication_layer.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+
+ seen_remotes = yield self.store.have_events(
+ [e.event_id for e in remote_auth_chain]
+ )
+
+ for e in remote_auth_chain:
+ if e.event_id in seen_remotes.keys():
+ continue
+
+ if e.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in remote_auth_chain
+ if e.event_id in auth_ids
+ }
+ e.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s missing_auth: %s",
+ event.event_id, e.event_id
+ )
+ yield self._handle_new_event(
+ origin, e, auth_events=auth
+ )
+
+ if e.event_id in event_auth_events:
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ # FIXME: Assumes we have and stored all the state for all the
+ # prev_events
+ current_state = set(e.event_id for e in auth_events.values())
+ different_auth = event_auth_events - current_state
+
+ if different_auth and not event.internal_metadata.is_outlier():
+ # Do auth conflict res.
+ logger.debug("Different auth: %s", different_auth)
+
+ # 1. Get what we think is the auth chain.
+ auth_ids = self.auth.compute_auth_events(event, context)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+ # 2. Get remote difference.
+ result = yield self.replication_layer.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
+
+ seen_remotes = yield self.store.have_events(
+ [e.event_id for e in result["auth_chain"]]
+ )
+
+ # 3. Process any remote auth chain events we haven't seen.
+ for ev in result["auth_chain"]:
+ if ev.event_id in seen_remotes.keys():
+ continue
+
+ if ev.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = [e_id for e_id, _ in ev.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ }
+ ev.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s different_auth: %s",
+ event.event_id, e.event_id
+ )
+
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
+
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ context.current_state.update(auth_events)
+ context.state_group = None
+
+ try:
+ self.auth.check(event, auth_events=auth_events)
+ except AuthError:
+ raise
+
+ @defer.inlineCallbacks
+ def construct_auth_difference(self, local_auth, remote_auth):
+ """ Given a local and remote auth chain, find the differences. This
+ assumes that we have already processed all events in remote_auth
+
+ Params:
+ local_auth (list)
+ remote_auth (list)
+
+ Returns:
+ dict
+ """
+
+ logger.debug("construct_auth_difference Start!")
+
+ # TODO: Make sure we are OK with local_auth or remote_auth having more
+ # auth events in them than strictly necessary.
+
+ def sort_fun(ev):
+ return ev.depth, ev.event_id
+
+ logger.debug("construct_auth_difference after sort_fun!")
+
+ # We find the differences by starting at the "bottom" of each list
+ # and iterating up on both lists. The lists are ordered by depth and
+ # then event_id, we iterate up both lists until we find the event ids
+ # don't match. Then we look at depth/event_id to see which side is
+ # missing that event, and iterate only up that list. Repeat.
+
+ remote_list = list(remote_auth)
+ remote_list.sort(key=sort_fun)
+
+ local_list = list(local_auth)
+ local_list.sort(key=sort_fun)
+
+ local_iter = iter(local_list)
+ remote_iter = iter(remote_list)
+
+ logger.debug("construct_auth_difference before get_next!")
+
+ def get_next(it, opt=None):
+ try:
+ return it.next()
+ except:
+ return opt
+
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+
+ logger.debug("construct_auth_difference before while")
+
+ missing_remotes = []
+ missing_locals = []
+ while current_local or current_remote:
+ if current_remote is None:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local is None:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.event_id == current_remote.event_id:
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.depth < current_remote.depth:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local.depth > current_remote.depth:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ # They have the same depth, so we fall back to the event_id order
+ if current_local.event_id < current_remote.event_id:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+
+ if current_local.event_id > current_remote.event_id:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ logger.debug("construct_auth_difference after while")
+
+ # missing locals should be sent to the server
+ # We should find why we are missing remotes, as they will have been
+ # rejected.
+
+ # Remove events from missing_remotes if they are referencing a missing
+ # remote. We only care about the "root" rejected ones.
+ missing_remote_ids = [e.event_id for e in missing_remotes]
+ base_remote_rejected = list(missing_remotes)
+ for e in missing_remotes:
+ for e_id, _ in e.auth_events:
+ if e_id in missing_remote_ids:
+ base_remote_rejected.remove(e)
+
+ reason_map = {}
+
+ for e in base_remote_rejected:
+ reason = yield self.store.get_rejection_reason(e.event_id)
+ if reason is None:
+ # FIXME: ERRR?!
+ logger.warn("Could not find reason for %s", e.event_id)
+ raise RuntimeError("")
+
+ reason_map[e.event_id] = reason
+
+ if reason == RejectedReason.AUTH_ERROR:
+ pass
+ elif reason == RejectedReason.REPLACED:
+ # TODO: Get proof
+ pass
+ elif reason == RejectedReason.NOT_ANCESTOR:
+ # TODO: Get proof.
+ pass
+
+ logger.debug("construct_auth_difference returning")
+
+ defer.returnValue({
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {
+ "reason": reason_map[e.event_id],
+ "proof": None,
+ }
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ })
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9c3271fe88..6fbd2af4ab 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -114,7 +114,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_and_send_event(self, event_dict, ratelimit=True):
+ def create_and_send_event(self, event_dict, ratelimit=True,
+ client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -148,6 +149,15 @@ class MessageHandler(BaseHandler):
builder.content
)
+ if client is not None:
+ if client.token_id is not None:
+ builder.internal_metadata.token_id = client.token_id
+ if client.device_id is not None:
+ builder.internal_metadata.device_id = client.device_id
+
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
+
event, context = yield self._create_new_client_event(
builder=builder,
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d66bfea7b1..cd0798c2b0 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data
)
+ # outbound signal from the presence module to advertise when a user's
+ # presence has changed
+ distributor.declare("user_presence_changed")
+
self.distributor = distributor
self.federation = hs.get_replication_layer()
@@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids,
statuscache=statuscache,
)
+ yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 732652c228..66a89c10b2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler):
# each request
httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable!
- trustedIdServers = ['matrix.org:8090']
+ trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
new file mode 100644
index 0000000000..962686f4bb
--- /dev/null
+++ b/synapse/handlers/sync.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseHandler
+
+from synapse.streams.config import PaginationConfig
+from synapse.api.constants import Membership, EventTypes
+
+from twisted.internet import defer
+
+import collections
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+SyncConfig = collections.namedtuple("SyncConfig", [
+ "user",
+ "client_info",
+ "limit",
+ "gap",
+ "sort",
+ "backfill",
+ "filter",
+])
+
+
+class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
+ "room_id",
+ "limited",
+ "published",
+ "events",
+ "state",
+ "prev_batch",
+ "ephemeral",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ """Make the result appear empty if there are no updates. This is used
+ to tell if room needs to be part of the sync result.
+ """
+ return bool(self.events or self.state or self.ephemeral)
+
+
+class SyncResult(collections.namedtuple("SyncResult", [
+ "next_batch", # Token for the next sync
+ "private_user_data", # List of private events for the user.
+ "public_user_data", # List of public events for all users.
+ "rooms", # RoomSyncResult for each room.
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ """Make the result appear empty if there are no updates. This is used
+ to tell if the notifier needs to wait for more events when polling for
+ events.
+ """
+ return bool(
+ self.private_user_data or self.public_user_data or self.rooms
+ )
+
+
+class SyncHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(SyncHandler, self).__init__(hs)
+ self.event_sources = hs.get_event_sources()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
+ """Get the sync for a client if we have new data for it now. Otherwise
+ wait for new data to arrive on the server. If the timeout expires, then
+ return an empty sync result.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if timeout == 0 or since_token is None:
+ result = yield self.current_sync_for_user(sync_config, since_token)
+ defer.returnValue(result)
+ else:
+ def current_sync_callback():
+ return self.current_sync_for_user(sync_config, since_token)
+
+ rm_handler = self.hs.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+ result = yield self.notifier.wait_for_events(
+ sync_config.user, room_ids,
+ sync_config.filter, timeout, current_sync_callback
+ )
+ defer.returnValue(result)
+
+ def current_sync_for_user(self, sync_config, since_token=None):
+ """Get the sync for client needed to match what the server has now.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if since_token is None:
+ return self.initial_sync(sync_config)
+ else:
+ if sync_config.gap:
+ return self.incremental_sync_with_gap(sync_config, since_token)
+ else:
+ #TODO(mjark): Handle gapless sync
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def initial_sync(self, sync_config):
+ """Get a sync for a client which is starting without any state
+ Returns:
+ A Deferred SyncResult.
+ """
+ if sync_config.sort == "timeline,desc":
+ # TODO(mjark): Handle going through events in reverse order?.
+ # What does "most recent events" mean when applying the limits mean
+ # in this case?
+ raise NotImplementedError()
+
+ now_token = yield self.event_sources.get_current_token()
+
+ presence_stream = self.event_sources.sources["presence"]
+ # TODO (mjark): This looks wrong, shouldn't we be getting the presence
+ # UP to the present rather than after the present?
+ pagination_config = PaginationConfig(from_token=now_token)
+ presence, _ = yield presence_stream.get_pagination_rows(
+ user=sync_config.user,
+ pagination_config=pagination_config.get_source_config("presence"),
+ key=None
+ )
+ room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=sync_config.user.to_string(),
+ membership_list=[Membership.INVITE, Membership.JOIN]
+ )
+
+ # TODO (mjark): Does public mean "published"?
+ published_rooms = yield self.store.get_rooms(is_public=True)
+ published_room_ids = set(r["room_id"] for r in published_rooms)
+
+ rooms = []
+ for event in room_list:
+ room_sync = yield self.initial_sync_for_room(
+ event.room_id, sync_config, now_token, published_room_ids
+ )
+ rooms.append(room_sync)
+
+ defer.returnValue(SyncResult(
+ public_user_data=presence,
+ private_user_data=[],
+ rooms=rooms,
+ next_batch=now_token,
+ ))
+
+ @defer.inlineCallbacks
+ def initial_sync_for_room(self, room_id, sync_config, now_token,
+ published_room_ids):
+ """Sync a room for a client which is starting without any state
+ Returns:
+ A Deferred RoomSyncResult.
+ """
+
+ recents, prev_batch_token, limited = yield self.load_filtered_recents(
+ room_id, sync_config, now_token,
+ )
+
+ current_state_events = yield self.state_handler.get_current_state(
+ room_id
+ )
+
+ defer.returnValue(RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch_token,
+ state=current_state_events,
+ limited=limited,
+ ephemeral=[],
+ ))
+
+ @defer.inlineCallbacks
+ def incremental_sync_with_gap(self, sync_config, since_token):
+ """ Get the incremental delta needed to bring the client up to
+ date with the server.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if sync_config.sort == "timeline,desc":
+ # TODO(mjark): Handle going through events in reverse order?.
+ # What does "most recent events" mean when applying the limits mean
+ # in this case?
+ raise NotImplementedError()
+
+ now_token = yield self.event_sources.get_current_token()
+
+ presence_source = self.event_sources.sources["presence"]
+ presence, presence_key = yield presence_source.get_new_events_for_user(
+ user=sync_config.user,
+ from_key=since_token.presence_key,
+ limit=sync_config.limit,
+ )
+ now_token = now_token.copy_and_replace("presence_key", presence_key)
+
+ typing_source = self.event_sources.sources["typing"]
+ typing, typing_key = yield typing_source.get_new_events_for_user(
+ user=sync_config.user,
+ from_key=since_token.typing_key,
+ limit=sync_config.limit,
+ )
+ now_token = now_token.copy_and_replace("typing_key", typing_key)
+
+ typing_by_room = {event["room_id"]: [event] for event in typing}
+ for event in typing:
+ event.pop("room_id")
+ logger.debug("Typing %r", typing_by_room)
+
+ rm_handler = self.hs.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+
+ # TODO (mjark): Does public mean "published"?
+ published_rooms = yield self.store.get_rooms(is_public=True)
+ published_room_ids = set(r["room_id"] for r in published_rooms)
+
+ room_events, _ = yield self.store.get_room_events_stream(
+ sync_config.user.to_string(),
+ from_key=since_token.room_key,
+ to_key=now_token.room_key,
+ room_id=None,
+ limit=sync_config.limit + 1,
+ )
+
+ rooms = []
+ if len(room_events) <= sync_config.limit:
+ # There is no gap in any of the rooms. Therefore we can just
+ # partition the new events by room and return them.
+ events_by_room_id = {}
+ for event in room_events:
+ events_by_room_id.setdefault(event.room_id, []).append(event)
+
+ for room_id in room_ids:
+ recents = events_by_room_id.get(room_id, [])
+ state = [event for event in recents if event.is_state()]
+ if recents:
+ prev_batch = now_token.copy_and_replace(
+ "room_key", recents[0].internal_metadata.before
+ )
+ else:
+ prev_batch = now_token
+
+ state = yield self.check_joined_room(
+ sync_config, room_id, state
+ )
+
+ room_sync = RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch,
+ state=state,
+ limited=False,
+ ephemeral=typing_by_room.get(room_id, [])
+ )
+ if room_sync:
+ rooms.append(room_sync)
+ else:
+ for room_id in room_ids:
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id, sync_config, since_token, now_token,
+ published_room_ids, typing_by_room
+ )
+ if room_sync:
+ rooms.append(room_sync)
+
+ defer.returnValue(SyncResult(
+ public_user_data=presence,
+ private_user_data=[],
+ rooms=rooms,
+ next_batch=now_token,
+ ))
+
+ @defer.inlineCallbacks
+ def load_filtered_recents(self, room_id, sync_config, now_token,
+ since_token=None):
+ limited = True
+ recents = []
+ filtering_factor = 2
+ load_limit = max(sync_config.limit * filtering_factor, 100)
+ max_repeat = 3 # Only try a few times per room, otherwise
+ room_key = now_token.room_key
+
+ while limited and len(recents) < sync_config.limit and max_repeat:
+ events, keys = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_token=since_token.room_key if since_token else None,
+ end_token=room_key,
+ )
+ (room_key, _) = keys
+ loaded_recents = sync_config.filter.filter_room_events(events)
+ loaded_recents.extend(recents)
+ recents = loaded_recents
+ if len(events) <= load_limit:
+ limited = False
+ max_repeat -= 1
+
+ if len(recents) > sync_config.limit:
+ recents = recents[-sync_config.limit:]
+ room_key = recents[0].internal_metadata.before
+
+ prev_batch_token = now_token.copy_and_replace(
+ "room_key", room_key
+ )
+
+ defer.returnValue((recents, prev_batch_token, limited))
+
+ @defer.inlineCallbacks
+ def incremental_sync_with_gap_for_room(self, room_id, sync_config,
+ since_token, now_token,
+ published_room_ids, typing_by_room):
+ """ Get the incremental delta needed to bring the client up to date for
+ the room. Gives the client the most recent events and the changes to
+ state.
+ Returns:
+ A Deferred RoomSyncResult
+ """
+
+ # TODO(mjark): Check for redactions we might have missed.
+
+ recents, prev_batch_token, limited = yield self.load_filtered_recents(
+ room_id, sync_config, now_token, since_token,
+ )
+
+ logging.debug("Recents %r", recents)
+
+ # TODO(mjark): This seems racy since this isn't being passed a
+ # token to indicate what point in the stream this is
+ current_state_events = yield self.state_handler.get_current_state(
+ room_id
+ )
+
+ state_at_previous_sync = yield self.get_state_at_previous_sync(
+ room_id, since_token=since_token
+ )
+
+ state_events_delta = yield self.compute_state_delta(
+ since_token=since_token,
+ previous_state=state_at_previous_sync,
+ current_state=current_state_events,
+ )
+
+ state_events_delta = yield self.check_joined_room(
+ sync_config, room_id, state_events_delta
+ )
+
+ room_sync = RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch_token,
+ state=state_events_delta,
+ limited=limited,
+ ephemeral=typing_by_room.get(room_id, [])
+ )
+
+ logging.debug("Room sync: %r", room_sync)
+
+ defer.returnValue(room_sync)
+
+ @defer.inlineCallbacks
+ def get_state_at_previous_sync(self, room_id, since_token):
+ """ Get the room state at the previous sync the client made.
+ Returns:
+ A Deferred list of Events.
+ """
+ last_events, token = yield self.store.get_recent_events_for_room(
+ room_id, end_token=since_token.room_key, limit=1,
+ )
+
+ if last_events:
+ last_event = last_events[0]
+ last_context = yield self.state_handler.compute_event_context(
+ last_event
+ )
+ if last_event.is_state():
+ state = [last_event] + last_context.current_state.values()
+ else:
+ state = last_context.current_state.values()
+ else:
+ state = ()
+ defer.returnValue(state)
+
+ def compute_state_delta(self, since_token, previous_state, current_state):
+ """ Works out the differnce in state between the current state and the
+ state the client got when it last performed a sync.
+ Returns:
+ A list of events.
+ """
+ # TODO(mjark) Check if the state events were received by the server
+ # after the previous sync, since we need to include those state
+ # updates even if they occured logically before the previous event.
+ # TODO(mjark) Check for new redactions in the state events.
+ previous_dict = {event.event_id: event for event in previous_state}
+ state_delta = []
+ for event in current_state:
+ if event.event_id not in previous_dict:
+ state_delta.append(event)
+ return state_delta
+
+ @defer.inlineCallbacks
+ def check_joined_room(self, sync_config, room_id, state_delta):
+ joined = False
+ for event in state_delta:
+ if (
+ event.type == EventTypes.Member
+ and event.state_key == sync_config.user.to_string()
+ ):
+ if event.content["membership"] == Membership.JOIN:
+ joined = True
+
+ if joined:
+ state_delta = yield self.state_handler.get_current_state(room_id)
+
+ defer.returnValue(state_delta)
|