diff --git a/README.rst b/README.rst
index 291a308b64..297e72f1ae 100644
--- a/README.rst
+++ b/README.rst
@@ -258,6 +258,14 @@ During setup of Synapse you need to call python2.7 directly again::
...substituting your host and domain name as appropriate.
+FreeBSD
+-------
+
+Synapse can be installed via FreeBSD Ports or Packages:
+
+ - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
+ - Packages: ``pkg install py27-matrix-synapse``
+
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b86c6c8399..b5536e8565 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
-from synapse.types import RoomID, UserID, EventID
+from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from unpaddedbase64 import decode_base64
@@ -510,35 +510,14 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
- access_token = request.args["access_token"][0]
-
- # Check for application service tokens with a user_id override
- try:
- app_service = yield self.store.get_app_service_by_token(
- access_token
- )
- if not app_service:
- raise KeyError
-
- user_id = app_service.sender
- if "user_id" in request.args:
- user_id = request.args["user_id"][0]
- if not app_service.is_interested_in_user(user_id):
- raise AuthError(
- 403,
- "Application service cannot masquerade as this user."
- )
-
- if not user_id:
- raise KeyError
-
+ user_id = yield self._get_appservice_user_id(request.args)
+ if user_id:
request.authenticated_entity = user_id
+ defer.returnValue(
+ Requester(UserID.from_string(user_id), "", False)
+ )
- defer.returnValue((UserID.from_string(user_id), "", False))
- return
- except KeyError:
- pass # normal users won't have the user_id query parameter set.
-
+ access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"]
token_id = user_info["token_id"]
@@ -564,7 +543,7 @@ class Auth(object):
request.authenticated_entity = user.to_string()
- defer.returnValue((user, token_id, is_guest,))
+ defer.returnValue(Requester(user, token_id, is_guest))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -572,6 +551,33 @@ class Auth(object):
)
@defer.inlineCallbacks
+ def _get_appservice_user_id(self, request_args):
+ app_service = yield self.store.get_app_service_by_token(
+ request_args["access_token"][0]
+ )
+ if app_service is None:
+ defer.returnValue(None)
+
+ if "user_id" not in request_args:
+ defer.returnValue(app_service.sender)
+
+ user_id = request_args["user_id"][0]
+ if app_service.sender == user_id:
+ defer.returnValue(app_service.sender)
+
+ if not app_service.is_interested_in_user(user_id):
+ raise AuthError(
+ 403,
+ "Application service cannot masquerade as this user."
+ )
+ if not (yield self.store.get_user_by_id(user_id)):
+ raise AuthError(
+ 403,
+ "Application service has not registered this user"
+ )
+ defer.returnValue(user_id)
+
+ @defer.inlineCallbacks
def _get_user_by_access_token(self, token):
""" Get a registered user's ID.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index be0c58a4ca..b106fbed6d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -29,6 +29,7 @@ class Codes(object):
USER_IN_USE = "M_USER_IN_USE"
ROOM_IN_USE = "M_ROOM_IN_USE"
BAD_PAGINATION = "M_BAD_PAGINATION"
+ BAD_STATE = "M_BAD_STATE"
UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
@@ -42,6 +43,7 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE"
+ INVALID_USERNAME = "M_INVALID_USERNAME"
class CodeMessageException(RuntimeError):
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 869a623090..bbfa5a7265 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
+ def __getitem__(self, field):
+ return self._event_dict[field]
+
+ def __contains__(self, field):
+ return field in self._event_dict
+
+ def items(self):
+ return self._event_dict.items()
+
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index b474042e84..2d1167296a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -19,6 +19,7 @@ from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias
+from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext
@@ -52,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
- def _filter_events_for_client(self, user_id, events, is_guest=False):
- # Assumes that user has at some point joined the room if not is_guest.
+ def _filter_events_for_clients(self, users, events):
+ """ Returns dict of user_id -> list of events that user is allowed to
+ see.
+ """
+ event_id_to_state = yield self.store.get_state_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ )
+ )
+
+ forgotten = yield defer.gatherResults([
+ self.store.who_forgot_in_room(
+ room_id,
+ )
+ for room_id in frozenset(e.room_id for e in events)
+ ], consumeErrors=True)
+
+ # Set of membership event_ids that have been forgotten
+ event_id_forgotten = frozenset(
+ row["event_id"] for rows in forgotten for row in rows
+ )
+
+ def allowed(event, user_id, is_guest):
+ state = event_id_to_state[event.event_id]
+
+ visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+ if visibility_event:
+ visibility = visibility_event.content.get("history_visibility", "shared")
+ else:
+ visibility = "shared"
- def allowed(event, membership, visibility):
if visibility == "world_readable":
return True
if is_guest:
return False
+ membership_event = state.get((EventTypes.Member, user_id), None)
+ if membership_event:
+ if membership_event.event_id in event_id_forgotten:
+ membership = None
+ else:
+ membership = membership_event.membership
+ else:
+ membership = None
+
if membership == Membership.JOIN:
return True
@@ -77,43 +116,20 @@ class BaseHandler(object):
return True
- event_id_to_state = yield self.store.get_state_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
- )
- )
-
- events_to_return = []
- for event in events:
- state = event_id_to_state[event.event_id]
-
- membership_event = state.get((EventTypes.Member, user_id), None)
- if membership_event:
- was_forgotten_at_event = yield self.store.was_forgotten_at(
- membership_event.state_key,
- membership_event.room_id,
- membership_event.event_id
- )
- if was_forgotten_at_event:
- membership = None
- else:
- membership = membership_event.membership
- else:
- membership = None
-
- visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
- if visibility_event:
- visibility = visibility_event.content.get("history_visibility", "shared")
- else:
- visibility = "shared"
-
- should_include = allowed(event, membership, visibility)
- if should_include:
- events_to_return.append(event)
+ defer.returnValue({
+ user_id: [
+ event
+ for event in events
+ if allowed(event, user_id, is_guest)
+ ]
+ for user_id, is_guest in users
+ })
- defer.returnValue(events_to_return)
+ @defer.inlineCallbacks
+ def _filter_events_for_client(self, user_id, events, is_guest=False):
+ # Assumes that user has at some point joined the room if not is_guest.
+ res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
+ defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id):
time_now = self.clock.time()
@@ -170,12 +186,10 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
- def handle_new_client_event(self, event, context, extra_destinations=[],
- extra_users=[], suppress_auth=False):
+ def handle_new_client_event(self, event, context, extra_users=[]):
# We now need to go and hit out to wherever we need to hit out to.
- if not suppress_auth:
- self.auth.check(event, auth_events=context.current_state)
+ self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
@@ -252,7 +266,12 @@ class BaseHandler(object):
event, context=context
)
- destinations = set(extra_destinations)
+ action_generator = ActionGenerator(self.store)
+ yield action_generator.handle_push_actions_for_event(
+ event, self
+ )
+
+ destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 28c674730e..c73eec2b91 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -36,10 +36,6 @@ def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user)
-def user_joined_room(distributor, user, room_id):
- return distributor.fire("user_joined_room", user, room_id)
-
-
class EventStreamHandler(BaseHandler):
def __init__(self, hs):
@@ -136,9 +132,6 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
- if is_guest:
- yield user_joined_room(self.distributor, auth_user, room_id)
-
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,
only_room_events=only_room_events,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c1704afc53..4b94940e99 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -36,6 +36,8 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
+from synapse.push.action_generator import ActionGenerator
+
from twisted.internet import defer
import itertools
@@ -242,6 +244,12 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
+ if not backfilled and not event.internal_metadata.is_outlier():
+ action_generator = ActionGenerator(self.store)
+ yield action_generator.handle_push_actions_for_event(
+ event, self
+ )
+
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
@@ -1684,7 +1692,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
- yield member_handler.change_membership(event, context)
+ yield member_handler.send_membership_event(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
@@ -1713,7 +1721,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
- yield member_handler.change_membership(event, context)
+ yield member_handler.send_membership_event(event, context)
@defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5805190ce8..4c7bf2bef3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -174,30 +174,25 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_and_send_event(self, event_dict, ratelimit=True,
- token_id=None, txn_id=None, is_guest=False):
- """ Given a dict from a client, create and handle a new event.
+ def create_event(self, event_dict, token_id=None, txn_id=None):
+ """
+ Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
- Persists and notifies local clients and federation.
-
Args:
event_dict (dict): An entire event
+
+ Returns:
+ Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder)
- if ratelimit:
- self.ratelimit(builder.user_id)
- # TODO(paul): Why does 'event' not have a 'user' object?
- user = UserID.from_string(builder.user_id)
- assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
-
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
@@ -216,6 +211,25 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
)
+ defer.returnValue((event, context))
+
+ @defer.inlineCallbacks
+ def send_event(self, event, context, ratelimit=True, is_guest=False):
+ """
+ Persists and notifies local clients and federation of an event.
+
+ Args:
+ event (FrozenEvent) the event to send.
+ context (Context) the context of the event.
+ ratelimit (bool): Whether to rate limit this send.
+ is_guest (bool): Whether the sender is a guest.
+ """
+ user = UserID.from_string(event.sender)
+
+ assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
+
+ if ratelimit:
+ self.ratelimit(event.sender)
if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key))
@@ -229,7 +243,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
- yield member_handler.change_membership(event, context, is_guest=is_guest)
+ yield member_handler.send_membership_event(event, context, is_guest=is_guest)
else:
yield self.handle_new_client_event(
event=event,
@@ -241,6 +255,25 @@ class MessageHandler(BaseHandler):
with PreserveLoggingContext():
presence.bump_presence_active_time(user)
+ @defer.inlineCallbacks
+ def create_and_send_event(self, event_dict, ratelimit=True,
+ token_id=None, txn_id=None, is_guest=False):
+ """
+ Creates an event, then sends it.
+
+ See self.create_event and self.send_event.
+ """
+ event, context = yield self.create_event(
+ event_dict,
+ token_id=token_id,
+ txn_id=txn_id
+ )
+ yield self.send_event(
+ event,
+ context,
+ ratelimit=ratelimit,
+ is_guest=is_guest
+ )
defer.returnValue(event)
@defer.inlineCallbacks
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6f111ff63e..8e601b052b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -53,7 +53,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400,
"User ID must only contain characters which do not"
- " require URL encoding."
+ " require URL encoding.",
+ Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)
@@ -84,7 +85,8 @@ class RegistrationHandler(BaseHandler):
localpart=None,
password=None,
generate_token=True,
- guest_access_token=None
+ guest_access_token=None,
+ make_guest=False
):
"""Registers a new client on the server.
@@ -118,6 +120,7 @@ class RegistrationHandler(BaseHandler):
token=token,
password_hash=password_hash,
was_guest=guest_access_token is not None,
+ make_guest=make_guest,
)
yield registered_user(self.distributor, user)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 48a07e4e35..a1baf9d200 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,7 +22,7 @@ from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
-from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
@@ -397,7 +397,58 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain)
@defer.inlineCallbacks
- def change_membership(self, event, context, do_auth=True, is_guest=False):
+ def update_membership(self, requester, target, room_id, action, txn_id=None):
+ effective_membership_state = action
+ if action in ["kick", "unban"]:
+ effective_membership_state = "leave"
+ elif action == "forget":
+ effective_membership_state = "leave"
+
+ msg_handler = self.hs.get_handlers().message_handler
+
+ content = {"membership": unicode(effective_membership_state)}
+ if requester.is_guest:
+ content["kind"] = "guest"
+
+ event, context = yield msg_handler.create_event(
+ {
+ "type": EventTypes.Member,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "state_key": target.to_string(),
+ },
+ token_id=requester.access_token_id,
+ txn_id=txn_id,
+ )
+
+ old_state = context.current_state.get((EventTypes.Member, event.state_key))
+ old_membership = old_state.content.get("membership") if old_state else None
+ if action == "unban" and old_membership != "ban":
+ raise SynapseError(
+ 403,
+ "Cannot unban user who was not banned (membership=%s)" % old_membership,
+ errcode=Codes.BAD_STATE
+ )
+ if old_membership == "ban" and action != "unban":
+ raise SynapseError(
+ 403,
+ "Cannot %s user who was is banned" % (action,),
+ errcode=Codes.BAD_STATE
+ )
+
+ yield msg_handler.send_event(
+ event,
+ context,
+ ratelimit=True,
+ is_guest=requester.is_guest
+ )
+
+ if action == "forget":
+ yield self.forget(requester.user, room_id)
+
+ @defer.inlineCallbacks
+ def send_membership_event(self, event, context, is_guest=False):
""" Change the membership status of a user in a room.
Args:
@@ -432,7 +483,7 @@ class RoomMemberHandler(BaseHandler):
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
- yield self._do_join(event, context, do_auth=do_auth)
+ yield self._do_join(event, context)
else:
if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
@@ -459,9 +510,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update(
event,
- membership=event.content["membership"],
context=context,
- do_auth=do_auth,
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -497,12 +546,12 @@ class RoomMemberHandler(BaseHandler):
})
event, context = yield self._create_new_client_event(builder)
- yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
+ yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
- def _do_join(self, event, context, room_hosts=None, do_auth=True):
+ def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite
@@ -536,9 +585,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update(
event,
- membership=event.content["membership"],
context=context,
- do_auth=do_auth,
)
prev_state = context.current_state.get((event.type, event.state_key))
@@ -603,8 +650,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids)
@defer.inlineCallbacks
- def _do_local_membership_update(self, event, membership, context,
- do_auth):
+ def _do_local_membership_update(self, event, context):
yield run_on_reactor()
target_user = UserID.from_string(event.state_key)
@@ -613,7 +659,6 @@ class RoomMemberHandler(BaseHandler):
event,
context,
extra_users=[target_user],
- suppress_auth=(not do_auth),
)
@defer.inlineCallbacks
@@ -880,28 +925,39 @@ class RoomContextHandler(BaseHandler):
(excluding state).
Returns:
- dict
+ dict, or None if the event isn't found
"""
before_limit = math.floor(limit/2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
+ def filter_evts(events):
+ return self._filter_events_for_client(
+ user.to_string(),
+ events,
+ is_guest=is_guest)
+
+ event = yield self.store.get_event(event_id, get_prev_content=True,
+ allow_none=True)
+ if not event:
+ defer.returnValue(None)
+ return
+
+ filtered = yield(filter_evts([event]))
+ if not filtered:
+ raise AuthError(
+ 403,
+ "You don't have permission to access that event."
+ )
+
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit
)
- results["events_before"] = yield self._filter_events_for_client(
- user.to_string(),
- results["events_before"],
- is_guest=is_guest,
- )
-
- results["events_after"] = yield self._filter_events_for_client(
- user.to_string(),
- results["events_after"],
- is_guest=is_guest,
- )
+ results["events_before"] = yield filter_evts(results["events_before"])
+ results["events_after"] = yield filter_evts(results["events_after"])
+ results["event"] = event
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3bc18b4338..d2864977b0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -54,6 +54,8 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"state", # dict[(str, str), FrozenEvent]
"ephemeral",
"account_data",
+ "unread_notification_count",
+ "unread_highlight_count",
])):
__slots__ = []
@@ -66,6 +68,8 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
or self.state
or self.ephemeral
or self.account_data
+ # nb the notification count does not, er, count: if there's nothing
+ # else in the result, we don't need to send it.
)
@@ -163,6 +167,18 @@ class SyncHandler(BaseHandler):
else:
return self.incremental_sync_with_gap(sync_config, since_token)
+ def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
+ if room_id not in ephemeral_by_room:
+ return None
+ for e in ephemeral_by_room[room_id]:
+ if e['type'] != 'm.receipt':
+ continue
+ for receipt_event_id, val in e['content'].items():
+ if 'm.read' in val:
+ if user_id in val['m.read']:
+ return receipt_event_id
+ return None
+
@defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state.
@@ -274,6 +290,18 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token
)
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config, ephemeral_by_room
+ )
+
+ notif_count = None
+ highlight_count = None
+ if notifs is not None:
+ notif_count = len(notifs)
+ highlight_count = len([
+ 1 for notif in notifs if _action_has_highlight(notif["actions"])
+ ])
+
current_state = yield self.get_state_at(room_id, now_token)
defer.returnValue(JoinedSyncResult(
@@ -284,6 +312,8 @@ class SyncHandler(BaseHandler):
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
+ unread_notification_count=notif_count,
+ unread_highlight_count=highlight_count,
))
def account_data_for_user(self, account_data):
@@ -423,6 +453,13 @@ class SyncHandler(BaseHandler):
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
+ # We now fetch all ephemeral events for this room in order to get
+ # this users current read receipt. This could almost certainly be
+ # optimised.
+ _, all_ephemeral_by_room = yield self.ephemeral_by_room(
+ sync_config, now_token
+ )
+
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
)
@@ -496,6 +533,18 @@ class SyncHandler(BaseHandler):
else:
prev_batch = now_token
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config, all_ephemeral_by_room
+ )
+
+ notif_count = None
+ highlight_count = None
+ if notifs is not None:
+ notif_count = len(notifs)
+ highlight_count = len([
+ 1 for notif in notifs if _action_has_highlight(notif["actions"])
+ ])
+
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
@@ -516,6 +565,8 @@ class SyncHandler(BaseHandler):
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
+ unread_notification_count=notif_count,
+ unread_highlight_count=highlight_count,
)
logger.debug("Result for room %s: %r", room_id, room_sync)
@@ -537,7 +588,8 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
- ephemeral_by_room, tags_by_room, account_data_by_room
+ ephemeral_by_room, tags_by_room, account_data_by_room,
+ all_ephemeral_by_room=all_ephemeral_by_room,
)
if room_sync:
joined.append(room_sync)
@@ -547,7 +599,8 @@ class SyncHandler(BaseHandler):
sync_config, leave_event, since_token, tags_by_room,
account_data_by_room
)
- archived.append(room_sync)
+ if room_sync:
+ archived.append(room_sync)
invited = [
InvitedSyncResult(room_id=event.room_id, invite=event)
@@ -616,7 +669,8 @@ class SyncHandler(BaseHandler):
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
ephemeral_by_room, tags_by_room,
- account_data_by_room):
+ account_data_by_room,
+ all_ephemeral_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
@@ -632,7 +686,7 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token,
)
- logging.debug("Recents %r", batch)
+ logger.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token)
@@ -650,6 +704,18 @@ class SyncHandler(BaseHandler):
if just_joined:
state = yield self.get_state_at(room_id, now_token)
+ notifs = yield self.unread_notifs_for_room_id(
+ room_id, sync_config, all_ephemeral_by_room
+ )
+
+ notif_count = None
+ highlight_count = None
+ if notifs is not None:
+ notif_count = len(notifs)
+ highlight_count = len([
+ 1 for notif in notifs if _action_has_highlight(notif["actions"])
+ ])
+
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -658,6 +724,8 @@ class SyncHandler(BaseHandler):
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
+ unread_notification_count=notif_count,
+ unread_highlight_count=highlight_count,
)
logger.debug("Room sync: %r", room_sync)
@@ -680,11 +748,14 @@ class SyncHandler(BaseHandler):
leave_token = since_token.copy_and_replace("room_key", stream_token)
+ if since_token.is_after(leave_token):
+ defer.returnValue(None)
+
batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token,
)
- logging.debug("Recents %r", batch)
+ logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
@@ -788,3 +859,31 @@ class SyncHandler(BaseHandler):
if join_event.content["membership"] == Membership.JOIN:
return True
return False
+
+ @defer.inlineCallbacks
+ def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
+ last_unread_event_id = self.last_read_event_id_for_room_and_user(
+ room_id, sync_config.user.to_string(), ephemeral_by_room
+ )
+
+ notifs = []
+ if last_unread_event_id:
+ notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ defer.returnValue(notifs)
+
+ # There is no new information in this period, so your notification
+ # count is whatever it was last time.
+ defer.returnValue(None)
+
+
+def _action_has_highlight(actions):
+ for action in actions:
+ try:
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
+ except AttributeError:
+ pass
+
+ return False
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 7dc656b7cb..a5dc84160c 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -27,12 +27,15 @@ import random
logger = logging.getLogger(__name__)
+# Pushers could now be moved to pull out of the event_push_actions table instead
+# of listening on the event stream: this would avoid them having to run the
+# rules again.
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
- def __init__(self, _hs, profile_tag, user_name, app_id,
+ def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
@@ -40,7 +43,7 @@ class Pusher(object):
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
- self.user_name = user_name
+ self.user_id = user_id
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
@@ -89,15 +92,15 @@ class Pusher(object):
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
- self.user_name, config, timeout=0, affect_presence=False,
+ self.user_id, config, timeout=0, affect_presence=False,
only_room_events=True
)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
- self.app_id, self.pushkey, self.user_name, self.last_token
+ self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s",
- self.pushkey, self.user_name, self.last_token)
+ self.pushkey, self.user_id, self.last_token)
wait = 0
while self.alive:
@@ -122,7 +125,7 @@ class Pusher(object):
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
- self.user_name, config, timeout=timeout, affect_presence=False,
+ self.user_id, config, timeout=timeout, affect_presence=False,
only_room_events=True
)
@@ -139,7 +142,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.last_token
)
return
@@ -150,28 +153,14 @@ class Pusher(object):
processed = False
rule_evaluator = yield \
- push_rule_evaluator.evaluator_for_user_name_and_profile_tag(
- self.user_name, self.profile_tag, single_event['room_id'], self.store
+ push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
+ self.user_id, self.profile_tag, single_event['room_id'], self.store
)
actions = yield rule_evaluator.actions_for_event(single_event)
tweaks = rule_evaluator.tweaks_for_actions(actions)
- if len(actions) == 0:
- logger.warn("Empty actions! Using default action.")
- actions = Pusher.DEFAULT_ACTIONS
-
- if 'notify' not in actions and 'dont_notify' not in actions:
- logger.warn("Neither notify nor dont_notify in actions: adding default")
- actions.extend(Pusher.DEFAULT_ACTIONS)
-
- if 'dont_notify' in actions:
- logger.debug(
- "%s for %s: dont_notify",
- single_event['event_id'], self.user_name
- )
- processed = True
- else:
+ if 'notify' in actions:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
@@ -190,8 +179,10 @@ class Pusher(object):
pk
)
yield self.hs.get_pusherpool().remove_pusher(
- self.app_id, pk, self.user_name
+ self.app_id, pk, self.user_id
)
+ else:
+ processed = True
if not self.alive:
return
@@ -202,7 +193,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.last_token,
self.clock.time_msec()
)
@@ -211,7 +202,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.failing_since)
else:
if not self.failing_since:
@@ -219,7 +210,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.failing_since
)
@@ -231,13 +222,13 @@ class Pusher(object):
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
- self.user_name, self.pushkey)
+ self.user_id, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.last_token
)
@@ -245,14 +236,14 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
- self.user_name,
+ self.user_id,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
- self.user_name,
+ self.user_id,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
@@ -289,7 +280,7 @@ class Pusher(object):
if last_active > self.last_last_active_time:
self.last_last_active_time = last_active
if self.has_unread:
- logger.info("Resetting badge count for %s", self.user_name)
+ logger.info("Resetting badge count for %s", self.user_id)
self.reset_badge_count()
self.has_unread = False
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
new file mode 100644
index 0000000000..4cf94f6c61
--- /dev/null
+++ b/synapse/push/action_generator.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import bulk_push_rule_evaluator
+
+import logging
+
+from synapse.api.constants import EventTypes
+
+logger = logging.getLogger(__name__)
+
+
+class ActionGenerator:
+ def __init__(self, store):
+ self.store = store
+ # really we want to get all user ids and all profile tags too,
+ # since we want the actions for each profile tag for every user and
+ # also actions for a client with no profile tag for each user.
+ # Currently the event stream doesn't support profile tags on an
+ # event stream, so we just run the rules for a client with no profile
+ # tag (ie. we just need all the users).
+
+ @defer.inlineCallbacks
+ def handle_push_actions_for_event(self, event, handler):
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ yield self.store.remove_push_actions_for_event_id(
+ event.room_id, event.redacts
+ )
+
+ bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
+ event.room_id, self.store
+ )
+
+ actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
+
+ yield self.store.set_push_actions_for_event_and_users(
+ event,
+ [
+ (uid, None, actions) for uid, actions in actions_by_user.items()
+ ]
+ )
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 57de0e34b4..3b526c4e33 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,27 +15,25 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
-def list_with_base_rules(rawrules, user_name):
+def list_with_base_rules(rawrules):
ruleslist = []
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules(
- user_name, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules(
- user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
- user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
@@ -43,223 +41,232 @@ def list_with_base_rules(rawrules, user_name):
while current_prio_class > 0:
ruleslist.extend(make_base_append_rules(
- user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
- user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
return ruleslist
-def make_base_append_rules(user, kind):
+def make_base_append_rules(kind):
rules = []
if kind == 'override':
- rules = make_base_append_override_rules()
+ rules = BASE_APPEND_OVRRIDE_RULES
elif kind == 'underride':
- rules = make_base_append_underride_rules(user)
+ rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content':
- rules = make_base_append_content_rules(user)
-
- for r in rules:
- r['priority_class'] = PRIORITY_CLASS_MAP[kind]
- r['default'] = True # Deprecated, left for backwards compat
+ rules = BASE_APPEND_CONTENT_RULES
return rules
-def make_base_prepend_rules(user, kind):
+def make_base_prepend_rules(kind):
rules = []
if kind == 'override':
- rules = make_base_prepend_override_rules()
-
- for r in rules:
- r['priority_class'] = PRIORITY_CLASS_MAP[kind]
- r['default'] = True # Deprecated, left for backwards compat
+ rules = BASE_PREPEND_OVERRIDE_RULES
return rules
-def make_base_append_content_rules(user):
- return [
- {
- 'rule_id': 'global/content/.m.rule.contains_user_name',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': user.localpart, # Matrix ID match
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default',
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
- ]
+BASE_APPEND_CONTENT_RULES = [
+ {
+ 'rule_id': 'global/content/.m.rule.contains_user_name',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern_type': 'user_localpart'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default',
+ }, {
+ 'set_tweak': 'highlight'
+ }
+ ]
+ },
+]
+
+
+BASE_PREPEND_OVERRIDE_RULES = [
+ {
+ 'rule_id': 'global/override/.m.rule.master',
+ 'enabled': False,
+ 'conditions': [],
+ 'actions': [
+ "dont_notify"
+ ]
+ }
+]
+
+
+BASE_APPEND_OVRRIDE_RULES = [
+ {
+ 'rule_id': 'global/override/.m.rule.suppress_notices',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.msgtype',
+ 'pattern': 'm.notice',
+ '_id': '_suppress_notices',
+ }
+ ],
+ 'actions': [
+ 'dont_notify',
+ ]
+ }
+]
+
+BASE_APPEND_UNDERRIDE_RULES = [
+ {
+ 'rule_id': 'global/underride/.m.rule.call',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.call.invite',
+ '_id': '_call',
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'ring'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.contains_display_name',
+ 'conditions': [
+ {
+ 'kind': 'contains_display_name'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight'
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.room_one_to_one',
+ 'conditions': [
+ {
+ 'kind': 'room_member_count',
+ 'is': '2',
+ '_id': 'member_count',
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.invite_for_me',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.member',
+ '_id': '_member',
+ },
+ {
+ 'kind': 'event_match',
+ 'key': 'content.membership',
+ 'pattern': 'invite',
+ '_id': '_invite_member',
+ },
+ {
+ 'kind': 'event_match',
+ 'key': 'state_key',
+ 'pattern_type': 'user_id'
+ },
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_tweak': 'sound',
+ 'value': 'default'
+ }, {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.member_event',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.member',
+ '_id': '_member',
+ }
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ },
+ {
+ 'rule_id': 'global/underride/.m.rule.message',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.message',
+ '_id': '_message',
+ }
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': False
+ }
+ ]
+ }
+]
-def make_base_prepend_override_rules():
- return [
- {
- 'rule_id': 'global/override/.m.rule.master',
- 'enabled': False,
- 'conditions': [],
- 'actions': [
- "dont_notify"
- ]
- }
- ]
+for r in BASE_APPEND_CONTENT_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['content']
+ r['default'] = True
-def make_base_append_override_rules():
- return [
- {
- 'rule_id': 'global/override/.m.rule.suppress_notices',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'content.msgtype',
- 'pattern': 'm.notice',
- }
- ],
- 'actions': [
- 'dont_notify',
- ]
- }
- ]
+for r in BASE_PREPEND_OVERRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['override']
+ r['default'] = True
+for r in BASE_APPEND_OVRRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['override']
+ r['default'] = True
-def make_base_append_underride_rules(user):
- return [
- {
- 'rule_id': 'global/underride/.m.rule.call',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.call.invite',
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'ring'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.contains_display_name',
- 'conditions': [
- {
- 'kind': 'contains_display_name'
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.room_one_to_one',
- 'conditions': [
- {
- 'kind': 'room_member_count',
- 'is': '2'
- }
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.invite_for_me',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- },
- {
- 'kind': 'event_match',
- 'key': 'content.membership',
- 'pattern': 'invite',
- },
- {
- 'kind': 'event_match',
- 'key': 'state_key',
- 'pattern': user.to_string(),
- },
- ],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.member_event',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- }
- ],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- },
- {
- 'rule_id': 'global/underride/.m.rule.message',
- 'enabled': False,
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- }
- ],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- }
- ]
+for r in BASE_APPEND_UNDERRIDE_RULES:
+ r['priority_class'] = PRIORITY_CLASS_MAP['underride']
+ r['default'] = True
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
new file mode 100644
index 0000000000..efd686fa6e
--- /dev/null
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import ujson as json
+
+from twisted.internet import defer
+
+import baserules
+from push_rule_evaluator import PushRuleEvaluatorForEvent
+
+from synapse.api.constants import EventTypes
+
+
+logger = logging.getLogger(__name__)
+
+
+def decode_rule_json(rule):
+ rule['conditions'] = json.loads(rule['conditions'])
+ rule['actions'] = json.loads(rule['actions'])
+ return rule
+
+
+@defer.inlineCallbacks
+def _get_rules(room_id, user_ids, store):
+ rules_by_user = yield store.bulk_get_push_rules(user_ids)
+ rules_by_user = {
+ uid: baserules.list_with_base_rules([
+ decode_rule_json(rule_list)
+ for rule_list in rules_by_user.get(uid, [])
+ ])
+ for uid in user_ids
+ }
+ defer.returnValue(rules_by_user)
+
+
+@defer.inlineCallbacks
+def evaluator_for_room_id(room_id, store):
+ users = yield store.get_users_in_room(room_id)
+ rules_by_user = yield _get_rules(room_id, users, store)
+
+ defer.returnValue(BulkPushRuleEvaluator(
+ room_id, rules_by_user, users, store
+ ))
+
+
+class BulkPushRuleEvaluator:
+ """
+ Runs push rules for all users in a room.
+ This is faster than running PushRuleEvaluator for each user because it
+ fetches all the rules for all the users in one (batched) db query
+ rather than doing multiple queries per-user. It currently uses
+ the same logic to run the actual rules, but could be optimised further
+ (see https://matrix.org/jira/browse/SYN-562)
+ """
+ def __init__(self, room_id, rules_by_user, users_in_room, store):
+ self.room_id = room_id
+ self.rules_by_user = rules_by_user
+ self.users_in_room = users_in_room
+ self.store = store
+
+ @defer.inlineCallbacks
+ def action_for_event_by_user(self, event, handler):
+ actions_by_user = {}
+
+ users_dict = yield self.store.are_guests(self.rules_by_user.keys())
+
+ filtered_by_user = yield handler._filter_events_for_clients(
+ users_dict.items(), [event]
+ )
+
+ evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
+
+ condition_cache = {}
+
+ member_state = yield self.store.get_state_for_event(
+ event.event_id,
+ )
+
+ display_names = {}
+ for ev in member_state.values():
+ nm = ev.content.get("displayname", None)
+ if nm and ev.type == EventTypes.Member:
+ display_names[ev.state_key] = nm
+
+ for uid, rules in self.rules_by_user.items():
+ display_name = display_names.get(uid, None)
+
+ filtered = filtered_by_user[uid]
+ if len(filtered) == 0:
+ continue
+
+ for rule in rules:
+ if 'enabled' in rule and not rule['enabled']:
+ continue
+
+ matches = _condition_checker(
+ evaluator, rule['conditions'], uid, display_name, condition_cache
+ )
+ if matches:
+ actions = [x for x in rule['actions'] if x != 'dont_notify']
+ if actions:
+ actions_by_user[uid] = actions
+ break
+ defer.returnValue(actions_by_user)
+
+
+def _condition_checker(evaluator, conditions, uid, display_name, cache):
+ for cond in conditions:
+ _id = cond.get("_id", None)
+ if _id:
+ res = cache.get(_id, None)
+ if res is False:
+ return False
+ elif res is True:
+ continue
+
+ res = evaluator.matches(cond, uid, display_name, None)
+ if _id:
+ cache[_id] = res
+
+ if not res:
+ return False
+
+ return True
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 7866db6a24..28f1fab0e4 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -23,13 +23,13 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
- def __init__(self, _hs, profile_tag, user_name, app_id,
+ def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
profile_tag,
- user_name,
+ user_id,
app_id,
app_display_name,
device_display_name,
@@ -87,7 +87,7 @@ class HttpPusher(Pusher):
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
- d['notification']['user_is_target'] = event['state_key'] == self.user_name
+ d['notification']['user_is_target'] = event['state_key'] == self.user_id
if 'content' in event:
d['notification']['content'] = event['content']
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index dec81566ba..379652c513 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,40 +15,70 @@
from twisted.internet import defer
-from synapse.types import UserID
-
import baserules
import logging
import simplejson as json
import re
+from synapse.types import UserID
+
logger = logging.getLogger(__name__)
+GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
+IS_GLOB = re.compile(r'[\?\*\[\]]')
+INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
+
+
@defer.inlineCallbacks
-def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store):
- rawrules = yield store.get_push_rules_for_user(user_name)
- enabled_map = yield store.get_push_rules_enabled_for_user(user_name)
+def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
+ rawrules = yield store.get_push_rules_for_user(user_id)
+ enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
- state_key=user_name,
+ state_key=user_id,
)
defer.returnValue(PushRuleEvaluator(
- user_name, profile_tag, rawrules, enabled_map,
+ user_id, profile_tag, rawrules, enabled_map,
room_id, our_member_event, store
))
+def _room_member_count(ev, condition, room_member_count):
+ if 'is' not in condition:
+ return False
+ m = INEQUALITY_EXPR.match(condition['is'])
+ if not m:
+ return False
+ ineq = m.group(1)
+ rhs = m.group(2)
+ if not rhs.isdigit():
+ return False
+ rhs = int(rhs)
+
+ if ineq == '' or ineq == '==':
+ return room_member_count == rhs
+ elif ineq == '<':
+ return room_member_count < rhs
+ elif ineq == '>':
+ return room_member_count > rhs
+ elif ineq == '>=':
+ return room_member_count >= rhs
+ elif ineq == '<=':
+ return room_member_count <= rhs
+ else:
+ return False
+
+
class PushRuleEvaluator:
- DEFAULT_ACTIONS = ['dont_notify']
- INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
+ DEFAULT_ACTIONS = []
- def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id,
+ def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
- self.user_name = user_name
+ self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id
self.our_member_event = our_member_event
@@ -61,8 +91,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
- user = UserID.from_string(self.user_name)
- self.rules = baserules.list_with_base_rules(rules, user)
+ self.rules = baserules.list_with_base_rules(rules)
self.enabled_map = enabled_map
@@ -83,9 +112,9 @@ class PushRuleEvaluator:
has configured both globally and per-room when we have the ability
to do such things.
"""
- if ev['user_id'] == self.user_name:
+ if ev['user_id'] == self.user_id:
# let's assume you probably know about messages you sent yourself
- defer.returnValue(['dont_notify'])
+ defer.returnValue([])
room_id = ev['room_id']
@@ -98,127 +127,178 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
+ evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
+
for r in self.rules:
- if r['rule_id'] in self.enabled_map:
- r['enabled'] = self.enabled_map[r['rule_id']]
- elif 'enabled' not in r:
- r['enabled'] = True
- if not r['enabled']:
+ enabled = self.enabled_map.get(r['rule_id'], None)
+ if enabled is not None and not enabled:
+ continue
+
+ if not r.get("enabled", True):
continue
- matches = True
conditions = r['conditions']
actions = r['actions']
- for c in conditions:
- matches &= self._event_fulfills_condition(
- ev, c, display_name=my_display_name,
- room_member_count=room_member_count
- )
- logger.debug(
- "Rule %s %s",
- r['rule_id'], "matches" if matches else "doesn't match"
- )
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
- r['rule_id'], self.user_name
+ r['rule_id'], self.user_id
)
continue
+
+ matches = True
+ for c in conditions:
+ matches = evaluator.matches(
+ c, self.user_id, my_display_name, self.profile_tag
+ )
+ if not matches:
+ break
+
+ logger.debug(
+ "Rule %s %s",
+ r['rule_id'], "matches" if matches else "doesn't match"
+ )
+
if matches:
- logger.info(
+ logger.debug(
"%s matches for user %s, event %s",
- r['rule_id'], self.user_name, ev['event_id']
+ r['rule_id'], self.user_id, ev['event_id']
)
+
+ # filter out dont_notify as we treat an empty actions list
+ # as dont_notify, and this doesn't take up a row in our database
+ actions = [x for x in actions if x != 'dont_notify']
+
defer.returnValue(actions)
- logger.info(
+ logger.debug(
"No rules match for user %s, event %s",
- self.user_name, ev['event_id']
+ self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
- @staticmethod
- def _glob_to_regexp(glob):
- r = re.escape(glob)
- r = re.sub(r'\\\*', r'.*?', r)
- r = re.sub(r'\\\?', r'.', r)
- # handle [abc], [a-z] and [!a-z] style ranges.
- r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
- lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
- re.sub(r'\\\-', '-', x.group(2)))), r)
- return r
+class PushRuleEvaluatorForEvent(object):
+ def __init__(self, event, room_member_count):
+ self._event = event
+ self._room_member_count = room_member_count
- def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
- if condition['kind'] == 'event_match':
- if 'pattern' not in condition:
- logger.warn("event_match condition with no pattern")
- return False
- # XXX: optimisation: cache our pattern regexps
- if condition['key'] == 'content.body':
- r = r'\b%s\b' % self._glob_to_regexp(condition['pattern'])
- else:
- r = r'^%s$' % self._glob_to_regexp(condition['pattern'])
- val = _value_for_dotted_key(condition['key'], ev)
- if val is None:
- return False
- return re.search(r, val, flags=re.IGNORECASE) is not None
+ # Maps strings of e.g. 'content.body' -> event["content"]["body"]
+ self._value_cache = _flatten_dict(event)
+ def matches(self, condition, user_id, display_name, profile_tag):
+ if condition['kind'] == 'event_match':
+ return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
- return condition['profile_tag'] == self.profile_tag
-
+ return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name':
- # This is special because display names can be different
- # between rooms and so you can't really hard code it in a rule.
- # Optimisation: we should cache these names and update them from
- # the event stream.
- if 'content' not in ev or 'body' not in ev['content']:
- return False
- if not display_name:
- return False
- return re.search(
- r"\b%s\b" % re.escape(display_name), ev['content']['body'],
- flags=re.IGNORECASE
- ) is not None
-
+ return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count':
- if 'is' not in condition:
- return False
- m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
- if not m:
- return False
- ineq = m.group(1)
- rhs = m.group(2)
- if not rhs.isdigit():
+ return _room_member_count(
+ self._event, condition, self._room_member_count
+ )
+ else:
+ return True
+
+ def _event_match(self, condition, user_id):
+ pattern = condition.get('pattern', None)
+
+ if not pattern:
+ pattern_type = condition.get('pattern_type', None)
+ if pattern_type == "user_id":
+ pattern = user_id
+ elif pattern_type == "user_localpart":
+ pattern = UserID.from_string(user_id).localpart
+
+ if not pattern:
+ logger.warn("event_match condition with no pattern")
+ return False
+
+ # XXX: optimisation: cache our pattern regexps
+ if condition['key'] == 'content.body':
+ body = self._event["content"].get("body", None)
+ if not body:
return False
- rhs = int(rhs)
-
- if ineq == '' or ineq == '==':
- return room_member_count == rhs
- elif ineq == '<':
- return room_member_count < rhs
- elif ineq == '>':
- return room_member_count > rhs
- elif ineq == '>=':
- return room_member_count >= rhs
- elif ineq == '<=':
- return room_member_count <= rhs
- else:
+
+ return _glob_matches(pattern, body, word_boundary=True)
+ else:
+ haystack = self._get_value(condition['key'])
+ if haystack is None:
return False
+
+ return _glob_matches(pattern, haystack)
+
+ def _contains_display_name(self, display_name):
+ if not display_name:
+ return False
+
+ body = self._event["content"].get("body", None)
+ if not body:
+ return False
+
+ return _glob_matches(display_name, body, word_boundary=True)
+
+ def _get_value(self, dotted_key):
+ return self._value_cache.get(dotted_key, None)
+
+
+def _glob_matches(glob, value, word_boundary=False):
+ """Tests if value matches glob.
+
+ Args:
+ glob (string)
+ value (string): String to test against glob.
+ word_boundary (bool): Whether to match against word boundaries or entire
+ string. Defaults to False.
+
+ Returns:
+ bool
+ """
+ if IS_GLOB.search(glob):
+ r = re.escape(glob)
+
+ r = r.replace(r'\*', '.*?')
+ r = r.replace(r'\?', '.')
+
+ # handle [abc], [a-z] and [!a-z] style ranges.
+ r = GLOB_REGEX.sub(
+ lambda x: (
+ '[%s%s]' % (
+ x.group(1) and '^' or '',
+ x.group(2).replace(r'\\\-', '-')
+ )
+ ),
+ r,
+ )
+ if word_boundary:
+ r = r"\b%s\b" % (r,)
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.search(value)
else:
- return True
+ r = r + "$"
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.match(value)
+ elif word_boundary:
+ r = re.escape(glob)
+ r = r"\b%s\b" % (r,)
+ r = re.compile(r, flags=re.IGNORECASE)
+
+ return r.search(value)
+ else:
+ return value.lower() == glob.lower()
+
+def _flatten_dict(d, prefix=[], result={}):
+ for key, value in d.items():
+ if isinstance(value, basestring):
+ result[".".join(prefix + [key])] = value.lower()
+ elif hasattr(value, "items"):
+ _flatten_dict(value, prefix=(prefix+[key]), result=result)
-def _value_for_dotted_key(dotted_key, event):
- parts = dotted_key.split(".")
- val = event
- while len(parts) > 0:
- if parts[0] not in val:
- return None
- val = val[parts[0]]
- parts = parts[1:]
- return val
+ return result
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 4208e5c76c..12c4af14bd 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -38,12 +38,12 @@ class PusherPool:
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
- user_name = user.to_string()
+ user_id = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
- if p.user_name == user_name:
+ if p.user_id == user_id:
yield p.presence_changed(state)
@defer.inlineCallbacks
@@ -52,14 +52,14 @@ class PusherPool:
self._start_pushers(pushers)
@defer.inlineCallbacks
- def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
+ def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
- "user_name": user_name,
+ "user_name": user_id,
"kind": kind,
"profile_tag": profile_tag,
"app_id": app_id,
@@ -74,7 +74,7 @@ class PusherPool:
"failing_since": None
})
yield self._add_pusher_to_store(
- user_name, access_token, profile_tag, kind, app_id,
+ user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@@ -109,11 +109,11 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
- def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
+ def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
- user_name=user_name,
+ user_id=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
@@ -125,14 +125,14 @@ class PusherPool:
lang=lang,
data=data,
)
- self._refresh_pusher(app_id, pushkey, user_name)
+ self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
profile_tag=pusherdict['profile_tag'],
- user_name=pusherdict['user_name'],
+ user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
@@ -150,14 +150,14 @@ class PusherPool:
)
@defer.inlineCallbacks
- def _refresh_pusher(self, app_id, pushkey, user_name):
+ def _refresh_pusher(self, app_id, pushkey, user_id):
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
p = None
for r in resultlist:
- if r['user_name'] == user_name:
+ if r['user_name'] == user_id:
p = r
if p:
@@ -186,12 +186,12 @@ class PusherPool:
logger.info("Started pushers")
@defer.inlineCallbacks
- def remove_pusher(self, app_id, pushkey, user_name):
- fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
+ def remove_pusher(self, app_id, pushkey, user_id):
+ fullid = "%s:%s:%s" % (app_id, pushkey, user_id)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
- yield self.store.delete_pusher_by_app_id_pushkey_user_name(
- app_id, pushkey, user_name
+ yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id, pushkey, user_id
)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 4d724dce72..e2f5eb7b29 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(auth_user)
+ requester = yield self.auth.get_user_by_req(request)
+ auth_user = requester.user
+ is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 7eef6bf5dc..74ec1e50e0 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
try:
# try to auth as a user
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
try:
- user_id = user.to_string()
+ user_id = requester.user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
@@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
- user, _, _ = yield self.auth.get_user_by_req(request)
-
+ requester = yield self.auth.get_user_by_req(request)
+ user = requester.user
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 631f2ca052..e89118b37d 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- auth_user, _, is_guest = yield self.auth.get_user_by_req(
+ requester = yield self.auth.get_user_by_req(
request,
- allow_guest=True
+ allow_guest=True,
)
+ is_guest = requester.is_guest
room_id = None
if is_guest:
if "room_id" not in request.args:
@@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
- auth_user.to_string(), pagin_config, timeout=timeout,
- as_client_event=as_client_event, affect_presence=(not is_guest),
- room_id=room_id, is_guest=is_guest
+ requester.user.to_string(),
+ pagin_config,
+ timeout=timeout,
+ as_client_event=as_client_event,
+ affect_presence=(not is_guest),
+ room_id=room_id,
+ is_guest=is_guest,
)
except:
logger.exception("Event stream failed")
@@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
- event = yield handler.get_event(auth_user, event_id)
+ event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 541319c351..ad161bdbab 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms(
- user_id=user.to_string(),
+ user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
include_archived=include_archived,
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 855385ec16..a6f8754e32 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
- target_user=user, auth_user=auth_user)
+ target_user=user, auth_user=requester.user)
defer.returnValue((200, state))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = {}
@@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
- target_user=user, auth_user=auth_user, state=state)
+ target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {}))
@@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
- if auth_user != user:
+ if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
@@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
- if auth_user != user:
+ if requester.user != user:
raise SynapseError(
400, "Cannot modify another user's presence list")
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index d4bc9e076c..b15defdd07 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
try:
@@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname(
- user, auth_user, new_name)
+ user, requester.user, new_name)
defer.returnValue((200, {}))
@@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
@@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url(
- user, auth_user, new_name)
+ user, requester.user, new_name)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 2aab28ae7b..2272d66dc7 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
+import copy
import simplejson as json
@@ -43,7 +44,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
@@ -51,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request)
if 'attr' in spec:
- self.set_rule_attr(user.to_string(), spec, content)
+ self.set_rule_attr(requester.user.to_string(), spec, content)
defer.returnValue((200, {}))
try:
@@ -73,7 +74,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try:
yield self.hs.get_datastore().add_push_rule(
- user_name=user.to_string(),
+ user_id=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
@@ -92,13 +93,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
- user.to_string(), namespaced_rule_id
+ requester.user.to_string(), namespaced_rule_id
)
defer.returnValue((200, {}))
except StoreError as e:
@@ -109,7 +110,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
+ user = requester.user
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
@@ -125,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
- ruleslist = baserules.list_with_base_rules(ruleslist, user)
+ # We're going to be mutating this a lot, so do a deep copy
+ ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
@@ -139,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class'])
+ # Remove internal stuff.
+ for c in r["conditions"]:
+ c.pop("_id", None)
+
+ pattern_type = c.pop("pattern_type", None)
+ if pattern_type == "user_id":
+ c["pattern"] = user.to_string()
+ elif pattern_type == "user_localpart":
+ c["pattern"] = user.localpart
+
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
@@ -205,7 +218,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- def set_rule_attr(self, user_name, spec, val):
+ def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -216,15 +229,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled(
- user_name, namespaced_rule_id, val
+ user_id, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
- def get_rule_attr(self, user_name, namespaced_rule_id, attr):
+ def get_rule_attr(self, user_id, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
- user_name, namespaced_rule_id
+ user_id, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 81a8786aeb..e218ed215c 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- user, token_id, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
+ user = requester.user
content = _parse_json(request)
@@ -40,7 +41,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
- content['app_id'], content['pushkey'], user_name=user.to_string()
+ content['app_id'], content['pushkey'], user_id=user.to_string()
)
defer.returnValue((200, {}))
@@ -70,8 +71,8 @@ class PusherRestServlet(ClientV1RestServlet):
try:
yield pusher_pool.add_pusher(
- user_name=user.to_string(),
- access_token=token_id,
+ user_id=user.to_string(),
+ access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'],
app_id=content['app_id'],
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 926f77d1c3..85b9f253e3 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
- info = yield self.make_room(room_config, auth_user, None)
+ info = yield self.make_room(
+ room_config,
+ requester.user,
+ None,
+ )
room_config.update(info)
defer.returnValue((200, info))
@@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
- user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
- user_id=user.to_string(),
+ user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
state_key=state_key,
- is_guest=is_guest,
+ is_guest=requester.is_guest,
)
if not data:
@@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
- user, token_id, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"type": event_type,
"content": content,
"room_id": room_id,
- "sender": user.to_string(),
+ "sender": requester.user.to_string(),
}
if state_key is not None:
@@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
- event_dict, token_id=token_id, txn_id=txn_id,
+ event_dict, token_id=requester.access_token_id, txn_id=txn_id,
)
defer.returnValue((200, {}))
@@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None):
- user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"type": event_type,
"content": content,
"room_id": room_id,
- "sender": user.to_string(),
+ "sender": requester.user.to_string(),
},
- token_id=token_id,
+ token_id=requester.access_token_id,
txn_id=txn_id,
)
@@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None):
- user, token_id, is_guest = yield self.auth.get_user_by_req(
+ requester = yield self.auth.get_user_by_req(
request,
- allow_guest=True
+ allow_guest=True,
)
# the identifier could be a room alias or a room id. Try one then the
@@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if is_room_alias:
handler = self.handlers.room_member_handler
- ret_dict = yield handler.join_room_alias(user, identifier)
+ ret_dict = yield handler.join_room_alias(
+ requester.user,
+ identifier,
+ )
defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
- if is_guest:
+ if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": identifier.to_string(),
- "sender": user.to_string(),
- "state_key": user.to_string(),
+ "sender": requester.user.to_string(),
+ "state_key": requester.user.to_string(),
},
- token_id=token_id,
+ token_id=requester.access_token_id,
txn_id=txn_id,
- is_guest=is_guest,
+ is_guest=requester.is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
events = yield handler.get_state_events(
room_id=room_id,
- user_id=user.to_string(),
+ user_id=requester.user.to_string(),
)
chunk = []
@@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
- target_user=target_user, auth_user=user
+ target_user=target_user,
+ auth_user=requester.user,
)
event["content"].update(presence_state)
except:
@@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
- user_id=user.to_string(),
- is_guest=is_guest,
+ user_id=requester.user.to_string(),
+ is_guest=requester.is_guest,
pagin_config=pagination_config,
as_client_event=as_client_event
)
@@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=room_id,
- user_id=user.to_string(),
- is_guest=is_guest,
+ user_id=requester.user.to_string(),
+ is_guest=requester.is_guest,
)
defer.returnValue((200, events))
@@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
- user_id=user.to_string(),
+ user_id=requester.user.to_string(),
pagin_config=pagination_config,
- is_guest=is_guest,
+ is_guest=requester.is_guest,
)
defer.returnValue((200, content))
@@ -394,18 +402,28 @@ class RoomEventContext(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
- user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context(
- user, room_id, event_id, limit, is_guest
+ requester.user,
+ room_id,
+ event_id,
+ limit,
+ requester.is_guest,
)
+ if not results:
+ raise SynapseError(
+ 404, "Event not found.", errcode=Codes.NOT_FOUND
+ )
+
time_now = self.clock.time_msec()
results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"]
]
+ results["event"] = serialize_event(results["event"], time_now)
results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"]
]
@@ -424,74 +442,51 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|kick|forget)")
+ "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None):
- user, token_id, is_guest = yield self.auth.get_user_by_req(
+ requester = yield self.auth.get_user_by_req(
request,
- allow_guest=True
+ allow_guest=True,
)
- effective_membership_action = membership_action
-
- if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
+ if requester.is_guest and membership_action not in {
+ Membership.JOIN,
+ Membership.LEAVE
+ }:
raise AuthError(403, "Guest access not allowed")
content = _parse_json(request)
- # target user is you unless it is an invite
- state_key = user.to_string()
-
if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite(
room_id,
- user,
+ requester.user,
content["medium"],
content["address"],
content["id_server"],
- token_id,
+ requester.access_token_id,
txn_id
)
defer.returnValue((200, {}))
return
- elif membership_action in ["invite", "ban", "kick"]:
- if "user_id" in content:
- state_key = content["user_id"]
- else:
- raise SynapseError(400, "Missing user_id key.")
-
- # make sure it looks like a user ID; it'll throw if it's invalid.
- UserID.from_string(state_key)
-
- if membership_action == "kick":
- effective_membership_action = "leave"
- elif membership_action == "forget":
- effective_membership_action = "leave"
- msg_handler = self.handlers.message_handler
-
- content = {"membership": unicode(effective_membership_action)}
- if is_guest:
- content["kind"] = "guest"
+ target = requester.user
+ if membership_action in ["invite", "ban", "unban", "kick"]:
+ if "user_id" not in content:
+ raise SynapseError(400, "Missing user_id key.")
+ target = UserID.from_string(content["user_id"])
- yield msg_handler.create_and_send_event(
- {
- "type": EventTypes.Member,
- "content": content,
- "room_id": room_id,
- "sender": user.to_string(),
- "state_key": state_key,
- },
- token_id=token_id,
+ yield self.handlers.room_member_handler.update_membership(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ action=membership_action,
txn_id=txn_id,
- is_guest=is_guest,
)
- if membership_action == "forget":
- yield self.handlers.room_member_handler.forget(user, room_id)
-
defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
@@ -524,7 +519,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None):
- user, token_id, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -533,10 +528,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
- "sender": user.to_string(),
+ "sender": requester.user.to_string(),
"redacts": event_id,
},
- token_id=token_id,
+ token_id=requester.access_token_id,
txn_id=txn_id,
)
@@ -564,7 +559,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))
@@ -576,14 +571,14 @@ class RoomTypingRestServlet(ClientV1RestServlet):
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,
- auth_user=auth_user,
+ auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
target_user=target_user,
- auth_user=auth_user,
+ auth_user=requester.user,
room_id=room_id,
)
@@ -597,12 +592,16 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
batch = request.args.get("next_batch", [None])[0]
- results = yield self.handlers.search_handler.search(auth_user, content, batch)
+ results = yield self.handlers.search_handler.search(
+ requester.user,
+ content,
+ batch,
+ )
defer.returnValue((200, results))
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 860cb0a642..ec4cf8db79 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
@@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
- username = "%d:%s" % (expiry, auth_user.to_string())
+ username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ddb6f041cd..fa56249a69 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if auth_user.to_string() != result[LoginType.PASSWORD]:
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+ if requester_user_id.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
- user_id = auth_user.to_string()
+ user_id = requester_user_id
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
@@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request):
yield run_on_reactor()
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
- auth_user.to_string()
+ requester.user.to_string()
)
defer.returnValue((200, {'threepids': threepids}))
@@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
@@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid(
- auth_user.to_string(),
+ user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
@@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']:
logger.debug(
"Binding emails %s to %s",
- threepid, auth_user.to_string()
+ threepid, user_id
)
yield self.identity_handler.bind_threepid(
- threePidCreds, auth_user.to_string()
+ threePidCreds, user_id
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 629b04fe7a..985efe2a62 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if user_id != auth_user.to_string():
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
@@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if user_id != auth_user.to_string():
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 2af7bfaf99..7695bebc28 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- if target_user != auth_user:
+ if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
@@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- if target_user != auth_user:
+ if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24c3554831..f989b08614 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, device_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- user_id = auth_user.to_string()
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
try:
@@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet):
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
- "Updating device_keys for device %r for user %r at %d",
- device_id, auth_user, time_now
+ "Updating device_keys for device %r for user %s at %d",
+ device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
@@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, device_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- user_id = auth_user.to_string()
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- auth_user_id = auth_user.to_string()
+ requester = yield self.auth.get_user_by_req(request)
+ auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 43c23d6090..eb4b369a3d 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
- user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
@@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet):
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,
- user_id=user.to_string(),
+ user_id=requester.user.to_string(),
event_id=event_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 25389ceded..c4d025b465 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -259,7 +259,10 @@ class RegisterRestServlet(RestServlet):
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
- user_id, _ = yield self.registration_handler.register(generate_token=False)
+ user_id, _ = yield self.registration_handler.register(
+ generate_token=False,
+ make_guest=True
+ )
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, {
"user_id": user_id,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index c05e7d50c8..e300ced214 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- user, token_id, is_guest = yield self.auth.get_user_by_req(
+ requester = yield self.auth.get_user_by_req(
request, allow_guest=True
)
+ user = requester.user
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
@@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet):
sync_config = SyncConfig(
user=user,
filter=filter,
- is_guest=is_guest,
+ is_guest=requester.is_guest,
)
if since is not None:
@@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
joined = self.encode_joined(
- sync_result.joined, filter, time_now, token_id
+ sync_result.joined, filter, time_now, requester.access_token_id
)
invited = self.encode_invited(
- sync_result.invited, filter, time_now, token_id
+ sync_result.invited, filter, time_now, requester.access_token_id
)
archived = self.encode_archived(
- sync_result.archived, filter, time_now, token_id
+ sync_result.archived, filter, time_now, requester.access_token_id
)
response_content = {
@@ -311,6 +312,8 @@ class SyncRestServlet(RestServlet):
if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events}
+ result["unread_notification_count"] = room.unread_notification_count
+ result["unread_highlight_count"] = room.unread_highlight_count
return result
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 1bfc36ab2b..42f2203f3d 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -42,8 +42,8 @@ class TagListServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if user_id != auth_user.to_string():
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
@@ -68,8 +68,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if user_id != auth_user.to_string():
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
@@ -88,8 +88,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
- if user_id != auth_user.to_string():
+ requester = yield self.auth.get_user_by_req(request)
+ if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index dd7a1b2b31..dcf3eaee1f 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
- auth_user.to_string()
+ requester.user.to_string()
).replace('=', '')
# use a random string for the main portion
@@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource):
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
- auth_user.to_string(),
+ request.user.user_id.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index c18160534e..ab52499785 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -253,7 +253,7 @@ class ThumbnailResource(BaseMediaResource):
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
- if t_method == "scale" or t_method == "crop":
+ if t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index c1e895ee81..9c7ad4ae85 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
- auth_user, _, _ = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
@@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource):
content_uri = yield self.create_content(
media_type, upload_name, request.content.read(),
- content_length, auth_user
+ content_length, requester.user
)
respond_with_json(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2b650f9fa3..7a3f6c4662 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -33,6 +33,7 @@ from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
+from .event_push_actions import EventPushActionsStore
from .state import StateStore
from .signatures import SignatureStore
@@ -75,6 +76,7 @@ class DataStore(RoomMemberStore, RoomStore,
SearchStore,
TagsStore,
AccountDataStore,
+ EventPushActionsStore
):
def __init__(self, hs):
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index eab58d9ce9..b5aa55c0a3 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,12 +15,12 @@
import logging
import urllib
import yaml
-from simplejson import JSONDecodeError
import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.config._base import ConfigError
from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore
@@ -144,66 +144,9 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
- def _parse_services_dict(self, results):
- # SQL results in the form:
- # [
- # {
- # 'regex': "something",
- # 'url': "something",
- # 'namespace': enum,
- # 'as_id': 0,
- # 'token': "something",
- # 'hs_token': "otherthing",
- # 'id': 0
- # }
- # ]
- services = {}
- for res in results:
- as_token = res["token"]
- if as_token is None:
- continue
- if as_token not in services:
- # add the service
- services[as_token] = {
- "id": res["id"],
- "url": res["url"],
- "token": as_token,
- "hs_token": res["hs_token"],
- "sender": res["sender"],
- "namespaces": {
- ApplicationService.NS_USERS: [],
- ApplicationService.NS_ALIASES: [],
- ApplicationService.NS_ROOMS: []
- }
- }
- # add the namespace regex if one exists
- ns_int = res["namespace"]
- if ns_int is None:
- continue
- try:
- services[as_token]["namespaces"][
- ApplicationService.NS_LIST[ns_int]].append(
- json.loads(res["regex"])
- )
- except IndexError:
- logger.error("Bad namespace enum '%s'. %s", ns_int, res)
- except JSONDecodeError:
- logger.error("Bad regex object '%s'", res["regex"])
-
- service_list = []
- for service in services.values():
- service_list.append(ApplicationService(
- token=service["token"],
- url=service["url"],
- namespaces=service["namespaces"],
- hs_token=service["hs_token"],
- sender=service["sender"],
- id=service["id"]
- ))
- return service_list
-
def _load_appservice(self, as_info):
required_string_fields = [
+ # TODO: Add id here when it's stable to release
"url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
@@ -245,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore):
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
- id=as_info["as_token"] # the token is the only unique thing here
+ id=as_info["id"] if "id" in as_info else as_info["as_token"],
)
def _populate_appservice_cache(self, config_files):
@@ -256,15 +199,38 @@ class ApplicationServiceStore(SQLBaseStore):
)
return
+ # Dicts of value -> filename
+ seen_as_tokens = {}
+ seen_ids = {}
+
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f))
+ if appservice.id in seen_ids:
+ raise ConfigError(
+ "Cannot reuse ID across application services: "
+ "%s (files: %s, %s)" % (
+ appservice.id, config_file, seen_ids[appservice.id],
+ )
+ )
+ seen_ids[appservice.id] = config_file
+ if appservice.token in seen_as_tokens:
+ raise ConfigError(
+ "Cannot reuse as_token across application services: "
+ "%s (files: %s, %s)" % (
+ appservice.token,
+ config_file,
+ seen_as_tokens[appservice.token],
+ )
+ )
+ seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
+ raise
class ApplicationServiceTransactionStore(SQLBaseStore):
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
new file mode 100644
index 0000000000..6b7cebc9ce
--- /dev/null
+++ b/synapse/storage/event_push_actions.py
@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+import logging
+import ujson as json
+
+logger = logging.getLogger(__name__)
+
+
+class EventPushActionsStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def set_push_actions_for_event_and_users(self, event, tuples):
+ """
+ :param event: the event set actions for
+ :param tuples: list of tuples of (user_id, profile_tag, actions)
+ """
+ values = []
+ for uid, profile_tag, actions in tuples:
+ values.append({
+ 'room_id': event.room_id,
+ 'event_id': event.event_id,
+ 'user_id': uid,
+ 'profile_tag': profile_tag,
+ 'actions': json.dumps(actions)
+ })
+
+ yield self.runInteraction(
+ "set_actions_for_event_and_users",
+ self._simple_insert_many_txn,
+ "event_push_actions",
+ values
+ )
+
+ @defer.inlineCallbacks
+ def get_unread_event_push_actions_by_room_for_user(
+ self, room_id, user_id, last_read_event_id
+ ):
+ def _get_unread_event_push_actions_by_room(txn):
+ sql = (
+ "SELECT stream_ordering, topological_ordering"
+ " FROM events"
+ " WHERE room_id = ? AND event_id = ?"
+ )
+ txn.execute(
+ sql, (room_id, last_read_event_id)
+ )
+ results = txn.fetchall()
+ if len(results) == 0:
+ return []
+
+ stream_ordering = results[0][0]
+ topological_ordering = results[0][1]
+
+ sql = (
+ "SELECT ea.event_id, ea.actions"
+ " FROM event_push_actions ea, events e"
+ " WHERE ea.room_id = e.room_id"
+ " AND ea.event_id = e.event_id"
+ " AND ea.user_id = ?"
+ " AND ea.room_id = ?"
+ " AND ("
+ " e.topological_ordering > ?"
+ " OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
+ ")"
+ )
+ txn.execute(sql, (
+ user_id, room_id,
+ topological_ordering, topological_ordering, stream_ordering
+ )
+ )
+ return [
+ {"event_id": row[0], "actions": json.loads(row[1])}
+ for row in txn.fetchall()
+ ]
+
+ ret = yield self.runInteraction(
+ "get_unread_event_push_actions_by_room",
+ _get_unread_event_push_actions_by_room
+ )
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def remove_push_actions_for_event_id(self, room_id, event_id):
+ def f(txn):
+ txn.execute(
+ "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+ (room_id, event_id)
+ )
+ yield self.runInteraction(
+ "remove_push_actions_for_event_id",
+ f
+ )
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index a913ea7c50..2adfefd994 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -25,13 +25,16 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks()
- def get_push_rules_for_user(self, user_name):
+ def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
- table=PushRuleTable.table_name,
+ table="push_rules",
keyvalues={
- "user_name": user_name,
+ "user_name": user_id,
},
- retcols=PushRuleTable.fields,
+ retcols=(
+ "user_name", "rule_id", "priority_class", "priority",
+ "conditions", "actions",
+ ),
desc="get_push_rules_enabled_for_user",
)
@@ -42,13 +45,15 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
@cachedInlineCallbacks()
- def get_push_rules_enabled_for_user(self, user_name):
+ def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
- table=PushRuleEnableTable.table_name,
+ table="push_rules_enable",
keyvalues={
- 'user_name': user_name
+ 'user_name': user_id
},
- retcols=PushRuleEnableTable.fields,
+ retcols=(
+ "user_name", "rule_id", "enabled",
+ ),
desc="get_push_rules_enabled_for_user",
)
defer.returnValue({
@@ -56,6 +61,39 @@ class PushRuleStore(SQLBaseStore):
})
@defer.inlineCallbacks
+ def bulk_get_push_rules(self, user_ids):
+ if not user_ids:
+ defer.returnValue({})
+
+ batch_size = 100
+
+ def f(txn, user_ids_to_fetch):
+ sql = (
+ "SELECT pr.*"
+ " FROM push_rules AS pr"
+ " LEFT JOIN push_rules_enable AS pre"
+ " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
+ " WHERE pr.user_name"
+ " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
+ " AND (pre.enabled IS NULL OR pre.enabled = 1)"
+ " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
+ )
+ txn.execute(sql, user_ids_to_fetch)
+ return self.cursor_to_dict(txn)
+
+ results = {}
+
+ chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
+ for batch_user_ids in chunks:
+ rows = yield self.runInteraction(
+ "bulk_get_push_rules", f, batch_user_ids
+ )
+
+ for row in rows:
+ results.setdefault(row['user_name'], []).append(row)
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = kwargs
if 'conditions' in vals:
@@ -84,15 +122,15 @@ class PushRuleStore(SQLBaseStore):
)
defer.returnValue(ret)
- def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
+ def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after)
res = self._simple_select_one_txn(
txn,
- table=PushRuleTable.table_name,
+ table="push_rules",
keyvalues={
- "user_name": user_name,
+ "user_name": user_id,
"rule_id": relative_to_rule,
},
retcols=["priority_class", "priority"],
@@ -116,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
new_rule.pop("before", None)
new_rule.pop("after", None)
new_rule['priority_class'] = priority_class
- new_rule['user_name'] = user_name
+ new_rule['user_name'] = user_id
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free
@@ -129,16 +167,16 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_rule_priority
sql = (
- "SELECT COUNT(*) FROM " + PushRuleTable.table_name +
+ "SELECT COUNT(*) FROM push_rules"
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
- txn.execute(sql, (user_name, priority_class, new_rule_priority))
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
# if there are conflicting rules, bump everything
if num_conflicting:
- sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
+ sql = "UPDATE push_rules SET priority = priority "
if after:
sql += "-1"
else:
@@ -149,30 +187,30 @@ class PushRuleStore(SQLBaseStore):
else:
sql += ">= ?"
- txn.execute(sql, (user_name, priority_class, new_rule_priority))
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_name,)
+ self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+ self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
self._simple_insert_txn(
txn,
- table=PushRuleTable.table_name,
+ table="push_rules",
values=new_rule,
)
- def _add_push_rule_highest_priority_txn(self, txn, user_name,
+ def _add_push_rule_highest_priority_txn(self, txn, user_id,
priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
- "SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name +
+ "SELECT COUNT(*), MAX(priority) FROM push_rules"
" WHERE user_name = ? and priority_class = ?"
)
- txn.execute(sql, (user_name, priority_class))
+ txn.execute(sql, (user_id, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
@@ -183,66 +221,66 @@ class PushRuleStore(SQLBaseStore):
# and insert the new rule
new_rule = kwargs
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
- new_rule['user_name'] = user_name
+ new_rule['user_name'] = user_id
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_name,)
+ self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+ self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
self._simple_insert_txn(
txn,
- table=PushRuleTable.table_name,
+ table="push_rules",
values=new_rule,
)
@defer.inlineCallbacks
- def delete_push_rule(self, user_name, rule_id):
+ def delete_push_rule(self, user_id, rule_id):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_name (str): The matrix ID of the push rule owner
+ user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(
- PushRuleTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id},
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
desc="delete_push_rule",
)
- self.get_push_rules_for_user.invalidate((user_name,))
- self.get_push_rules_enabled_for_user.invalidate((user_name,))
+ self.get_push_rules_for_user.invalidate((user_id,))
+ self.get_push_rules_enabled_for_user.invalidate((user_id,))
@defer.inlineCallbacks
- def set_push_rule_enabled(self, user_name, rule_id, enabled):
+ def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
- user_name, rule_id, enabled
+ user_id, rule_id, enabled
)
defer.returnValue(ret)
- def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
+ def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
self._simple_upsert_txn(
txn,
- PushRuleEnableTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id},
+ "push_rules_enable",
+ {'user_name': user_id, 'rule_id': rule_id},
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_name,)
+ self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_name,)
+ self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
@@ -252,27 +290,3 @@ class RuleNotFoundException(Exception):
class InconsistentRuleException(Exception):
pass
-
-
-class PushRuleTable(object):
- table_name = "push_rules"
-
- fields = [
- "id",
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
- ]
-
-
-class PushRuleEnableTable(object):
- table_name = "push_rules_enable"
-
- fields = [
- "user_name",
- "rule_id",
- "enabled"
- ]
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index b9568dad26..8ec706178a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -80,17 +80,17 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
@defer.inlineCallbacks
- def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
+ def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data):
try:
next_id = yield self._pushers_id_gen.get_next()
yield self._simple_upsert(
- PushersTable.table_name,
+ "pushers",
dict(
app_id=app_id,
pushkey=pushkey,
- user_name=user_name,
+ user_name=user_id,
),
dict(
access_token=access_token,
@@ -112,42 +112,38 @@ class PusherStore(SQLBaseStore):
raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
+ def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
yield self._simple_delete_one(
- PushersTable.table_name,
- {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
- desc="delete_pusher_by_app_id_pushkey_user_name",
+ "pushers",
+ {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
+ desc="delete_pusher_by_app_id_pushkey_user_id",
)
@defer.inlineCallbacks
- def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
+ def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
yield self._simple_update_one(
- PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ "pushers",
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token},
desc="update_pusher_last_token",
)
@defer.inlineCallbacks
- def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
+ def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
last_token, last_success):
yield self._simple_update_one(
- PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ "pushers",
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success",
)
@defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_name,
+ def update_pusher_failing_since(self, app_id, pushkey, user_id,
failing_since):
yield self._simple_update_one(
- PushersTable.table_name,
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
+ "pushers",
+ {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'failing_since': failing_since},
desc="update_pusher_failing_since",
)
-
-
-class PushersTable(object):
- table_name = "pushers"
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index f0fa0bd33c..70cde0d04d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore):
@@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash, was_guest=False):
+ def register(self, user_id, token, password_hash,
+ was_guest=False, make_guest=False):
"""Attempts to register an account.
Args:
@@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
+ make_guest (boolean): True if the the new user should be guest,
+ false to add a regular user account.
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
"register",
- self._register, user_id, token, password_hash, was_guest
+ self._register, user_id, token, password_hash, was_guest, make_guest
)
+ self.is_guest.invalidate((user_id,))
- def _register(self, txn, user_id, token, password_hash, was_guest):
+ def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
@@ -99,13 +103,15 @@ class RegistrationStore(SQLBaseStore):
if was_guest:
txn.execute("UPDATE users SET"
" password_hash = ?,"
- " upgrade_ts = ?"
+ " upgrade_ts = ?,"
+ " is_guest = ?"
" WHERE name = ?",
- [password_hash, now, user_id])
+ [password_hash, now, 1 if make_guest else 0, user_id])
else:
- txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
- "VALUES (?,?,?)",
- [user_id, password_hash, now])
+ txn.execute("INSERT INTO users "
+ "(name, password_hash, creation_ts, is_guest) "
+ "VALUES (?,?,?,?)",
+ [user_id, password_hash, now, 1 if make_guest else 0])
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={
"name": user_id,
},
- retcols=["name", "password_hash"],
+ retcols=["name", "password_hash", "is_guest"],
allow_none=True,
)
@@ -249,9 +255,41 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
+ @cachedInlineCallbacks()
+ def is_guest(self, user_id):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ defer.returnValue(res if res else False)
+
+ @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+ inlineCallbacks=True)
+ def are_guests(self, user_ids):
+ sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+ ",".join("?" for _ in user_ids),
+ )
+
+ rows = yield self._execute(
+ "are_guests", self.cursor_to_dict, sql, *user_ids
+ )
+
+ result = {user_id: False for user_id in user_ids}
+
+ result.update({
+ row["name"]: bool(row["is_guest"])
+ for row in rows
+ })
+
+ defer.returnValue(result)
+
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 390bd78654..dc09a3aaba 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -49,7 +49,7 @@ class RoomStore(SQLBaseStore):
"""
try:
yield self._simple_insert(
- RoomsTable.table_name,
+ "rooms",
{
"room_id": room_id,
"creator": room_creator_user_id,
@@ -70,9 +70,9 @@ class RoomStore(SQLBaseStore):
A namedtuple containing the room information, or an empty list.
"""
return self._simple_select_one(
- table=RoomsTable.table_name,
+ table="rooms",
keyvalues={"room_id": room_id},
- retcols=RoomsTable.fields,
+ retcols=("room_id", "is_public", "creator"),
desc="get_room",
allow_none=True,
)
@@ -275,13 +275,3 @@ class RoomStore(SQLBaseStore):
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
-
-
-class RoomsTable(object):
- table_name = "rooms"
-
- fields = [
- "room_id",
- "is_public",
- "creator"
- ]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7d3ce4579d..68ac88905f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all()
+ self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2)
@@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
+
+ @cached()
+ def who_forgot_in_room(self, room_id):
+ return self._simple_select_list(
+ table="room_memberships",
+ retcols=("user_id", "event_id"),
+ keyvalues={
+ "room_id": room_id,
+ "forgotten": 1,
+ },
+ desc="who_forgot"
+ )
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql
new file mode 100644
index 0000000000..bdf6ae3f24
--- /dev/null
+++ b/synapse/storage/schema/delta/28/event_push_actions.sql
@@ -0,0 +1,26 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_push_actions(
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ profile_tag VARCHAR(32),
+ actions TEXT NOT NULL,
+ CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
+);
+
+
+CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag);
diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/schema/delta/28/users_is_guest.sql
new file mode 100644
index 0000000000..21d2b420bf
--- /dev/null
+++ b/synapse/storage/schema/delta/28/users_is_guest.sql
@@ -0,0 +1,22 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD is_guest SMALLINT DEFAULT 0 NOT NULL;
+/*
+ * NB: any guest users created between 27 and 28 will be incorrectly
+ * marked as not guests: we don't bother to fill these in correctly
+ * because guest access is not really complete in 27 anyway so it's
+ * very unlikley there will be any guest users created.
+ */
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index b40a070b69..4475c451c1 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -16,8 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
-from collections import namedtuple
-
from canonicaljson import encode_canonical_json
import logging
@@ -50,12 +48,15 @@ class TransactionStore(SQLBaseStore):
def _get_received_txn_response(self, txn, transaction_id, origin):
result = self._simple_select_one_txn(
txn,
- table=ReceivedTransactionsTable.table_name,
+ table="received_transactions",
keyvalues={
"transaction_id": transaction_id,
"origin": origin,
},
- retcols=ReceivedTransactionsTable.fields,
+ retcols=(
+ "transaction_id", "origin", "ts", "response_code", "response_json",
+ "has_been_referenced",
+ ),
allow_none=True,
)
@@ -79,7 +80,7 @@ class TransactionStore(SQLBaseStore):
"""
return self._simple_insert(
- table=ReceivedTransactionsTable.table_name,
+ table="received_transactions",
values={
"transaction_id": transaction_id,
"origin": origin,
@@ -136,7 +137,7 @@ class TransactionStore(SQLBaseStore):
self._simple_insert_txn(
txn,
- table=SentTransactions.table_name,
+ table="sent_transactions",
values={
"id": next_id,
"transaction_id": transaction_id,
@@ -171,7 +172,7 @@ class TransactionStore(SQLBaseStore):
code, response_json):
self._simple_update_one_txn(
txn,
- table=SentTransactions.table_name,
+ table="sent_transactions",
keyvalues={
"transaction_id": transaction_id,
"destination": destination,
@@ -229,11 +230,11 @@ class TransactionStore(SQLBaseStore):
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
txn,
- table=DestinationsTable.table_name,
+ table="destinations",
keyvalues={
"destination": destination,
},
- retcols=DestinationsTable.fields,
+ retcols=("destination", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -304,52 +305,3 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
-
-
-class ReceivedTransactionsTable(object):
- table_name = "received_transactions"
-
- fields = [
- "transaction_id",
- "origin",
- "ts",
- "response_code",
- "response_json",
- "has_been_referenced",
- ]
-
-
-class SentTransactions(object):
- table_name = "sent_transactions"
-
- fields = [
- "id",
- "transaction_id",
- "destination",
- "ts",
- "response_code",
- "response_json",
- ]
-
- EntryType = namedtuple("SentTransactionsEntry", fields)
-
-
-class TransactionsToPduTable(object):
- table_name = "transaction_id_to_pdu"
-
- fields = [
- "transaction_id",
- "destination",
- "pdu_id",
- "pdu_origin",
- ]
-
-
-class DestinationsTable(object):
- table_name = "destinations"
-
- fields = [
- "destination",
- "retry_last_ts",
- "retry_interval",
- ]
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 8c082bf4e0..4f089bfb94 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -22,6 +22,9 @@ import logging
logger = logging.getLogger(__name__)
+MAX_LIMIT = 1000
+
+
class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a
@@ -32,7 +35,7 @@ class SourcePaginationConfig(object):
self.from_key = from_key
self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b'
- self.limit = int(limit) if limit is not None else None
+ self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self):
return (
@@ -49,7 +52,7 @@ class PaginationConfig(object):
self.from_token = from_token
self.to_token = to_token
self.direction = 'f' if direction == 'f' else 'b'
- self.limit = int(limit) if limit is not None else None
+ self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod
def from_request(cls, request, raise_invalid_params=True,
diff --git a/synapse/types.py b/synapse/types.py
index 1ec7b3e103..2095837ba6 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
+Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
+
+
class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain"))
):
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 5ff4c8a873..474c5c418f 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _, _) = yield self.auth.get_user_by_req(request)
- self.assertEquals(user.to_string(), self.test_user)
+ requester = yield self.auth.get_user_by_req(request)
+ self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
@@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _, _) = yield self.auth.get_user_by_req(request)
- self.assertEquals(user.to_string(), self.test_user)
+ requester = yield self.auth.get_user_by_req(request)
+ self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
@@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
- (user, _, _) = yield self.auth.get_user_by_req(request)
- self.assertEquals(user.to_string(), masquerading_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ self.assertEquals(requester.user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 191c420c4d..ef48bbc296 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -29,6 +29,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
+ id="unique_identifier",
url="some_url",
token="some_token",
namespaces={
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
deleted file mode 100644
index 30355ea99a..0000000000
--- a/tests/handlers/test_federation.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-from tests import unittest
-
-from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
-from synapse.handlers.federation import FederationHandler
-
-from mock import NonCallableMock, ANY, Mock
-
-from ..utils import setup_test_homeserver
-
-
-class FederationTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
-
- self.state_handler = NonCallableMock(spec_set=[
- "compute_event_context",
- ])
-
- self.auth = NonCallableMock(spec_set=[
- "check",
- "check_host_in_room",
- ])
-
- self.hostname = "test"
- hs = yield setup_test_homeserver(
- self.hostname,
- datastore=NonCallableMock(spec_set=[
- "persist_event",
- "store_room",
- "get_room",
- "get_destination_retry_timings",
- "set_destination_retry_timings",
- "have_events",
- ]),
- resource_for_federation=NonCallableMock(),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_member_handler",
- "federation_handler",
- ]),
- auth=self.auth,
- state_handler=self.state_handler,
- keyring=Mock(),
- )
-
- self.datastore = hs.get_datastore()
- self.handlers = hs.get_handlers()
- self.notifier = hs.get_notifier()
- self.hs = hs
-
- self.handlers.federation_handler = FederationHandler(self.hs)
-
- @defer.inlineCallbacks
- def test_msg(self):
- pdu = FrozenEvent({
- "type": EventTypes.Message,
- "room_id": "foo",
- "content": {"msgtype": u"fooo"},
- "origin_server_ts": 0,
- "event_id": "$a:b",
- "user_id":"@a:b",
- "origin": "b",
- "auth_events": [],
- "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
- })
-
- self.datastore.persist_event.return_value = defer.succeed((1,1))
- self.datastore.get_room.return_value = defer.succeed(True)
- self.auth.check_host_in_room.return_value = defer.succeed(True)
-
- retry_timings_res = {
- "destination": "",
- "retry_last_ts": 0,
- "retry_interval": 0,
- }
- self.datastore.get_destination_retry_timings.return_value = (
- defer.succeed(retry_timings_res)
- )
-
- def have_events(event_ids):
- return defer.succeed({})
- self.datastore.have_events.side_effect = have_events
-
- def annotate(ev, old_state=None, outlier=False):
- context = Mock()
- context.current_state = {}
- context.auth_events = {}
- return defer.succeed(context)
- self.state_handler.compute_event_context.side_effect = annotate
-
- yield self.handlers.federation_handler.on_receive_pdu(
- "fo", pdu, False
- )
-
- self.datastore.persist_event.assert_called_once_with(
- ANY,
- is_new_state=True,
- backfilled=False,
- current_state=None,
- context=ANY,
- )
-
- self.state_handler.compute_event_context.assert_called_once_with(
- ANY, old_state=None, outlier=False
- )
-
- self.auth.check.assert_called_once_with(ANY, auth_events={})
-
- self.notifier.on_new_room_event.assert_called_once_with(
- ANY, 1, 1, extra_users=[]
- )
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 15000aae0c..447a22b5fc 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -28,7 +28,6 @@ from synapse.api.constants import PresenceState
from synapse.api.errors import SynapseError
from synapse.handlers.presence import PresenceHandler, UserPresenceCache
from synapse.streams.config import SourcePaginationConfig
-from synapse.storage.transactions import DestinationsTable
from synapse.types import UserID
OFFLINE = PresenceState.OFFLINE
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
deleted file mode 100644
index 12c8bed004..0000000000
--- a/tests/handlers/test_room.py
+++ /dev/null
@@ -1,404 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-from .. import unittest
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
-from synapse.handlers.profile import ProfileHandler
-from synapse.types import UserID
-from ..utils import setup_test_homeserver
-
-from mock import Mock, NonCallableMock
-
-
-class RoomMemberHandlerTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
- self.hostname = "red"
- hs = yield setup_test_homeserver(
- self.hostname,
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- datastore=NonCallableMock(spec_set=[
- "persist_event",
- "get_room_member",
- "get_room",
- "store_room",
- "get_latest_events_in_room",
- "add_event_hashes",
- ]),
- resource_for_federation=NonCallableMock(),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_member_handler",
- "profile_handler",
- "federation_handler",
- ]),
- auth=NonCallableMock(spec_set=[
- "check",
- "add_auth_events",
- "check_host_in_room",
- ]),
- state_handler=NonCallableMock(spec_set=[
- "compute_event_context",
- "get_current_state",
- ]),
- )
-
- self.federation = NonCallableMock(spec_set=[
- "handle_new_event",
- "send_invite",
- "get_state_for_room",
- ])
-
- self.datastore = hs.get_datastore()
- self.handlers = hs.get_handlers()
- self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
- self.distributor = hs.get_distributor()
- self.auth = hs.get_auth()
- self.hs = hs
-
- self.handlers.federation_handler = self.federation
-
- self.distributor.declare("collect_presencelike_data")
-
- self.handlers.room_member_handler = RoomMemberHandler(self.hs)
- self.handlers.profile_handler = ProfileHandler(self.hs)
- self.room_member_handler = self.handlers.room_member_handler
-
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- self.datastore.persist_event.return_value = (1,1)
- self.datastore.add_event_hashes.return_value = []
-
- @defer.inlineCallbacks
- def test_invite(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- target_user_id = "@red:blue"
- content = {"membership": Membership.INVITE}
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": target_user_id,
- "room_id": room_id,
- "content": content,
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@alice:green"): self._create_member(
- user_id="@alice:green",
- room_id=room_id,
- ),
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- def send_invite(domain, event):
- return defer.succeed(event)
-
- self.federation.send_invite.side_effect = send_invite
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- yield room_handler.change_membership(event, context)
-
- self.state_handler.compute_event_context.assert_called_once_with(
- builder
- )
-
- self.auth.add_auth_events.assert_called_once_with(
- builder, context
- )
-
- self.federation.send_invite.assert_called_once_with(
- "blue", event,
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context,
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
- )
- self.assertFalse(self.datastore.get_room.called)
- self.assertFalse(self.datastore.store_room.called)
- self.assertFalse(self.federation.get_state_for_room.called)
-
- @defer.inlineCallbacks
- def test_simple_join(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- user = UserID.from_string(user_id)
-
- join_signal_observer = Mock()
- self.distributor.observe("user_joined_room", join_signal_observer)
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": Membership.JOIN},
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- membership=Membership.INVITE
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- # Actual invocation
- yield room_handler.change_membership(event, context)
-
- self.federation.handle_new_event.assert_called_once_with(
- event, destinations=set()
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[user]
- )
-
- join_signal_observer.assert_called_with(
- user=user, room_id=room_id
- )
-
- def _create_member(self, user_id, room_id, membership=Membership.JOIN):
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": membership},
- })
-
- return builder.build()
-
- @defer.inlineCallbacks
- def test_simple_leave(self):
- room_id = "!foo:red"
- user_id = "@bob:red"
- user = UserID.from_string(user_id)
-
- builder = self.hs.get_event_builder_factory().new({
- "type": EventTypes.Member,
- "sender": user_id,
- "state_key": user_id,
- "room_id": room_id,
- "content": {"membership": Membership.LEAVE},
- })
-
- self.datastore.get_latest_events_in_room.return_value = (
- defer.succeed([])
- )
-
- def annotate(_):
- ctx = Mock()
- ctx.current_state = {
- (EventTypes.Member, "@bob:red"): self._create_member(
- user_id="@bob:red",
- room_id=room_id,
- membership=Membership.JOIN
- ),
- }
- ctx.prev_state_events = []
-
- return defer.succeed(ctx)
-
- self.state_handler.compute_event_context.side_effect = annotate
-
- def add_auth(_, ctx):
- ctx.auth_events = ctx.current_state[
- (EventTypes.Member, "@bob:red")
- ]
-
- return defer.succeed(True)
- self.auth.add_auth_events.side_effect = add_auth
-
- room_handler = self.room_member_handler
- event, context = yield room_handler._create_new_client_event(
- builder
- )
-
- leave_signal_observer = Mock()
- self.distributor.observe("user_left_room", leave_signal_observer)
-
- # Actual invocation
- yield room_handler.change_membership(event, context)
-
- self.federation.handle_new_event.assert_called_once_with(
- event, destinations=set(['red'])
- )
-
- self.datastore.persist_event.assert_called_once_with(
- event, context=context
- )
- self.notifier.on_new_room_event.assert_called_once_with(
- event, 1, 1, extra_users=[user]
- )
-
- leave_signal_observer.assert_called_with(
- user=user, room_id=room_id
- )
-
-
-class RoomCreationTest(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
- self.hostname = "red"
-
- hs = yield setup_test_homeserver(
- self.hostname,
- datastore=NonCallableMock(spec_set=[
- "store_room",
- "snapshot_room",
- "persist_event",
- "get_joined_hosts_for_room",
- ]),
- http_client=NonCallableMock(spec_set=[]),
- notifier=NonCallableMock(spec_set=["on_new_room_event"]),
- handlers=NonCallableMock(spec_set=[
- "room_creation_handler",
- "message_handler",
- ]),
- auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- )
-
- self.federation = NonCallableMock(spec_set=[
- "handle_new_event",
- ])
-
- self.handlers = hs.get_handlers()
-
- self.handlers.room_creation_handler = RoomCreationHandler(hs)
- self.room_creation_handler = self.handlers.room_creation_handler
-
- self.message_handler = self.handlers.message_handler
-
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- @defer.inlineCallbacks
- def test_room_creation(self):
- user_id = "@foo:red"
- room_id = "!bobs_room:red"
- config = {"visibility": "private"}
-
- yield self.room_creation_handler.create_room(
- user_id=user_id,
- room_id=room_id,
- config=config,
- )
-
- self.assertTrue(self.message_handler.create_and_send_event.called)
-
- event_dicts = [
- e[0][0]
- for e in self.message_handler.create_and_send_event.call_args_list
- ]
-
- self.assertTrue(len(event_dicts) > 3)
-
- self.assertDictContainsSubset(
- {
- "type": EventTypes.Create,
- "sender": user_id,
- "room_id": room_id,
- },
- event_dicts[0]
- )
-
- self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
-
- self.assertDictContainsSubset(
- {
- "type": EventTypes.Member,
- "sender": user_id,
- "room_id": room_id,
- "state_key": user_id,
- },
- event_dicts[1]
- )
-
- self.assertEqual(
- Membership.JOIN,
- event_dicts[1]["content"]["membership"]
- )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 124bc10e0f..763c04d667 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -27,7 +27,6 @@ from ..utils import (
from synapse.api.errors import AuthError
from synapse.handlers.typing import TypingNotificationHandler
-from synapse.storage.transactions import DestinationsTable
from synapse.types import UserID
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index d782eadb6a..90b911f879 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""Tests REST events for /presence paths."""
-
from tests import unittest
from twisted.internet import defer
@@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events
-from synapse.types import UserID
+from synapse.types import Requester, UserID
from synapse.util.async import run_on_reactor
from collections import namedtuple
@@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False):
- return (UserID.from_string(myid), "", False)
+ return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 77b7b06c10..c1a3f52043 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,16 +14,15 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-
from tests import unittest
from twisted.internet import defer
-from mock import Mock, NonCallableMock
+from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.errors import SynapseError, AuthError
-from synapse.types import UserID
+from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile
@@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return (UserID.from_string(myid), "", False)
+ return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index a5a464640f..5abecdf6e0 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import tempfile
+from synapse.config._base import ConfigError
from tests import unittest
from twisted.internet import defer
from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState
-from synapse.server import HomeServer
from synapse.storage.appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@@ -26,7 +27,6 @@ import json
import os
import yaml
from mock import Mock
-from tests.utils import SQLiteMemoryDbPool, MockClock
class ApplicationServiceStoreTestCase(unittest.TestCase):
@@ -41,9 +41,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token = "token1"
self.as_url = "some_url"
- self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob")
- self._add_appservice("token2", "some_url", "some_hs_token", "bob")
- self._add_appservice("token3", "some_url", "some_hs_token", "bob")
+ self.as_id = "as1"
+ self._add_appservice(
+ self.as_token,
+ self.as_id,
+ self.as_url,
+ "some_hs_token",
+ "bob"
+ )
+ self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
+ self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
self.store = ApplicationServiceStore(hs)
@@ -55,9 +62,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
except:
pass
- def _add_appservice(self, as_token, url, hs_token, sender):
+ def _add_appservice(self, as_token, id, url, hs_token, sender):
as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
- sender_localpart=sender, namespaces={})
+ id=id, sender_localpart=sender, namespaces={})
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -74,6 +81,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
+ self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
@@ -110,34 +118,34 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
{
"token": "token1",
"url": "https://matrix-as.org",
- "id": "token1"
+ "id": "id_1"
},
{
"token": "alpha_tok",
"url": "https://alpha.com",
- "id": "alpha_tok"
+ "id": "id_alpha"
},
{
"token": "beta_tok",
"url": "https://beta.com",
- "id": "beta_tok"
+ "id": "id_beta"
},
{
- "token": "delta_tok",
- "url": "https://delta.com",
- "id": "delta_tok"
+ "token": "gamma_tok",
+ "url": "https://gamma.com",
+ "id": "id_gamma"
},
]
for s in self.as_list:
- yield self._add_service(s["url"], s["token"])
+ yield self._add_service(s["url"], s["token"], s["id"])
self.as_yaml_files = []
self.store = TestTransactionStore(hs)
- def _add_service(self, url, as_token):
+ def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
- sender_localpart="a_sender", namespaces={})
+ id=id, sender_localpart="a_sender", namespaces={})
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -405,3 +413,64 @@ class TestTransactionStore(ApplicationServiceTransactionStore,
def __init__(self, hs):
super(TestTransactionStore, self).__init__(hs)
+
+
+class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
+
+ def _write_config(self, suffix, **kwargs):
+ vals = {
+ "id": "id" + suffix,
+ "url": "url" + suffix,
+ "as_token": "as_token" + suffix,
+ "hs_token": "hs_token" + suffix,
+ "sender_localpart": "sender_localpart" + suffix,
+ "namespaces": {},
+ }
+ vals.update(kwargs)
+
+ _, path = tempfile.mkstemp(prefix="as_config")
+ with open(path, "w") as f:
+ f.write(yaml.dump(vals))
+ return path
+
+ @defer.inlineCallbacks
+ def test_unique_works(self):
+ f1 = self._write_config(suffix="1")
+ f2 = self._write_config(suffix="2")
+
+ config = Mock(app_service_config_files=[f1, f2])
+ hs = yield setup_test_homeserver(config=config)
+
+ ApplicationServiceStore(hs)
+
+ @defer.inlineCallbacks
+ def test_duplicate_ids(self):
+ f1 = self._write_config(id="id", suffix="1")
+ f2 = self._write_config(id="id", suffix="2")
+
+ config = Mock(app_service_config_files=[f1, f2])
+ hs = yield setup_test_homeserver(config=config)
+
+ with self.assertRaises(ConfigError) as cm:
+ ApplicationServiceStore(hs)
+
+ e = cm.exception
+ self.assertIn(f1, e.message)
+ self.assertIn(f2, e.message)
+ self.assertIn("id", e.message)
+
+ @defer.inlineCallbacks
+ def test_duplicate_as_tokens(self):
+ f1 = self._write_config(as_token="as_token", suffix="1")
+ f2 = self._write_config(as_token="as_token", suffix="2")
+
+ config = Mock(app_service_config_files=[f1, f2])
+ hs = yield setup_test_homeserver(config=config)
+
+ with self.assertRaises(ConfigError) as cm:
+ ApplicationServiceStore(hs)
+
+ e = cm.exception
+ self.assertIn(f1, e.message)
+ self.assertIn(f2, e.message)
+ self.assertIn("as_token", e.message)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c6d3ea7325..a35efcc71e 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -45,7 +45,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertEquals(
# TODO(paul): Surely this field should be 'user_id', not 'name'
# Additionally surely it shouldn't come in a 1-element list
- {"name": self.user_id, "password_hash": self.pwhash},
+ {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0},
(yield self.store.get_user_by_id(self.user_id))
)
|