diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b440280b74..88fa0bb2e4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_3pe_protocols(self):
+ def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services()
protocols = {}
+
+ # Collect up all the individual protocol responses out of the ASes
for s in services:
for p in s.protocols:
- protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
+ if only_protocol is not None and p != only_protocol:
+ continue
+
+ if p not in protocols:
+ protocols[p] = []
+
+ info = yield self.appservice_api.get_3pe_protocol(s, p)
+
+ if info is not None:
+ protocols[p].append(info)
+
+ def _merge_instances(infos):
+ if not infos:
+ return {}
+
+ # Merge the 'instances' lists of multiple results, but just take
+ # the other fields from the first as they ought to be identical
+ # copy the result so as not to corrupt the cached one
+ combined = dict(infos[0])
+ combined["instances"] = list(combined["instances"])
+
+ for info in infos[1:]:
+ combined["instances"].extend(info["instances"])
+
+ return combined
+
+ for p in protocols.keys():
+ protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 8d630c6b1a..aa68755936 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -58,7 +58,7 @@ class DeviceHandler(BaseHandler):
attempts = 0
while attempts < 5:
try:
- device_id = stringutils.random_string_with_symbols(16)
+ device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
new file mode 100644
index 0000000000..c5368e5df2
--- /dev/null
+++ b/synapse/handlers/devicemessage.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.types import get_domain_from_id
+from synapse.util.stringutils import random_string
+
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceMessageHandler(object):
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.federation = hs.get_replication_layer()
+
+ self.federation.register_edu_handler(
+ "m.direct_to_device", self.on_direct_to_device_edu
+ )
+
+ @defer.inlineCallbacks
+ def on_direct_to_device_edu(self, origin, content):
+ local_messages = {}
+ sender_user_id = content["sender"]
+ if origin != get_domain_from_id(sender_user_id):
+ logger.warn(
+ "Dropping device message from %r with spoofed sender %r",
+ origin, sender_user_id
+ )
+ message_type = content["type"]
+ message_id = content["message_id"]
+ for user_id, by_device in content["messages"].items():
+ messages_by_device = {
+ device_id: {
+ "content": message_content,
+ "type": message_type,
+ "sender": sender_user_id,
+ }
+ for device_id, message_content in by_device.items()
+ }
+ if messages_by_device:
+ local_messages[user_id] = messages_by_device
+
+ stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+ origin, message_id, local_messages
+ )
+
+ self.notifier.on_new_event(
+ "to_device_key", stream_id, users=local_messages.keys()
+ )
+
+ @defer.inlineCallbacks
+ def send_device_message(self, sender_user_id, message_type, messages):
+
+ local_messages = {}
+ remote_messages = {}
+ for user_id, by_device in messages.items():
+ if self.is_mine_id(user_id):
+ messages_by_device = {
+ device_id: {
+ "content": message_content,
+ "type": message_type,
+ "sender": sender_user_id,
+ }
+ for device_id, message_content in by_device.items()
+ }
+ if messages_by_device:
+ local_messages[user_id] = messages_by_device
+ else:
+ destination = get_domain_from_id(user_id)
+ remote_messages.setdefault(destination, {})[user_id] = by_device
+
+ message_id = random_string(16)
+
+ remote_edu_contents = {}
+ for destination, messages in remote_messages.items():
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ }
+
+ stream_id = yield self.store.add_messages_to_device_inbox(
+ local_messages, remote_edu_contents
+ )
+
+ self.notifier.on_new_event(
+ "to_device_key", stream_id, users=local_messages.keys()
+ )
+
+ for destination in remote_messages.keys():
+ # Enqueue a new federation transaction to send the new
+ # device messages to each remote destination.
+ self.federation.send_device_messages(destination)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..fd11935b40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
-import json
+import ujson as json
import logging
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -29,8 +31,9 @@ class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
+ self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
- self.server_name = hs.hostname
+ self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
@@ -40,7 +43,7 @@ class E2eKeysHandler(object):
)
@defer.inlineCallbacks
- def query_devices(self, query_body):
+ def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
@@ -63,27 +66,60 @@ class E2eKeysHandler(object):
# separate users by domain.
# make a map from domain to user_id to device_ids
- queries_by_domain = collections.defaultdict(dict)
+ local_query = {}
+ remote_queries = {}
+
for user_id, device_ids in device_keys_query.items():
- user = synapse.types.UserID.from_string(user_id)
- queries_by_domain[user.domain][user_id] = device_ids
+ if self.is_mine_id(user_id):
+ local_query[user_id] = device_ids
+ else:
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
- # TODO: do these in parallel
+ failures = {}
results = {}
- for destination, destination_query in queries_by_domain.items():
- if destination == self.server_name:
- res = yield self.query_local_devices(destination_query)
- else:
- res = yield self.federation.query_client_keys(
- destination, {"device_keys": destination_query}
- )
- res = res["device_keys"]
- for user_id, keys in res.items():
- if user_id in destination_query:
+ if local_query:
+ local_result = yield self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
results[user_id] = keys
- defer.returnValue((200, {"device_keys": results}))
+ @defer.inlineCallbacks
+ def do_remote_query(destination):
+ destination_query = remote_queries[destination]
+ try:
+ limiter = yield get_retry_limiter(
+ destination, self.clock, self.store
+ )
+ with limiter:
+ remote_result = yield self.federation.query_client_keys(
+ destination,
+ {"device_keys": destination_query},
+ timeout=timeout
+ )
+
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in destination_query:
+ results[user_id] = keys
+
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+ except NotRetryingDestination as e:
+ failures[destination] = {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(do_remote_query)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue({
+ "device_keys": results, "failures": failures,
+ })
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -104,7 +140,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
- raise errors.SynapseError(400, "Not a user here")
+ raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
@@ -137,3 +173,107 @@ class E2eKeysHandler(object):
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})
+
+ @defer.inlineCallbacks
+ def claim_one_time_keys(self, query, timeout):
+ local_query = []
+ remote_queries = {}
+
+ for user_id, device_keys in query.get("one_time_keys", {}).items():
+ if self.is_mine_id(user_id):
+ for device_id, algorithm in device_keys.items():
+ local_query.append((user_id, device_id, algorithm))
+ else:
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_keys
+
+ results = yield self.store.claim_e2e_one_time_keys(local_query)
+
+ json_result = {}
+ failures = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+
+ @defer.inlineCallbacks
+ def claim_client_keys(destination):
+ device_keys = remote_queries[destination]
+ try:
+ limiter = yield get_retry_limiter(
+ destination, self.clock, self.store
+ )
+ with limiter:
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+ except NotRetryingDestination as e:
+ failures[destination] = {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(claim_client_keys)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue({
+ "one_time_keys": json_result,
+ "failures": failures
+ })
+
+ @defer.inlineCallbacks
+ def upload_keys_for_user(self, user_id, device_id, keys):
+ time_now = self.clock.time_msec()
+
+ # TODO: Validate the JSON to make sure it has the right keys.
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ logger.info(
+ "Updating device_keys for device %r for user %s at %d",
+ device_id, user_id, time_now
+ )
+ # TODO: Sign the JSON with the server key
+ yield self.store.set_e2e_device_keys(
+ user_id, device_id, time_now,
+ encode_canonical_json(device_keys)
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ logger.info(
+ "Adding %d one_time_keys for device %r for user %r at %d",
+ len(one_time_keys), device_id, user_id, time_now
+ )
+ key_list = []
+ for key_id, key_json in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, encode_canonical_json(key_json)
+ ))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, key_list
+ )
+
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ self.device_handler.check_device_registered(user_id, device_id)
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+
+ defer.returnValue({"one_time_key_counts": result})
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index dc90a5dde4..f7cb3c1bb2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -832,11 +832,13 @@ class FederationHandler(BaseHandler):
new_pdu = event
- message_handler = self.hs.get_handlers().message_handler
- destinations = yield message_handler.get_joined_hosts_for_room_from_state(
- context
+ users_in_room = yield self.store.get_joined_users_from_context(event, context)
+
+ destinations = set(
+ get_domain_from_id(user_id) for user_id in users_in_room
+ if not self.hs.is_mine_id(user_id)
)
- destinations = set(destinations)
+
destinations.discard(origin)
logger.debug(
@@ -1055,11 +1057,12 @@ class FederationHandler(BaseHandler):
new_pdu = event
- message_handler = self.hs.get_handlers().message_handler
- destinations = yield message_handler.get_joined_hosts_for_room_from_state(
- context
+ users_in_room = yield self.store.get_joined_users_from_context(event, context)
+
+ destinations = set(
+ get_domain_from_id(user_id) for user_id in users_in_room
+ if not self.hs.is_mine_id(user_id)
)
- destinations = set(destinations)
destinations.discard(origin)
logger.debug(
@@ -1582,10 +1585,12 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
+ context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
+ context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
@@ -1667,10 +1672,12 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
+ context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
+ context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3577db0595..178209a209 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -30,7 +30,6 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
-from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -945,7 +944,12 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id
)
- destinations = yield self.get_joined_hosts_for_room_from_state(context)
+ users_in_room = yield self.store.get_joined_users_from_context(event, context)
+
+ destinations = [
+ get_domain_from_id(user_id) for user_id in users_in_room
+ if not self.hs.is_mine_id(user_id)
+ ]
@defer.inlineCallbacks
def _notify():
@@ -963,39 +967,3 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations,
)
-
- def get_joined_hosts_for_room_from_state(self, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_hosts_for_room_from_state(
- state_group, context.current_state_ids
- )
-
- @cachedInlineCallbacks(num_args=1, cache_context=True)
- def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
- cache_context):
-
- # Don't bother getting state for people on the same HS
- current_state = yield self.store.get_events([
- e_id for key, e_id in current_state_ids.items()
- if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
- ])
-
- destinations = set()
- for e in current_state.itervalues():
- try:
- if e.type == EventTypes.Member:
- if e.content["membership"] == Membership.JOIN:
- destinations.add(get_domain_from_id(e.state_key))
- except SynapseError:
- logger.warn(
- "Failed to get destination from event %s", e.event_id
- )
-
- defer.returnValue(destinations)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index cf82a2336e..b047ae2250 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -52,6 +52,11 @@ bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
+state_transition_counter = metrics.register_counter(
+ "state_transition", labels=["from", "to"]
+)
+
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active"
@@ -212,7 +217,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct.
"""
logger.info(
- "Performing _on_shutdown. Persiting %d unpersisted changes",
+ "Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state)
)
@@ -229,7 +234,7 @@ class PresenceHandler(object):
may stack up and slow down shutdown times.
"""
logger.info(
- "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
+ "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
@@ -260,6 +265,12 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives
+ # Only bother handling the last presence change for each user
+ new_states_dict = {}
+ for new_state in new_states:
+ new_states_dict[new_state.user_id] = new_state
+ new_state = new_states_dict.values()
+
for new_state in new_states:
user_id = new_state.user_id
@@ -614,18 +625,8 @@ class PresenceHandler(object):
Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
"""
- now = self.clock.time_msec()
for host, states in hosts_to_states.items():
- self.federation.send_edu(
- destination=host,
- edu_type="m.presence",
- content={
- "push": [
- _format_user_presence_state(state, now)
- for state in states
- ]
- }
- )
+ self.federation.send_presence(host, states)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
@@ -646,6 +647,13 @@ class PresenceHandler(object):
)
continue
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got presence update from %r with bad 'user_id': %r",
+ origin, user_id,
+ )
+ continue
+
presence_state = push.get("presence", None)
if not presence_state:
logger.info(
@@ -705,13 +713,13 @@ class PresenceHandler(object):
defer.returnValue([
{
"type": "m.presence",
- "content": _format_user_presence_state(state, now),
+ "content": format_user_presence_state(state, now),
}
for state in updates
])
else:
defer.returnValue([
- _format_user_presence_state(state, now) for state in updates
+ format_user_presence_state(state, now) for state in updates
])
@defer.inlineCallbacks
@@ -939,33 +947,38 @@ class PresenceHandler(object):
def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties.
"""
+ if old_state == new_state:
+ return False
+
if old_state.status_msg != new_state.status_msg:
+ notify_reason_counter.inc("status_msg_change")
return True
- if old_state.state == PresenceState.ONLINE:
- if new_state.state != PresenceState.ONLINE:
- # Always notify for online -> anything
- return True
+ if old_state.state != new_state.state:
+ notify_reason_counter.inc("state_change")
+ state_transition_counter.inc(old_state.state, new_state.state)
+ return True
+ if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active:
+ notify_reason_counter.inc("current_active_change")
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
- if not (old_state.currently_active and new_state.currently_active):
+ if not new_state.currently_active:
+ notify_reason_counter.inc("last_active_change_online")
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
- return True
-
- if old_state.state != new_state.state:
+ notify_reason_counter.inc("last_active_change_not_online")
return True
return False
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients
and to other servers.
"""
@@ -1078,7 +1091,7 @@ class PresenceEventSource(object):
defer.returnValue(([
{
"type": "m.presence",
- "content": _format_user_presence_state(s, now),
+ "content": format_user_presence_state(s, now),
}
for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
}
},
},
+ key=(room_id, receipt_type, user_id),
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index bf6b1c1535..cbd26f8f95 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,12 +20,10 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
- EventTypes, JoinRules, RoomCreationPreset, Membership,
+ EventTypes, JoinRules, RoomCreationPreset
)
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
-from synapse.util.async import concurrently_execute
-from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from collections import OrderedDict
@@ -36,8 +34,6 @@ import string
logger = logging.getLogger(__name__)
-REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
-
id_server_scheme = "https://"
@@ -196,6 +192,11 @@ class RoomCreationHandler(BaseHandler):
},
ratelimit=False)
+ content = {}
+ is_direct = config.get("is_direct", None)
+ if is_direct:
+ content["is_direct"] = is_direct
+
for invitee in invite_list:
yield room_member_handler.update_membership(
requester,
@@ -203,6 +204,7 @@ class RoomCreationHandler(BaseHandler):
room_id,
"invite",
ratelimit=False,
+ content=content,
)
for invite_3pid in invite_3pid_list:
@@ -342,149 +344,6 @@ class RoomCreationHandler(BaseHandler):
)
-class RoomListHandler(BaseHandler):
- def __init__(self, hs):
- super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache(hs)
- self.remote_list_request_cache = ResponseCache(hs)
- self.remote_list_cache = {}
- self.fetch_looping_call = hs.get_clock().looping_call(
- self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
- )
- self.fetch_all_remote_lists()
-
- def get_local_public_room_list(self):
- result = self.response_cache.get(())
- if not result:
- result = self.response_cache.set((), self._get_public_room_list())
- return result
-
- @defer.inlineCallbacks
- def _get_public_room_list(self):
- room_ids = yield self.store.get_public_room_ids()
-
- results = []
-
- @defer.inlineCallbacks
- def handle_room(room_id):
- current_state = yield self.state_handler.get_current_state(room_id)
-
- # Double check that this is actually a public room.
- join_rules_event = current_state.get((EventTypes.JoinRules, ""))
- if join_rules_event:
- join_rule = join_rules_event.content.get("join_rule", None)
- if join_rule and join_rule != JoinRules.PUBLIC:
- defer.returnValue(None)
-
- result = {"room_id": room_id}
-
- num_joined_users = len([
- 1 for _, event in current_state.items()
- if event.type == EventTypes.Member
- and event.membership == Membership.JOIN
- ])
- if num_joined_users == 0:
- return
-
- result["num_joined_members"] = num_joined_users
-
- aliases = yield self.store.get_aliases_for_room(room_id)
- if aliases:
- result["aliases"] = aliases
-
- name_event = yield current_state.get((EventTypes.Name, ""))
- if name_event:
- name = name_event.content.get("name", None)
- if name:
- result["name"] = name
-
- topic_event = current_state.get((EventTypes.Topic, ""))
- if topic_event:
- topic = topic_event.content.get("topic", None)
- if topic:
- result["topic"] = topic
-
- canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
- if canonical_event:
- canonical_alias = canonical_event.content.get("alias", None)
- if canonical_alias:
- result["canonical_alias"] = canonical_alias
-
- visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
- visibility = None
- if visibility_event:
- visibility = visibility_event.content.get("history_visibility", None)
- result["world_readable"] = visibility == "world_readable"
-
- guest_event = current_state.get((EventTypes.GuestAccess, ""))
- guest = None
- if guest_event:
- guest = guest_event.content.get("guest_access", None)
- result["guest_can_join"] = guest == "can_join"
-
- avatar_event = current_state.get(("m.room.avatar", ""))
- if avatar_event:
- avatar_url = avatar_event.content.get("url", None)
- if avatar_url:
- result["avatar_url"] = avatar_url
-
- results.append(result)
-
- yield concurrently_execute(handle_room, room_ids, 10)
-
- # FIXME (erikj): START is no longer a valid value
- defer.returnValue({"start": "START", "end": "END", "chunk": results})
-
- @defer.inlineCallbacks
- def fetch_all_remote_lists(self):
- deferred = self.hs.get_replication_layer().get_public_rooms(
- self.hs.config.secondary_directory_servers
- )
- self.remote_list_request_cache.set((), deferred)
- self.remote_list_cache = yield deferred
-
- @defer.inlineCallbacks
- def get_aggregated_public_room_list(self):
- """
- Get the public room list from this server and the servers
- specified in the secondary_directory_servers config option.
- XXX: Pagination...
- """
- # We return the results from out cache which is updated by a looping call,
- # unless we're missing a cache entry, in which case wait for the result
- # of the fetch if there's one in progress. If not, omit that server.
- wait = False
- for s in self.hs.config.secondary_directory_servers:
- if s not in self.remote_list_cache:
- logger.warn("No cached room list from %s: waiting for fetch", s)
- wait = True
- break
-
- if wait and self.remote_list_request_cache.get(()):
- yield self.remote_list_request_cache.get(())
-
- public_rooms = yield self.get_local_public_room_list()
-
- # keep track of which room IDs we've seen so we can de-dup
- room_ids = set()
-
- # tag all the ones in our list with our server name.
- # Also add the them to the de-deping set
- for room in public_rooms['chunk']:
- room["server_name"] = self.hs.hostname
- room_ids.add(room["room_id"])
-
- # Now add the results from federation
- for server_name, server_result in self.remote_list_cache.items():
- for room in server_result["chunk"]:
- if room["room_id"] not in room_ids:
- room["server_name"] = server_name
- public_rooms["chunk"].append(room)
- room_ids.add(room["room_id"])
-
- defer.returnValue(public_rooms)
-
-
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest):
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
new file mode 100644
index 0000000000..5a533682c5
--- /dev/null
+++ b/synapse/handlers/room_list.py
@@ -0,0 +1,400 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 - 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import BaseHandler
+
+from synapse.api.constants import (
+ EventTypes, JoinRules,
+)
+from synapse.util.async import concurrently_execute
+from synapse.util.caches.response_cache import ResponseCache
+
+from collections import namedtuple
+from unpaddedbase64 import encode_base64, decode_base64
+
+import logging
+import msgpack
+
+logger = logging.getLogger(__name__)
+
+REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
+
+
+class RoomListHandler(BaseHandler):
+ def __init__(self, hs):
+ super(RoomListHandler, self).__init__(hs)
+ self.response_cache = ResponseCache(hs)
+ self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+
+ def get_local_public_room_list(self, limit=None, since_token=None,
+ search_filter=None):
+ if search_filter:
+ # We explicitly don't bother caching searches.
+ return self._get_public_room_list(limit, since_token, search_filter)
+
+ result = self.response_cache.get((limit, since_token))
+ if not result:
+ result = self.response_cache.set(
+ (limit, since_token),
+ self._get_public_room_list(limit, since_token)
+ )
+ return result
+
+ @defer.inlineCallbacks
+ def _get_public_room_list(self, limit=None, since_token=None,
+ search_filter=None):
+ if since_token and since_token != "END":
+ since_token = RoomListNextBatch.from_token(since_token)
+ else:
+ since_token = None
+
+ rooms_to_order_value = {}
+ rooms_to_num_joined = {}
+ rooms_to_latest_event_ids = {}
+
+ newly_visible = []
+ newly_unpublished = []
+ if since_token:
+ stream_token = since_token.stream_ordering
+ current_public_id = yield self.store.get_current_public_room_stream_id()
+ public_room_stream_id = since_token.public_room_stream_id
+ newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
+ public_room_stream_id, current_public_id
+ )
+ else:
+ stream_token = yield self.store.get_room_max_stream_ordering()
+ public_room_stream_id = yield self.store.get_current_public_room_stream_id()
+
+ room_ids = yield self.store.get_public_room_ids_at_stream_id(
+ public_room_stream_id
+ )
+
+ # We want to return rooms in a particular order: the number of joined
+ # users. We then arbitrarily use the room_id as a tie breaker.
+
+ @defer.inlineCallbacks
+ def get_order_for_room(room_id):
+ latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_forward_extremeties_for_room(
+ room_id, stream_token
+ )
+ rooms_to_latest_event_ids[room_id] = latest_event_ids
+
+ if not latest_event_ids:
+ return
+
+ joined_users = yield self.state_handler.get_current_user_in_room(
+ room_id, latest_event_ids,
+ )
+ num_joined_users = len(joined_users)
+ rooms_to_num_joined[room_id] = num_joined_users
+
+ if num_joined_users == 0:
+ return
+
+ # We want larger rooms to be first, hence negating num_joined_users
+ rooms_to_order_value[room_id] = (-num_joined_users, room_id)
+
+ yield concurrently_execute(get_order_for_room, room_ids, 10)
+
+ sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
+ sorted_rooms = [room_id for room_id, _ in sorted_entries]
+
+ # `sorted_rooms` should now be a list of all public room ids that is
+ # stable across pagination. Therefore, we can use indices into this
+ # list as our pagination tokens.
+
+ # Filter out rooms that we don't want to return
+ rooms_to_scan = [
+ r for r in sorted_rooms
+ if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
+ ]
+
+ if since_token:
+ # Filter out rooms we've already returned previously
+ # `since_token.current_limit` is the index of the last room we
+ # sent down, so we exclude it and everything before/after it.
+ if since_token.direction_is_forward:
+ rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
+ else:
+ rooms_to_scan = rooms_to_scan[:since_token.current_limit]
+ rooms_to_scan.reverse()
+
+ # Actually generate the entries. _generate_room_entry will append to
+ # chunk but will stop if len(chunk) > limit
+ chunk = []
+ if limit and not search_filter:
+ step = limit + 1
+ for i in xrange(0, len(rooms_to_scan), step):
+ # We iterate here because the vast majority of cases we'll stop
+ # at first iteration, but occaisonally _generate_room_entry
+ # won't append to the chunk and so we need to loop again.
+ # We don't want to scan over the entire range either as that
+ # would potentially waste a lot of work.
+ yield concurrently_execute(
+ lambda r: self._generate_room_entry(
+ r, rooms_to_num_joined[r],
+ chunk, limit, search_filter
+ ),
+ rooms_to_scan[i:i + step], 10
+ )
+ if len(chunk) >= limit + 1:
+ break
+ else:
+ yield concurrently_execute(
+ lambda r: self._generate_room_entry(
+ r, rooms_to_num_joined[r],
+ chunk, limit, search_filter
+ ),
+ rooms_to_scan, 5
+ )
+
+ chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
+
+ # Work out the new limit of the batch for pagination, or None if we
+ # know there are no more results that would be returned.
+ # i.e., [since_token.current_limit..new_limit] is the batch of rooms
+ # we've returned (or the reverse if we paginated backwards)
+ # We tried to pull out limit + 1 rooms above, so if we have <= limit
+ # then we know there are no more results to return
+ new_limit = None
+ if chunk and (not limit or len(chunk) > limit):
+
+ if not since_token or since_token.direction_is_forward:
+ if limit:
+ chunk = chunk[:limit]
+ last_room_id = chunk[-1]["room_id"]
+ else:
+ if limit:
+ chunk = chunk[-limit:]
+ last_room_id = chunk[0]["room_id"]
+
+ new_limit = sorted_rooms.index(last_room_id)
+
+ results = {
+ "chunk": chunk,
+ }
+
+ if since_token:
+ results["new_rooms"] = bool(newly_visible)
+
+ if not since_token or since_token.direction_is_forward:
+ if new_limit is not None:
+ results["next_batch"] = RoomListNextBatch(
+ stream_ordering=stream_token,
+ public_room_stream_id=public_room_stream_id,
+ current_limit=new_limit,
+ direction_is_forward=True,
+ ).to_token()
+
+ if since_token:
+ results["prev_batch"] = since_token.copy_and_replace(
+ direction_is_forward=False,
+ current_limit=since_token.current_limit + 1,
+ ).to_token()
+ else:
+ if new_limit is not None:
+ results["prev_batch"] = RoomListNextBatch(
+ stream_ordering=stream_token,
+ public_room_stream_id=public_room_stream_id,
+ current_limit=new_limit,
+ direction_is_forward=False,
+ ).to_token()
+
+ if since_token:
+ results["next_batch"] = since_token.copy_and_replace(
+ direction_is_forward=True,
+ current_limit=since_token.current_limit - 1,
+ ).to_token()
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
+ def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
+ search_filter):
+ if limit and len(chunk) > limit + 1:
+ # We've already got enough, so lets just drop it.
+ return
+
+ result = {
+ "room_id": room_id,
+ "num_joined_members": num_joined_users,
+ }
+
+ current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
+
+ event_map = yield self.store.get_events([
+ event_id for key, event_id in current_state_ids.items()
+ if key[0] in (
+ EventTypes.JoinRules,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.GuestAccess,
+ "m.room.avatar",
+ )
+ ])
+
+ current_state = {
+ (ev.type, ev.state_key): ev
+ for ev in event_map.values()
+ }
+
+ # Double check that this is actually a public room.
+ join_rules_event = current_state.get((EventTypes.JoinRules, ""))
+ if join_rules_event:
+ join_rule = join_rules_event.content.get("join_rule", None)
+ if join_rule and join_rule != JoinRules.PUBLIC:
+ defer.returnValue(None)
+
+ aliases = yield self.store.get_aliases_for_room(room_id)
+ if aliases:
+ result["aliases"] = aliases
+
+ name_event = yield current_state.get((EventTypes.Name, ""))
+ if name_event:
+ name = name_event.content.get("name", None)
+ if name:
+ result["name"] = name
+
+ topic_event = current_state.get((EventTypes.Topic, ""))
+ if topic_event:
+ topic = topic_event.content.get("topic", None)
+ if topic:
+ result["topic"] = topic
+
+ canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
+ if canonical_event:
+ canonical_alias = canonical_event.content.get("alias", None)
+ if canonical_alias:
+ result["canonical_alias"] = canonical_alias
+
+ visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
+ visibility = None
+ if visibility_event:
+ visibility = visibility_event.content.get("history_visibility", None)
+ result["world_readable"] = visibility == "world_readable"
+
+ guest_event = current_state.get((EventTypes.GuestAccess, ""))
+ guest = None
+ if guest_event:
+ guest = guest_event.content.get("guest_access", None)
+ result["guest_can_join"] = guest == "can_join"
+
+ avatar_event = current_state.get(("m.room.avatar", ""))
+ if avatar_event:
+ avatar_url = avatar_event.content.get("url", None)
+ if avatar_url:
+ result["avatar_url"] = avatar_url
+
+ if _matches_room_entry(result, search_filter):
+ chunk.append(result)
+
+ @defer.inlineCallbacks
+ def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
+ search_filter=None):
+ if search_filter:
+ # We currently don't support searching across federation, so we have
+ # to do it manually without pagination
+ limit = None
+ since_token = None
+
+ res = yield self._get_remote_list_cached(
+ server_name, limit=limit, since_token=since_token,
+ )
+
+ if search_filter:
+ res = {"chunk": [
+ entry
+ for entry in list(res.get("chunk", []))
+ if _matches_room_entry(entry, search_filter)
+ ]}
+
+ defer.returnValue(res)
+
+ def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
+ search_filter=None):
+ repl_layer = self.hs.get_replication_layer()
+ if search_filter:
+ # We can't cache when asking for search
+ return repl_layer.get_public_rooms(
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ )
+
+ result = self.remote_response_cache.get((server_name, limit, since_token))
+ if not result:
+ result = self.remote_response_cache.set(
+ (server_name, limit, since_token),
+ repl_layer.get_public_rooms(
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ )
+ )
+ return result
+
+
+class RoomListNextBatch(namedtuple("RoomListNextBatch", (
+ "stream_ordering", # stream_ordering of the first public room list
+ "public_room_stream_id", # public room stream id for first public room list
+ "current_limit", # The number of previous rooms returned
+ "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
+))):
+
+ KEY_DICT = {
+ "stream_ordering": "s",
+ "public_room_stream_id": "p",
+ "current_limit": "n",
+ "direction_is_forward": "d",
+ }
+
+ REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
+
+ @classmethod
+ def from_token(cls, token):
+ return RoomListNextBatch(**{
+ cls.REVERSE_KEY_DICT[key]: val
+ for key, val in msgpack.loads(decode_base64(token)).items()
+ })
+
+ def to_token(self):
+ return encode_base64(msgpack.dumps({
+ self.KEY_DICT[key]: val
+ for key, val in self._asdict().items()
+ }))
+
+ def copy_and_replace(self, **kwds):
+ return self._replace(
+ **kwds
+ )
+
+
+def _matches_room_entry(room_entry, search_filter):
+ if search_filter and search_filter.get("generic_search_term", None):
+ generic_search_term = search_filter["generic_search_term"].upper()
+ if generic_search_term in room_entry.get("name", "").upper():
+ return True
+ elif generic_search_term in room_entry.get("topic", "").upper():
+ return True
+ elif generic_search_term in room_entry.get("canonical_alias", "").upper():
+ return True
+ else:
+ return True
+
+ return False
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0b530b9034..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
"user_id": user_id,
"typing": typing,
},
+ key=(room_id, user_id),
))
yield preserve_context_over_deferred(
@@ -199,7 +200,14 @@ class TypingHandler(object):
user_id = content["user_id"]
# Check that the string is a valid user id
- UserID.from_string(user_id)
+ user = UserID.from_string(user_id)
+
+ if user.domain != origin:
+ logger.info(
+ "Got typing update from %r with bad 'user_id': %r",
+ origin, user_id,
+ )
+ return
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
|