diff --git a/.travis.yml b/.travis.yml
index a98d547978..b34b17af75 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -24,6 +24,9 @@ matrix:
env: TOX_ENV=py36
- python: 3.6
+ env: TOX_ENV=check_isort
+
+ - python: 3.6
env: TOX_ENV=check-newsfragment
install:
diff --git a/changelog.d/3316.feature b/changelog.d/3316.feature
new file mode 100644
index 0000000000..50068b7222
--- /dev/null
+++ b/changelog.d/3316.feature
@@ -0,0 +1 @@
+Enforce the specified API for report_event
diff --git a/changelog.d/3351.misc b/changelog.d/3351.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3351.misc
diff --git a/changelog.d/3499.misc b/changelog.d/3499.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3499.misc
diff --git a/changelog.d/3505.feature b/changelog.d/3505.feature
new file mode 100644
index 0000000000..ca1867f529
--- /dev/null
+++ b/changelog.d/3505.feature
@@ -0,0 +1 @@
+Reduce database consumption when processing large numbers of receipts
diff --git a/changelog.d/3521.feature b/changelog.d/3521.feature
new file mode 100644
index 0000000000..6dced5f2ae
--- /dev/null
+++ b/changelog.d/3521.feature
@@ -0,0 +1 @@
+Cache optimisation for /sync requests
\ No newline at end of file
diff --git a/changelog.d/3530.misc b/changelog.d/3530.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3530.misc
diff --git a/changelog.d/3533.bugfix b/changelog.d/3533.bugfix
new file mode 100644
index 0000000000..04cbbefd5f
--- /dev/null
+++ b/changelog.d/3533.bugfix
@@ -0,0 +1 @@
+Fix queued federation requests being processed in the wrong order
diff --git a/changelog.d/3534.misc b/changelog.d/3534.misc
new file mode 100644
index 0000000000..949c12dc69
--- /dev/null
+++ b/changelog.d/3534.misc
@@ -0,0 +1 @@
+refactor: use parse_{string,integer} and assert's from http.servlet for deduplication
diff --git a/changelog.d/3535.misc b/changelog.d/3535.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3535.misc
diff --git a/changelog.d/3540.misc b/changelog.d/3540.misc
new file mode 100644
index 0000000000..99dcad8e46
--- /dev/null
+++ b/changelog.d/3540.misc
@@ -0,0 +1 @@
+check isort for each PR
diff --git a/changelog.d/3541.feature b/changelog.d/3541.feature
new file mode 100644
index 0000000000..24524136ea
--- /dev/null
+++ b/changelog.d/3541.feature
@@ -0,0 +1 @@
+Optimisation to make handling incoming federation requests more efficient.
\ No newline at end of file
diff --git a/changelog.d/3544.misc b/changelog.d/3544.misc
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/changelog.d/3544.misc
diff --git a/changelog.d/3546.bugfix b/changelog.d/3546.bugfix
new file mode 100644
index 0000000000..921dc6e7b0
--- /dev/null
+++ b/changelog.d/3546.bugfix
@@ -0,0 +1 @@
+Ensure that erasure requests are correctly honoured for publicly accessible rooms when accessed over federation.
\ No newline at end of file
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6dec862fec..bc629832d9 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
- access_token = get_access_token_from_request(
+ access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
- get_access_token_from_request(
+ self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
@@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
- token = get_access_token_from_request(
+ token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
@@ -673,67 +673,67 @@ class Auth(object):
" edit its room list entry"
)
+ @staticmethod
+ def has_access_token(request):
+ """Checks if the request has an access_token.
-def has_access_token(request):
- """Checks if the request has an access_token.
+ Returns:
+ bool: False if no access_token was given, True otherwise.
+ """
+ query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ return bool(query_params) or bool(auth_headers)
- Returns:
- bool: False if no access_token was given, True otherwise.
- """
- query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
- return bool(query_params) or bool(auth_headers)
-
-
-def get_access_token_from_request(request, token_not_found_http_status=401):
- """Extracts the access_token from the request.
-
- Args:
- request: The http request.
- token_not_found_http_status(int): The HTTP status code to set in the
- AuthError if the token isn't found. This is used in some of the
- legacy APIs to change the status code to 403 from the default of
- 401 since some of the old clients depended on auth errors returning
- 403.
- Returns:
- str: The access_token
- Raises:
- AuthError: If there isn't an access_token in the request.
- """
+ @staticmethod
+ def get_access_token_from_request(request, token_not_found_http_status=401):
+ """Extracts the access_token from the request.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
- query_params = request.args.get(b"access_token")
- if auth_headers:
- # Try the get the access_token from a "Authorization: Bearer"
- # header
- if query_params is not None:
- raise AuthError(
- token_not_found_http_status,
- "Mixing Authorization headers and access_token query parameters.",
- errcode=Codes.MISSING_TOKEN,
- )
- if len(auth_headers) > 1:
- raise AuthError(
- token_not_found_http_status,
- "Too many Authorization headers.",
- errcode=Codes.MISSING_TOKEN,
- )
- parts = auth_headers[0].split(" ")
- if parts[0] == "Bearer" and len(parts) == 2:
- return parts[1]
+ Args:
+ request: The http request.
+ token_not_found_http_status(int): The HTTP status code to set in the
+ AuthError if the token isn't found. This is used in some of the
+ legacy APIs to change the status code to 403 from the default of
+ 401 since some of the old clients depended on auth errors returning
+ 403.
+ Returns:
+ str: The access_token
+ Raises:
+ AuthError: If there isn't an access_token in the request.
+ """
+
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
+ if auth_headers:
+ # Try the get the access_token from a "Authorization: Bearer"
+ # header
+ if query_params is not None:
+ raise AuthError(
+ token_not_found_http_status,
+ "Mixing Authorization headers and access_token query parameters.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ if len(auth_headers) > 1:
+ raise AuthError(
+ token_not_found_http_status,
+ "Too many Authorization headers.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ parts = auth_headers[0].split(" ")
+ if parts[0] == "Bearer" and len(parts) == 2:
+ return parts[1]
+ else:
+ raise AuthError(
+ token_not_found_http_status,
+ "Invalid Authorization header.",
+ errcode=Codes.MISSING_TOKEN,
+ )
else:
- raise AuthError(
- token_not_found_http_status,
- "Invalid Authorization header.",
- errcode=Codes.MISSING_TOKEN,
- )
- else:
- # Try to get the access_token from the query params.
- if not query_params:
- raise AuthError(
- token_not_found_http_status,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ # Try to get the access_token from the query params.
+ if not query_params:
+ raise AuthError(
+ token_not_found_http_status,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
- return query_params[0]
+ return query_params[0]
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f0c7a06718..c11798093d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
-from synapse.http.servlet import assert_params_in_request
+from synapse.http.servlet import assert_params_in_dict
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+ assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d3ecebd29f..20fb46fc89 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -43,7 +43,6 @@ from synapse.crypto.event_signing import (
add_hashes_and_signatures,
compute_event_signature,
)
-from synapse.events.utils import prune_event
from synapse.events.validator import EventValidator
from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id
@@ -52,8 +51,8 @@ from synapse.util.async import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function
-from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
+from synapse.visibility import filter_events_for_server
from ._base import BaseHandler
@@ -501,137 +500,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- @measure_func("_filter_events_for_server")
- @defer.inlineCallbacks
- def _filter_events_for_server(self, server_name, room_id, events):
- """Filter the given events for the given server, redacting those the
- server can't see.
-
- Assumes the server is currently in the room.
-
- Returns
- list[FrozenEvent]
- """
- # First lets check to see if all the events have a history visibility
- # of "shared" or "world_readable". If thats the case then we don't
- # need to check membership (as we know the server is in the room).
- event_to_state_ids = yield self.store.get_state_ids_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- )
- )
-
- visibility_ids = set()
- for sids in event_to_state_ids.itervalues():
- hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
- if hist:
- visibility_ids.add(hist)
-
- # If we failed to find any history visibility events then the default
- # is "shared" visiblity.
- if not visibility_ids:
- defer.returnValue(events)
-
- event_map = yield self.store.get_events(visibility_ids)
- all_open = all(
- e.content.get("history_visibility") in (None, "shared", "world_readable")
- for e in event_map.itervalues()
- )
-
- if all_open:
- defer.returnValue(events)
-
- # Ok, so we're dealing with events that have non-trivial visibility
- # rules, so we need to also get the memberships of the room.
-
- event_to_state_ids = yield self.store.get_state_ids_for_events(
- frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
- )
- )
-
- # We only want to pull out member events that correspond to the
- # server's domain.
-
- def check_match(id):
- try:
- return server_name == get_domain_from_id(id)
- except Exception:
- return False
-
- # Parses mapping `event_id -> (type, state_key) -> state event_id`
- # to get all state ids that we're interested in.
- event_map = yield self.store.get_events([
- e_id
- for key_to_eid in list(event_to_state_ids.values())
- for key, e_id in key_to_eid.items()
- if key[0] != EventTypes.Member or check_match(key[1])
- ])
-
- event_to_state = {
- e_id: {
- key: event_map[inner_e_id]
- for key, inner_e_id in key_to_eid.iteritems()
- if inner_e_id in event_map
- }
- for e_id, key_to_eid in event_to_state_ids.iteritems()
- }
-
- erased_senders = yield self.store.are_users_erased(
- e.sender for e in events,
- )
-
- def redact_disallowed(event, state):
- # if the sender has been gdpr17ed, always return a redacted
- # copy of the event.
- if erased_senders[event.sender]:
- logger.info(
- "Sender of %s has been erased, redacting",
- event.event_id,
- )
- return prune_event(event)
-
- if not state:
- return event
-
- history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
- if history:
- visibility = history.content.get("history_visibility", "shared")
- if visibility in ["invited", "joined"]:
- # We now loop through all state events looking for
- # membership states for the requesting server to determine
- # if the server is either in the room or has been invited
- # into the room.
- for ev in state.itervalues():
- if ev.type != EventTypes.Member:
- continue
- try:
- domain = get_domain_from_id(ev.state_key)
- except Exception:
- continue
-
- if domain != server_name:
- continue
-
- memtype = ev.membership
- if memtype == Membership.JOIN:
- return event
- elif memtype == Membership.INVITE:
- if visibility == "invited":
- return event
- else:
- return prune_event(event)
-
- return event
-
- defer.returnValue([
- redact_disallowed(e, event_to_state[e.event_id])
- for e in events
- ])
-
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities):
@@ -1558,7 +1426,7 @@ class FederationHandler(BaseHandler):
limit
)
- events = yield self._filter_events_for_server(origin, room_id, events)
+ events = yield filter_events_for_server(self.store, origin, events)
defer.returnValue(events)
@@ -1605,8 +1473,8 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield self._filter_events_for_server(
- origin, event.room_id, [event]
+ events = yield filter_events_for_server(
+ self.store, origin, [event],
)
event = events[0]
defer.returnValue(event)
@@ -1896,8 +1764,8 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
)
- missing_events = yield self._filter_events_for_server(
- origin, room_id, missing_events,
+ missing_events = yield filter_events_for_server(
+ self.store, origin, missing_events,
)
defer.returnValue(missing_events)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 1abd45297b..828229f5c3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
-EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
+EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
@@ -50,7 +50,7 @@ class RoomListHandler(BaseHandler):
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -87,7 +87,7 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index cf6723563a..882816dc8f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
return content
-def assert_params_in_request(body, required):
+def assert_params_in_dict(body, required):
absent = []
for k in required:
if k not in body:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 21e26f9c5e..5fd30a4c2c 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics
-from synapse.util.logcontext import LoggingContext, ContextResourceUsage
+from synapse.util.logcontext import ContextResourceUsage, LoggingContext
logger = logging.getLogger(__name__)
@@ -42,9 +42,10 @@ class SynapseRequest(Request):
which is handling the request, and returns a context manager.
"""
- def __init__(self, site, *args, **kw):
- Request.__init__(self, *args, **kw)
+ def __init__(self, site, channel, *args, **kw):
+ Request.__init__(self, channel, *args, **kw)
self.site = site
+ self._channel = channel
self.authenticated_entity = None
self.start_time = 0
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 7ab12b850f..ed12342f40 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -49,7 +49,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
- self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 7c01b438cb..00b1b3066e 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,38 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
-
-def get_transaction_key(request):
- """A helper function which returns a transaction key that can be used
- with TransactionCache for idempotent requests.
-
- Idempotency is based on the returned key being the same for separate
- requests to the same endpoint. The key is formed from the HTTP request
- path and the access_token for the requesting user.
-
- Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
- Returns:
- str: A transaction key
- """
- token = get_access_token_from_request(request)
- return request.path + "/" + token
-
-
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
- def __init__(self, clock):
- self.clock = clock
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = self.hs.get_auth()
+ self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@@ -56,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ def _get_transaction_key(self, request):
+ """A helper function which returns a transaction key that can be used
+ with TransactionCache for idempotent requests.
+
+ Idempotency is based on the returned key being the same for separate
+ requests to the same endpoint. The key is formed from the HTTP request
+ path and the access_token for the requesting user.
+
+ Args:
+ request (twisted.web.http.Request): The incoming request. Must
+ contain an access_token.
+ Returns:
+ str: A transaction key
+ """
+ token = self.auth.get_access_token_from_request(request)
+ return request.path + "/" + token
+
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
- get_transaction_key(request), fn, *args, **kwargs
+ self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 2b091d61a5..2dc50e582b 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns
@@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- before_ts = request.args.get("before_ts", None)
- if not before_ts:
- raise SynapseError(400, "Missing 'before_ts' arg")
-
- logger.info("before_ts: %r", before_ts[0])
-
- try:
- before_ts = int(before_ts[0])
- except Exception:
- raise SynapseError(400, "Invalid 'before_ts' arg")
+ before_ts = parse_integer(request, "before_ts", required=True)
+ logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
-
- new_room_user_id = content.get("new_room_user_id")
- if not new_room_user_id:
- raise SynapseError(400, "Please provide field `new_room_user_id`")
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
@@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- if not new_password:
- raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
@@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
- start = request.args.get("start")[0]
- limit = request.args.get("limit")[0]
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
+ start = parse_integer(request, "start", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["limit", "start"])
limit = params['limit']
start = params['start']
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
- term = request.args.get("term")[0]
- if not term:
- raise SynapseError(400, "Missing 'term' arg")
-
+ term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index dde02328c3..c77d7aba68 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 4fdbb83815..69dcd618cb 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
+ room_alias = RoomAlias.from_string(room_alias)
+
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, "Missing room_id key",
+ raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
-
- room_alias = RoomAlias.from_string(room_alias)
-
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index fbe8cb2023..fd5f85b53e 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
@@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
- include_archived = request.args.get("archived", None) == ["true"]
+ include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 05a8ecfcd8..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
-from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
- access_token = get_access_token_from_request(request)
+ access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 0df7ce570f..6e95d9bec2 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import parse_json_value_from_request
+from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@@ -75,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- before = request.args.get("before", None)
+ before = parse_string(request, "before")
if before:
- before = _namespaced_rule_id(spec, before[0])
+ before = _namespaced_rule_id(spec, before)
- after = request.args.get("after", None)
+ after = parse_string(request, "after")
if after:
- after = _namespaced_rule_id(spec, after[0])
+ after = _namespaced_rule_id(spec, after)
try:
yield self.store.add_push_rule(
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 1581f88db5..182a68b1e2 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
+ assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
)
defer.returnValue((200, {}))
- reqd = ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
- missing = []
- for i in reqd:
- if i not in content:
- missing.append(i)
- if len(missing):
- raise SynapseError(400, "Missing parameters: " + ','.join(missing),
- errcode=Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ content,
+ ['kind', 'app_id', 'app_display_name',
+ 'device_display_name', 'pushkey', 'lang', 'data']
+ )
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(RestServlet, self).__init__()
+ super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 3ce5f8b726..25a143af8d 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -18,15 +18,12 @@ import hmac
import logging
from hashlib import sha1
-from six import string_types
-
from twisted.internet import defer
import synapse.util.stringutils as stringutils
-from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns
@@ -67,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
+ self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
- if "type" not in register_json:
- raise SynapseError(400, "Missing 'type' key.")
+ assert_params_in_dict(register_json, ["type"])
try:
login_type = register_json["type"]
@@ -310,11 +307,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- as_token = get_access_token_from_request(request)
-
- if "user" not in register_json:
- raise SynapseError(400, "Expected 'user' key.")
+ as_token = self.auth.get_access_token_from_request(request)
+ assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
- if not isinstance(register_json.get("mac", None), string_types):
- raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), string_types):
- raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), string_types):
- raise SynapseError(400, "Expected 'password' key.")
+ assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -400,7 +390,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
@@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
- if "localpart" not in user_json:
- raise SynapseError(400, "Expected 'localpart' key.")
-
- if "displayname" not in user_json:
- raise SynapseError(400, "Expected 'displayname' key.")
+ assert_params_in_dict(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2470db52ba..3d62447854 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
- filter_bytes = request.args.get("filter", None)
+ filter_bytes = parse_string(request, "filter")
if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- limit = int(request.args.get("limit", [10])[0])
+ limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context(
requester.user,
@@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
- if "user_id" not in content:
- raise SynapseError(400, "Missing user_id key.")
+ assert_params_in_dict(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
@@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- batch = request.args.get("next_batch", [None])[0]
+ batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 528c1f43f9..eeae466d82 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,12 +20,11 @@ from six.moves import http_client
from twisted.internet import defer
-from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
- assert_params_in_request,
+ assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.util.msisdn import phone_number_to_msisdn
@@ -48,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@@ -81,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
@@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
- if has_access_token(request):
+ if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@@ -160,11 +159,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
- logger.error("Auth succeeded but no known type!", result.keys())
+ logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- if 'new_password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
yield self._set_password_handler.set_password(
@@ -229,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ body,
+ ['id_server', 'client_secret', 'email', 'send_attempt'],
+ )
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
@@ -267,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
- ]
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ ])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@@ -373,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = ['medium', 'address']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 09f6a8efe3..9b75bb1377 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -18,14 +18,18 @@ import logging
from twisted.internet import defer
from synapse.api import errors
-from synapse.http import servlet
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
-class DevicesRestServlet(servlet.RestServlet):
+class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
@@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
-class DeleteDevicesRestServlet(servlet.RestServlet):
+class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
@@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # deal with older clients which didn't pass a J*DELETESON dict
+ # DELETE
+ # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
- if 'devices' not in body:
- raise errors.SynapseError(
- 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
- )
+ assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
-class DeviceRestServlet(servlet.RestServlet):
+class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
@@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 896650d5a5..d6cf915d86 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -24,12 +24,11 @@ from twisted.internet import defer
import synapse
import synapse.types
-from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
RestServlet,
- assert_params_in_request,
+ assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@@ -69,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@@ -105,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
@@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
- if has_access_token(request):
+ if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid
# IDs.
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
@@ -387,9 +386,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
- if 'password' not in params:
- raise SynapseError(400, "Missing password.",
- Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
@@ -566,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
- if any(x not in threepid for x in reqd):
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ try:
+ assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ except SynapseError as ex:
+ if ex.errcode == Codes.MISSING_PARAM:
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue(None)
+ raise
yield self.auth_handler.add_threepid(
user_id,
@@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
- defer.returnValue((403, "Guest access is disabled"))
+ raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 08bb8e04fd..95d2a71ec2 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -15,9 +15,17 @@
import logging
+from six import string_types
+from six.moves import http_client
+
from twisted.internet import defer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
from ._base import client_v2_patterns
@@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ("reason", "score"))
+
+ if not isinstance(body["reason"], string_types):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'reason' must be a string",
+ Codes.BAD_JSON,
+ )
+ if not isinstance(body["score"], int):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'score' must be an integer",
+ Codes.BAD_JSON,
+ )
yield self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
- reason=body.get("reason"),
+ reason=body["reason"],
content=body,
received_ts=self.clock.time_msec(),
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 90bdb1db15..a9e9a47a0b 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):
diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py
index a2e391415f..bdbd8d50dd 100644
--- a/synapse/rest/media/v1/identicon_resource.py
+++ b/synapse/rest/media/v1/identicon_resource.py
@@ -16,6 +16,8 @@ from pydenticon import Generator
from twisted.web.resource import Resource
+from synapse.http.servlet import parse_integer
+
FOREGROUND = [
"rgb(45,79,255)",
"rgb(254,180,44)",
@@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request):
name = "/".join(request.postpath)
- width = int(request.args.get("width", [96])[0])
- height = int(request.args.get("height", [96])[0])
+ width = parse_integer(request, "width", default=96)
+ height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 4e3a18ce08..b70b15c4c2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes,
wrap_json_request_handler,
)
+from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
- url = request.args.get("url")[0]
+ url = parse_string(request, "url")
if "ts" in request.args:
- ts = int(request.args.get("ts")[0])
+ ts = parse_integer(request, "ts")
else:
ts = self.clock.time_msec()
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 1a98120e1d..9b22d204a6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler
+from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
@@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
- upload_name = request.args.get("filename", None)
+ upload_name = parse_string(request, "filename")
if upload_name:
try:
- upload_name = upload_name[0].decode('UTF-8')
+ upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 3738901ea4..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -140,7 +140,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
room_ids = set(room_ids)
- if from_key:
+ if from_key is not None:
+ # Only ask the database about rooms where there have been new
+ # receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
@@ -151,7 +153,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -162,7 +163,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
from the start.
Returns:
- list: A list of receipts.
+ Deferred[list]: A list of receipts.
+ """
+ if from_key is not None:
+ # Check the cache first to see if any new receipts have been added
+ # since`from_key`. If not we can no-op.
+ if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ defer.succeed([])
+
+ return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+ @cachedInlineCallbacks(num_args=3, tree=True)
+ def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """See get_linearized_receipts_for_room
"""
def f(txn):
if from_key:
@@ -211,7 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"content": content,
}])
- @cachedList(cached_method_name="get_linearized_receipts_for_room",
+ @cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
@@ -373,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
@@ -493,7 +506,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn(
txn,
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 46ccbbda7d..451e4fa441 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,6 +16,7 @@
import logging
from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None):
- def get_param(name, default=None):
- lst = request.args.get(name, [])
- if len(lst) > 1:
- raise SynapseError(
- 400, "%s must be specified only once" % (name,)
- )
- elif len(lst) == 1:
- return lst[0]
- else:
- return default
-
- direction = get_param("dir", 'f')
- if direction not in ['f', 'b']:
- raise SynapseError(400, "'dir' parameter is invalid.")
-
- from_tok = get_param("from")
- to_tok = get_param("to")
+ direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
+
+ from_tok = parse_string(request, "from")
+ to_tok = parse_string(request, "to")
try:
if from_tok == "END":
@@ -88,12 +76,10 @@ class PaginationConfig(object):
except Exception:
raise SynapseError(400, "'to' paramater is invalid")
- limit = get_param("limit", None)
- if limit is not None and not limit.isdigit():
- raise SynapseError(400, "'limit' parameter must be an integer.")
+ limit = parse_integer(request, "limit", default=default_limit)
- if limit is None:
- limit = default_limit
+ if limit and limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 8637867c6d..f2bde74dc5 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -74,14 +74,13 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- not_known_entities = set(entities) - set(self._entity_to_key)
-
- result = (
- {self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos))}
- .intersection(entities)
- .union(not_known_entities)
- )
+ changed_entities = {
+ self._cache[k] for k in self._cache.islice(
+ start=self._cache.bisect_right(stream_pos),
+ )
+ }
+
+ result = changed_entities.intersection(entities)
self.metrics.inc_hits()
else:
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 5ac33b2132..7deb38f2a7 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -92,13 +92,22 @@ class _PerHostRatelimiter(object):
self.window_size = window_size
self.sleep_limit = sleep_limit
- self.sleep_msec = sleep_msec
+ self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
+ # request_id objects for requests which have been slept
self.sleeping_requests = set()
+
+ # map from request_id object to Deferred for requests which are ready
+ # for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
+
+ # request id objects for requests which are in progress
self.current_processing = set()
+
+ # times at which we have recently (within the last window_size ms)
+ # received requests.
self.request_times = []
@contextlib.contextmanager
@@ -117,11 +126,15 @@ class _PerHostRatelimiter(object):
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
+
+ # remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
+ # reject the request if we already have too many queued up (either
+ # sleeping or in the ready queue).
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
@@ -134,9 +147,13 @@ class _PerHostRatelimiter(object):
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
- logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
+ logger.info(
+ "Ratelimiter: queueing request (queue now %i items)",
+ len(self.ready_request_queue),
+ )
+
return queue_defer
else:
return defer.succeed(None)
@@ -148,10 +165,9 @@ class _PerHostRatelimiter(object):
if len(self.request_times) > self.sleep_limit:
logger.debug(
- "Ratelimit [%s]: sleeping req",
- id(request_id),
+ "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
)
- ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
@@ -200,11 +216,8 @@ class _PerHostRatelimiter(object):
)
self.current_processing.discard(request_id)
try:
- request_id, deferred = self.ready_request_queue.popitem()
-
- # XXX: why do we do the following? the on_start callback above will
- # do it for us.
- self.current_processing.add(request_id)
+ # start processing the next item on the queue.
+ _, deferred = self.ready_request_queue.popitem(last=False)
with PreserveLoggingContext():
deferred.callback(None)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 015c2bab37..9b97ea2b83 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,10 +16,13 @@ import itertools
import logging
import operator
+import six
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.types import get_domain_from_id
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
logger = logging.getLogger(__name__)
@@ -225,3 +228,154 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
# we turn it into a list before returning it.
defer.returnValue(list(filtered_events))
+
+
+@defer.inlineCallbacks
+def filter_events_for_server(store, server_name, events):
+ # Whatever else we do, we need to check for senders which have requested
+ # erasure of their data.
+ erased_senders = yield store.are_users_erased(
+ e.sender for e in events,
+ )
+
+ def redact_disallowed(event, state):
+ # if the sender has been gdpr17ed, always return a redacted
+ # copy of the event.
+ if erased_senders[event.sender]:
+ logger.info(
+ "Sender of %s has been erased, redacting",
+ event.event_id,
+ )
+ return prune_event(event)
+
+ # state will be None if we decided we didn't need to filter by
+ # room membership.
+ if not state:
+ return event
+
+ history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ if history:
+ visibility = history.content.get("history_visibility", "shared")
+ if visibility in ["invited", "joined"]:
+ # We now loop through all state events looking for
+ # membership states for the requesting server to determine
+ # if the server is either in the room or has been invited
+ # into the room.
+ for ev in state.itervalues():
+ if ev.type != EventTypes.Member:
+ continue
+ try:
+ domain = get_domain_from_id(ev.state_key)
+ except Exception:
+ continue
+
+ if domain != server_name:
+ continue
+
+ memtype = ev.membership
+ if memtype == Membership.JOIN:
+ return event
+ elif memtype == Membership.INVITE:
+ if visibility == "invited":
+ return event
+ else:
+ # server has no users in the room: redact
+ return prune_event(event)
+
+ return event
+
+ # Next lets check to see if all the events have a history visibility
+ # of "shared" or "world_readable". If thats the case then we don't
+ # need to check membership (as we know the server is in the room).
+ event_to_state_ids = yield store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
+ )
+
+ visibility_ids = set()
+ for sids in event_to_state_ids.itervalues():
+ hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist:
+ visibility_ids.add(hist)
+
+ # If we failed to find any history visibility events then the default
+ # is "shared" visiblity.
+ if not visibility_ids:
+ all_open = True
+ else:
+ event_map = yield store.get_events(visibility_ids)
+ all_open = all(
+ e.content.get("history_visibility") in (None, "shared", "world_readable")
+ for e in event_map.itervalues()
+ )
+
+ if all_open:
+ # all the history_visibility state affecting these events is open, so
+ # we don't need to filter by membership state. We *do* need to check
+ # for user erasure, though.
+ if erased_senders:
+ events = [
+ redact_disallowed(e, None)
+ for e in events
+ ]
+
+ defer.returnValue(events)
+
+ # Ok, so we're dealing with events that have non-trivial visibility
+ # rules, so we need to also get the memberships of the room.
+
+ # first, for each event we're wanting to return, get the event_ids
+ # of the history vis and membership state at those events.
+ event_to_state_ids = yield store.get_state_ids_for_events(
+ frozenset(e.event_id for e in events),
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ )
+ )
+
+ # We only want to pull out member events that correspond to the
+ # server's domain.
+ #
+ # event_to_state_ids contains lots of duplicates, so it turns out to be
+ # cheaper to build a complete set of unique
+ # ((type, state_key), event_id) tuples, and then filter out the ones we
+ # don't want.
+ #
+ state_key_to_event_id_set = {
+ e
+ for key_to_eid in six.itervalues(event_to_state_ids)
+ for e in key_to_eid.items()
+ }
+
+ def include(typ, state_key):
+ if typ != EventTypes.Member:
+ return True
+
+ # we avoid using get_domain_from_id here for efficiency.
+ idx = state_key.find(":")
+ if idx == -1:
+ return False
+ return state_key[idx + 1:] == server_name
+
+ event_map = yield store.get_events([
+ e_id
+ for key, e_id in state_key_to_event_id_set
+ if include(key[0], key[1])
+ ])
+
+ event_to_state = {
+ e_id: {
+ key: event_map[inner_e_id]
+ for key, inner_e_id in key_to_eid.iteritems()
+ if inner_e_id in event_map
+ }
+ for e_id, key_to_eid in event_to_state_ids.iteritems()
+ }
+
+ defer.returnValue([
+ redact_disallowed(e, event_to_state[e.event_id])
+ for e in events
+ ])
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index eee99ca2e0..34e68ae82f 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
- self.cache = HttpTransactionCache(self.clock)
+ self.hs = Mock()
+ self.hs.get_clock = Mock(return_value=self.clock)
+ self.hs.get_auth = Mock()
+ self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index f596acb85f..f15fb36213 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -17,26 +17,22 @@ import json
from mock import Mock
-from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactorClock
-from synapse.rest.client.v1.register import CreateUserRestServlet
+from synapse.http.server import JsonResource
+from synapse.rest.client.v1.register import register_servlets
+from synapse.util import Clock
from tests import unittest
-from tests.utils import mock_getRawHeaders
+from tests.server import make_request, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase):
+ """
+ Tests for CreateUserRestServlet.
+ """
def setUp(self):
- # do the dance to hook up request data to self.request_data
- self.request_data = ""
- self.request = Mock(
- content=Mock(read=Mock(side_effect=lambda: self.request_data)),
- path='/_matrix/client/api/v1/createUser'
- )
- self.request.args = {}
- self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-
self.registration_handler = Mock()
self.appservice = Mock(sender="@as:test")
@@ -44,39 +40,49 @@ class CreateUserServletTestCase(unittest.TestCase):
get_app_service_by_token=Mock(return_value=self.appservice)
)
- # do the dance to hook things up to the hs global
- handlers = Mock(
- registration_handler=self.registration_handler,
+ handlers = Mock(registration_handler=self.registration_handler)
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
- self.hs = Mock()
- self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
- self.servlet = CreateUserRestServlet(self.hs)
- @defer.inlineCallbacks
def test_POST_createuser_with_valid_user(self):
+
+ res = JsonResource(self.hs)
+ register_servlets(self.hs, res)
+
+ request_data = json.dumps(
+ {
+ "localpart": "someone",
+ "displayname": "someone interesting",
+ "duration_seconds": 200,
+ }
+ )
+
+ url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
+
user_id = "@someone:interesting"
token = "my token"
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
- self.request_data = json.dumps({
- "localpart": "someone",
- "displayname": "someone interesting",
- "duration_seconds": 200
- })
self.registration_handler.get_or_create_user = Mock(
return_value=(user_id, token)
)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
+ request, channel = make_request(b"POST", url, request_data)
+ request.render(res)
+
+ # Advance the clock because it waits
+ self.clock.advance(1)
+
+ self.assertEquals(channel.result["code"], b"200")
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "home_server": self.hs.hostname,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 895dffa095..6b5764095e 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -25,949 +25,772 @@ from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.api.constants import Membership
+from synapse.http.server import JsonResource
from synapse.types import UserID
+from synapse.util import Clock
-from ....utils import MockHttpResource, setup_test_homeserver
-from .utils import RestTestCase
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
-PATH_PREFIX = "/_matrix/client/api/v1"
+from .utils import RestHelper
+PATH_PREFIX = b"/_matrix/client/api/v1"
-class RoomPermissionsTestCase(RestTestCase):
- """ Tests room permissions. """
- user_id = "@sid1:red"
- rmcreator_id = "@notme:red"
- @defer.inlineCallbacks
+class RoomBase(unittest.TestCase):
+ rmcreator_id = None
+
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- hs = yield setup_test_homeserver(
+ self.clock = ThreadedMemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = setup_test_homeserver(
"red",
http_client=None,
+ clock=self.hs_clock,
+ reactor=self.clock,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
- self.ratelimiter = hs.get_ratelimiter()
+ self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
- hs.get_handlers().federation_handler = Mock()
+ self.hs.get_federation_handler = Mock(return_value=Mock())
def get_user_by_access_token(token=None, allow_guest=False):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
+
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return synapse.types.create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
+ )
+
+ self.hs.get_auth().get_user_by_req = get_user_by_req
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
+ self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
- self.auth_user_id = self.rmcreator_id
+ self.hs.get_datastore().insert_client_ip = _insert_client_ip
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ self.resource = JsonResource(self.hs)
+ synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
+ self.helper = RestHelper(self.hs, self.resource, self.user_id)
- self.auth = hs.get_auth()
- # create some rooms under the name rmcreator_id
- self.uncreated_rmid = "!aa:test"
+class RoomPermissionsTestCase(RoomBase):
+ """ Tests room permissions. """
- self.created_rmid = yield self.create_room_as(self.rmcreator_id,
- is_public=False)
+ user_id = b"@sid1:red"
+ rmcreator_id = b"@notme:red"
+
+ def setUp(self):
- self.created_public_rmid = yield self.create_room_as(self.rmcreator_id,
- is_public=True)
+ super(RoomPermissionsTestCase, self).setUp()
+
+ self.helper.auth_user_id = self.rmcreator_id
+ # create some rooms under the name rmcreator_id
+ self.uncreated_rmid = "!aa:test"
+ self.created_rmid = self.helper.create_room_as(
+ self.rmcreator_id, is_public=False
+ )
+ self.created_public_rmid = self.helper.create_room_as(
+ self.rmcreator_id, is_public=True
+ )
# send a message in one of the rooms
self.created_rmid_msg_path = (
- "/rooms/%s/send/m.room.message/a1" % (self.created_rmid)
- )
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
+ "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
+ ).encode('ascii')
+ request, channel = make_request(
+ b"PUT",
self.created_rmid_msg_path,
- '{"msgtype":"m.text","body":"test msg"}'
+ b'{"msgtype":"m.text","body":"test msg"}',
)
- self.assertEquals(200, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
# set topic for public room
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
- '{"topic":"Public Room Topic"}'
+ request, channel = make_request(
+ b"PUT",
+ ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
+ b'{"topic":"Public Room Topic"}',
)
- self.assertEquals(200, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
# auth as user_id now
- self.auth_user_id = self.user_id
-
- def tearDown(self):
- pass
+ self.helper.auth_user_id = self.user_id
- @defer.inlineCallbacks
def test_send_message(self):
- msg_content = '{"msgtype":"m.text","body":"hello"}'
- send_msg_path = (
- "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,)
- )
+ msg_content = b'{"msgtype":"m.text","body":"hello"}'
+
+ seq = iter(range(100))
+
+ def send_msg_path():
+ return b"/rooms/%s/send/m.room.message/mid%s" % (
+ self.created_rmid,
+ str(next(seq)).encode('ascii'),
+ )
# send message in uncreated room, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
- msg_content
+ request, channel = make_request(
+ b"PUT",
+ b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
+ msg_content,
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and invited, expect 403
- yield self.invite(
- room=self.created_rmid,
- src=self.rmcreator_id,
- targ=self.user_id
- )
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
+ self.helper.invite(
+ room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and joined, expect 200
- yield self.join(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(200, code, msg=str(response))
+ self.helper.join(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and left, expect 403
- yield self.leave(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(403, code, msg=str(response))
+ self.helper.leave(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_topic_perms(self):
- topic_content = '{"topic":"My Topic Name"}'
- topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
+ topic_content = b'{"topic":"My Topic Name"}'
+ topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid,
- topic_content
+ request, channel = make_request(
+ b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(
+ b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
- yield self.invite(
+ self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
- yield self.join(room=self.created_rmid, user=self.user_id)
+ self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default
- self.auth_user_id = self.rmcreator_id
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(200, code, msg=str(response))
- self.auth_user_id = self.user_id
+ self.helper.auth_user_id = self.rmcreator_id
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.helper.auth_user_id = self.user_id
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(topic_content), response)
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(topic_content), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
- yield self.leave(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(200, code, msg=str(response))
+ self.helper.leave(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid
+ request, channel = make_request(
+ b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
- topic_content
+ request, channel = make_request(
+ b"PUT",
+ b"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
+ topic_content,
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
- path = "/rooms/%s/state/m.room.member/%s" % (room, member)
- (code, response) = yield self.mock_resource.trigger_get(path)
- self.assertEquals(expect_code, code)
+ path = b"/rooms/%s/state/m.room.member/%s" % (room, member)
+ request, channel = make_request(b"GET", path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(expect_code, int(channel.result["code"]))
- @defer.inlineCallbacks
def test_membership_basic_room_perms(self):
# === room does not exist ===
room = self.uncreated_rmid
# get membership of self, get membership of other, uncreated room
# expect all 403s
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# trying to invite people to this room should 403
- yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id,
- expect_code=403)
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403
+ )
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
- yield self.join(room=room, user=usr, expect_code=404)
- yield self.leave(room=room, user=usr, expect_code=404)
+ self.helper.join(room=room, user=usr, expect_code=404)
+ self.helper.leave(room=room, user=usr, expect_code=404)
- @defer.inlineCallbacks
def test_membership_private_room_perms(self):
room = self.created_rmid
# get membership of self, get membership of other, private room + invite
# expect all 403s
- yield self.invite(room=room, src=self.rmcreator_id,
- targ=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# get membership of self, get membership of other, private room + joined
# expect all 200s
- yield self.join(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.join(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
# get membership of self, get membership of other, private room + left
# expect all 200s
- yield self.leave(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.leave(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
- @defer.inlineCallbacks
def test_membership_public_room_perms(self):
room = self.created_public_rmid
# get membership of self, get membership of other, public room + invite
# expect 403
- yield self.invite(room=room, src=self.rmcreator_id,
- targ=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# get membership of self, get membership of other, public room + joined
# expect all 200s
- yield self.join(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.join(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
# get membership of self, get membership of other, public room + left
# expect 200.
- yield self.leave(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.leave(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
- @defer.inlineCallbacks
def test_invited_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
# set [invite/join/left] of other user, expect 403s
- yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id,
- expect_code=403)
- yield self.change_membership(room=room, src=self.user_id,
- targ=self.rmcreator_id,
- membership=Membership.JOIN,
- expect_code=403)
- yield self.change_membership(room=room, src=self.user_id,
- targ=self.rmcreator_id,
- membership=Membership.LEAVE,
- expect_code=403)
-
- @defer.inlineCallbacks
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403
+ )
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=self.rmcreator_id,
+ membership=Membership.JOIN,
+ expect_code=403,
+ )
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=self.rmcreator_id,
+ membership=Membership.LEAVE,
+ expect_code=403,
+ )
+
def test_joined_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
- yield self.join(room=room, user=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
# set invited of self, expect 403
- yield self.invite(room=room, src=self.user_id, targ=self.user_id,
- expect_code=403)
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.user_id, expect_code=403
+ )
# set joined of self, expect 200 (NOOP)
- yield self.join(room=room, user=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
other = "@burgundy:red"
# set invited of other, expect 200
- yield self.invite(room=room, src=self.user_id, targ=other,
- expect_code=200)
+ self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200)
# set joined of other, expect 403
- yield self.change_membership(room=room, src=self.user_id,
- targ=other,
- membership=Membership.JOIN,
- expect_code=403)
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.JOIN,
+ expect_code=403,
+ )
# set left of other, expect 403
- yield self.change_membership(room=room, src=self.user_id,
- targ=other,
- membership=Membership.LEAVE,
- expect_code=403)
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=403,
+ )
# set left of self, expect 200
- yield self.leave(room=room, user=self.user_id)
+ self.helper.leave(room=room, user=self.user_id)
- @defer.inlineCallbacks
def test_leave_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
- yield self.join(room=room, user=self.user_id)
- yield self.leave(room=room, user=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
+ self.helper.leave(room=room, user=self.user_id)
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s
for usr in [self.user_id, self.rmcreator_id]:
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403
+ expect_code=403,
)
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403
+ expect_code=403,
)
# It is always valid to LEAVE if you've already left (currently.)
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403
+ expect_code=403,
)
-class RoomsMemberListTestCase(RestTestCase):
+class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events."""
- user_id = "@sid1:red"
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- self.auth_user_id = self.user_id
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ user_id = b"@sid1:red"
- def tearDown(self):
- pass
-
- @defer.inlineCallbacks
def test_get_member_list(self):
- room_id = yield self.create_room_as(self.user_id)
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/members" % room_id
- )
- self.assertEquals(200, code, msg=str(response))
+ room_id = self.helper.create_room_as(self.user_id)
+ request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_no_room(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/roomdoesnotexist/members"
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members")
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_no_permission(self):
- room_id = yield self.create_room_as("@some_other_guy:red")
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/members" % room_id
- )
- self.assertEquals(403, code, msg=str(response))
+ room_id = self.helper.create_room_as(b"@some_other_guy:red")
+ request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_mixed_memberships(self):
- room_creator = "@some_other_guy:red"
- room_id = yield self.create_room_as(room_creator)
- room_path = "/rooms/%s/members" % room_id
- yield self.invite(room=room_id, src=room_creator,
- targ=self.user_id)
+ room_creator = b"@some_other_guy:red"
+ room_id = self.helper.create_room_as(room_creator)
+ room_path = b"/rooms/%s/members" % room_id
+ self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- yield self.join(room=room_id, user=self.user_id)
+ self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- yield self.leave(room=room_id, user=self.user_id)
+ self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
-class RoomsCreateTestCase(RestTestCase):
+class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """
- user_id = "@sid1:red"
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ user_id = b"@sid1:red"
- @defer.inlineCallbacks
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger("POST",
- "/createRoom",
- "{}")
- self.assertEquals(200, code, response)
- self.assertTrue("room_id" in response)
+ request, channel = make_request(b"POST", b"/createRoom", b"{}")
+
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), channel.result)
+ self.assertTrue("room_id" in channel.json_body)
- @defer.inlineCallbacks
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibility":"private"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"POST", b"/createRoom", b'{"visibility":"private"}'
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"custom":"stuff"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibility":"private","custom":"things"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}'
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibili')
- self.assertEquals(400, code)
+ request, channel = make_request(b"POST", b"/createRoom", b'{"visibili')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]))
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '["hello"]')
- self.assertEquals(400, code)
+ request, channel = make_request(b"POST", b"/createRoom", b'["hello"]')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]))
-class RoomTopicTestCase(RestTestCase):
+class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
- user_id = "@sid1:red"
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
+ user_id = b"@sid1:red"
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
-
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
+ def setUp(self):
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ super(RoomTopicTestCase, self).setUp()
# create the room
- self.room_id = yield self.create_room_as(self.user_id)
- self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
-
- def tearDown(self):
- pass
+ self.room_id = self.helper.create_room_as(self.user_id)
+ self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,)
- @defer.inlineCallbacks
def test_invalid_puts(self):
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_topic(self):
# nothing should be there
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(404, code, msg=str(response))
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# valid get
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(content), channel.json_body)
- @defer.inlineCallbacks
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# valid get
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(content), channel.json_body)
-class RoomMemberStateTestCase(RestTestCase):
+class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """
- user_id = "@sid1:red"
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
- hs.get_handlers().federation_handler = Mock()
+ user_id = b"@sid1:red"
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ def setUp(self):
- self.room_id = yield self.create_room_as(self.user_id)
+ super(RoomMemberStateTestCase, self).setUp()
+ self.room_id = self.helper.create_room_as(self.user_id)
def tearDown(self):
pass
- @defer.inlineCallbacks
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger("PUT", path, '{}')
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# valid keys, wrong types
- content = ('{"membership":["%s","%s","%s"]}' % (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE
- ))
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(400, code, msg=str(response))
+ content = '{"membership":["%s","%s","%s"]}' % (
+ Membership.INVITE,
+ Membership.JOIN,
+ Membership.LEAVE,
+ )
+ request, channel = make_request(b"PUT", path, content.encode('ascii'))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.user_id
+ urlparse.quote(self.room_id),
+ self.user_id,
)
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content.encode('ascii'))
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- expected_response = {
- "membership": Membership.JOIN,
- }
- self.assertEquals(expected_response, response)
+ expected_response = {"membership": Membership.JOIN}
+ self.assertEquals(expected_response, channel.json_body)
- @defer.inlineCallbacks
def test_rooms_members_other(self):
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.other_id
+ urlparse.quote(self.room_id),
+ self.other_id,
)
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
- self.assertEquals(json.loads(content), response)
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(json.loads(content), channel.json_body)
- @defer.inlineCallbacks
def test_rooms_members_other_custom_keys(self):
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.other_id
+ urlparse.quote(self.room_id),
+ self.other_id,
)
# valid invite message with custom key
- content = ('{"membership":"%s","invite_text":"%s"}' % (
- Membership.INVITE, "Join us!"
- ))
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ content = '{"membership":"%s","invite_text":"%s"}' % (
+ Membership.INVITE,
+ "Join us!",
+ )
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
- self.assertEquals(json.loads(content), response)
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(json.loads(content), channel.json_body)
-class RoomMessagesTestCase(RestTestCase):
+class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
+ super(RoomMessagesTestCase, self).setUp()
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
+ self.room_id = self.helper.create_room_as(self.user_id)
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- self.room_id = yield self.create_room_as(self.user_id)
-
- def tearDown(self):
- pass
-
- @defer.inlineCallbacks
def test_invalid_puts(self):
- path = "/rooms/%s/send/m.room.message/mid1" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_messages_sent(self):
- path = "/rooms/%s/send/m.room.message/mid1" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = '{"body":"test","msgtype":{"type":"a"}}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# custom message types
content = '{"body":"test","msgtype":"test.custom.text"}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
-
-# (code, response) = yield self.mock_resource.trigger("GET", path, None)
-# self.assertEquals(200, code, msg=str(response))
-# self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# m.text message type
- path = "/rooms/%s/send/m.room.message/mid2" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = '{"body":"test2","msgtype":"m.text"}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
-class RoomInitialSyncTestCase(RestTestCase):
+class RoomInitialSyncTestCase(RoomBase):
""" Tests /rooms/$room_id/initialSync. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ super(RoomInitialSyncTestCase, self).setUp()
# create the room
- self.room_id = yield self.create_room_as(self.user_id)
+ self.room_id = self.helper.create_room_as(self.user_id)
- @defer.inlineCallbacks
def test_initial_sync(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/initialSync" % self.room_id
- )
- self.assertEquals(200, code)
+ request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
- self.assertEquals(self.room_id, response["room_id"])
- self.assertEquals("join", response["membership"])
+ self.assertEquals(self.room_id, channel.json_body["room_id"])
+ self.assertEquals("join", channel.json_body["membership"])
# Room state is easier to assert on if we unpack it into a dict
state = {}
- for event in response["state"]:
+ for event in channel.json_body["state"]:
if "state_key" not in event:
continue
t = event["type"]
@@ -977,75 +800,48 @@ class RoomInitialSyncTestCase(RestTestCase):
self.assertTrue("m.room.create" in state)
- self.assertTrue("messages" in response)
- self.assertTrue("chunk" in response["messages"])
- self.assertTrue("end" in response["messages"])
+ self.assertTrue("messages" in channel.json_body)
+ self.assertTrue("chunk" in channel.json_body["messages"])
+ self.assertTrue("end" in channel.json_body["messages"])
- self.assertTrue("presence" in response)
+ self.assertTrue("presence" in channel.json_body)
presence_by_user = {
- e["content"]["user_id"]: e for e in response["presence"]
+ e["content"]["user_id"]: e for e in channel.json_body["presence"]
}
self.assertTrue(self.user_id in presence_by_user)
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
-class RoomMessageListTestCase(RestTestCase):
+class RoomMessageListTestCase(RoomBase):
""" Tests /rooms/$room_id/messages REST events. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
+ super(RoomMessageListTestCase, self).setUp()
+ self.room_id = self.helper.create_room_as(self.user_id)
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- self.room_id = yield self.create_room_as(self.user_id)
-
- @defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/messages?access_token=x&from=%s" %
- (self.room_id, token))
- self.assertEquals(200, code)
- self.assertTrue("start" in response)
- self.assertEquals(token, response['start'])
- self.assertTrue("chunk" in response)
- self.assertTrue("end" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("start" in channel.json_body)
+ self.assertEquals(token, channel.json_body['start'])
+ self.assertTrue("chunk" in channel.json_body)
+ self.assertTrue("end" in channel.json_body)
+
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/messages?access_token=x&from=%s" %
- (self.room_id, token))
- self.assertEquals(200, code)
- self.assertTrue("start" in response)
- self.assertEquals(token, response['start'])
- self.assertTrue("chunk" in response)
- self.assertTrue("end" in response)
+ request, channel = make_request(
+ b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("start" in channel.json_body)
+ self.assertEquals(token, channel.json_body['start'])
+ self.assertTrue("chunk" in channel.json_body)
+ self.assertTrue("end" in channel.json_body)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 54d7ba380d..41de8e0762 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -16,13 +16,14 @@
import json
import time
-# twisted imports
+import attr
+
from twisted.internet import defer
from synapse.api.constants import Membership
-# trial imports
from tests import unittest
+from tests.server import make_request, wait_until_result
class RestTestCase(unittest.TestCase):
@@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
for key in required:
self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual))
+
+
+@attr.s
+class RestHelper(object):
+ """Contains extra helper functions to quickly and clearly perform a given
+ REST action, which isn't the focus of the test.
+ """
+
+ hs = attr.ib()
+ resource = attr.ib()
+ auth_user_id = attr.ib()
+
+ def create_room_as(self, room_creator, is_public=True, tok=None):
+ temp_id = self.auth_user_id
+ self.auth_user_id = room_creator
+ path = b"/_matrix/client/r0/createRoom"
+ content = {}
+ if not is_public:
+ content["visibility"] = "private"
+ if tok:
+ path = path + b"?access_token=%s" % tok.encode('ascii')
+
+ request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
+ request.render(self.resource)
+ wait_until_result(self.hs.get_reactor(), channel)
+
+ assert channel.result["code"] == b"200", channel.result
+ self.auth_user_id = temp_id
+ return channel.json_body["room_id"]
+
+ def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=src,
+ targ=targ,
+ tok=tok,
+ membership=Membership.INVITE,
+ expect_code=expect_code,
+ )
+
+ def join(self, room=None, user=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=user,
+ targ=user,
+ tok=tok,
+ membership=Membership.JOIN,
+ expect_code=expect_code,
+ )
+
+ def leave(self, room=None, user=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=user,
+ targ=user,
+ tok=tok,
+ membership=Membership.LEAVE,
+ expect_code=expect_code,
+ )
+
+ def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ temp_id = self.auth_user_id
+ self.auth_user_id = src
+
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ data = {"membership": membership}
+
+ request, channel = make_request(
+ b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
+ )
+
+ request.render(self.resource)
+ wait_until_result(self.hs.get_reactor(), channel)
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ self.auth_user_id = temp_id
+
+ @defer.inlineCallbacks
+ def register(self, user_id):
+ (code, response) = yield self.mock_resource.trigger(
+ "POST",
+ "/_matrix/client/r0/register",
+ json.dumps(
+ {"user": user_id, "password": "test", "type": "m.login.password"}
+ ),
+ )
+ self.assertEquals(200, code)
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+ if body is None:
+ body = "body_text_here"
+
+ path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
+ content = '{"msgtype":"m.text","body":"%s"}' % body
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ (code, response) = yield self.mock_resource.trigger("PUT", path, content)
+ self.assertEquals(expect_code, code, msg=str(response))
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index f18a8a6027..e69de29bb2 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -1,61 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.types import UserID
-
-from tests import unittest
-
-from ....utils import MockHttpResource, setup_test_homeserver
-
-PATH_PREFIX = "/_matrix/client/v2_alpha"
-
-
-class V2AlphaRestTestCase(unittest.TestCase):
- # Consumer must define
- # USER_ID = <some string>
- # TO_REGISTER = [<list of REST servlets to register>]
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
-
- hs = yield setup_test_homeserver(
- datastore=self.make_datastore_mock(),
- http_client=None,
- resource_for_client=self.mock_resource,
- resource_for_federation=self.mock_resource,
- )
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.USER_ID),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- for r in self.TO_REGISTER:
- r.register_servlets(hs, self.mock_resource)
-
- def make_datastore_mock(self):
- store = Mock(spec=[
- "insert_client_ip",
- ])
- store.get_app_service_by_token = Mock(return_value=None)
- return store
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index bb0b2f94ea..5ea9cc825f 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,35 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.types
from synapse.api.errors import Codes
+from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
-
-from ....utils import MockHttpResource, setup_test_homeserver
+from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
+from tests.server import make_request, setup_test_homeserver, wait_until_result
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
- USER_ID = "@apple:test"
+ USER_ID = b"@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
- EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
+ EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter]
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
- self.hs = yield setup_test_homeserver(
- http_client=None,
- resource_for_client=self.mock_resource,
- resource_for_federation=self.mock_resource,
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.auth = self.hs.get_auth()
@@ -55,82 +53,103 @@ class FilterTestCase(unittest.TestCase):
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
- UserID.from_string(self.USER_ID), 1, False, None)
+ UserID.from_string(self.USER_ID), 1, False, None
+ )
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
+ self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
- r.register_servlets(self.hs, self.mock_resource)
+ r.register_servlets(self.hs, self.resource)
- @defer.inlineCallbacks
def test_add_filter(self):
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
- )
- self.assertEquals(200, code)
- self.assertEquals({"filter_id": "0"}, response)
- filter = yield self.store.get_user_filter(
- user_localpart='apple',
- filter_id=0,
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ self.EXAMPLE_FILTER_JSON,
)
- self.assertEquals(filter, self.EXAMPLE_FILTER)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.json_body, {"filter_id": "0"})
+ filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ self.clock.advance(0)
+ self.assertEquals(filter.result, self.EXAMPLE_FILTER)
- @defer.inlineCallbacks
def test_add_filter_for_other_user(self):
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
+ self.EXAMPLE_FILTER_JSON,
)
- self.assertEquals(403, code)
- self.assertEquals(response['errcode'], Codes.FORBIDDEN)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"403")
+ self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
- @defer.inlineCallbacks
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ self.EXAMPLE_FILTER_JSON,
)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
self.hs.is_mine = _is_mine
- self.assertEquals(403, code)
- self.assertEquals(response['errcode'], Codes.FORBIDDEN)
+ self.assertEqual(channel.result["code"], b"403")
+ self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
- @defer.inlineCallbacks
def test_get_filter(self):
- filter_id = yield self.filtering.add_user_filter(
- user_localpart='apple',
- user_filter=self.EXAMPLE_FILTER
+ filter_id = self.filtering.add_user_filter(
+ user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/%s" % (self.USER_ID, filter_id)
+ self.clock.advance(1)
+ filter_id = filter_id.result
+ request, channel = make_request(
+ b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
- self.assertEquals(200, code)
- self.assertEquals(self.EXAMPLE_FILTER, response)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
- @defer.inlineCallbacks
def test_get_filter_non_existant(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/12382148321" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
- self.assertEquals(400, code)
- self.assertEquals(response['errcode'], Codes.NOT_FOUND)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
+ self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
- @defer.inlineCallbacks
def test_get_filter_invalid_id(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/foobar" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
- self.assertEquals(400, code)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
- @defer.inlineCallbacks
def test_get_filter_no_id(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
- self.assertEquals(400, code)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 9b57a56070..e004d8fc73 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -2,165 +2,192 @@ import json
from mock import Mock
-from twisted.internet import defer
from twisted.python import failure
+from twisted.test.proto_helpers import MemoryReactorClock
-from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError
-from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.util import Clock
from tests import unittest
-from tests.utils import mock_getRawHeaders
+from tests.server import make_request, setup_test_homeserver, wait_until_result
class RegisterRestServletTestCase(unittest.TestCase):
-
def setUp(self):
- # do the dance to hook up request data to self.request_data
- self.request_data = ""
- self.request = Mock(
- content=Mock(read=Mock(side_effect=lambda: self.request_data)),
- path='/_matrix/api/v2_alpha/register'
- )
- self.request.args = {}
- self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = b"/_matrix/client/r0/register"
self.appservice = None
- self.auth = Mock(get_appservice_by_req=Mock(
- side_effect=lambda x: self.appservice)
+ self.auth = Mock(
+ get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
- get_session_data=Mock(return_value=None)
+ get_session_data=Mock(return_value=None),
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
+ self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+ self.datastore = Mock(return_value=Mock())
+ self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global
self.handlers = Mock(
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
- login_handler=self.login_handler
+ login_handler=self.login_handler,
+ )
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
- self.hs = Mock()
- self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
+ self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- # init the thing we're testing
- self.servlet = RegisterRestServlet(self.hs)
+ self.resource = JsonResource(self.hs)
+ register_servlets(self.hs, self.resource)
- @defer.inlineCallbacks
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
- self.request_data = json.dumps({
- "username": "kermit"
- })
- self.appservice = {
- "id": "1234"
- }
- self.registration_handler.appservice_register = Mock(
- return_value=user_id
- )
- self.auth_handler.get_access_token_for_user_id = Mock(
- return_value=token
+ self.appservice = {"id": "1234"}
+ self.registration_handler.appservice_register = Mock(return_value=user_id)
+ self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+ request_data = json.dumps({"username": "kermit"})
+
+ request, channel = make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "home_server": self.hs.hostname,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
- @defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
-
- self.request_data = json.dumps({
- "username": "kermit"
- })
self.appservice = None # no application service exists
- result = yield self.servlet.on_POST(self.request)
- self.assertEquals(result, (401, None))
+ request_data = json.dumps({"username": "kermit"})
+ request, channel = make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
- self.request_data = json.dumps({
- "username": "kermit",
- "password": 666
- })
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
+ request_data = json.dumps({"username": "kermit", "password": 666})
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Invalid password"
+ )
def test_POST_bad_username(self):
- self.request_data = json.dumps({
- "username": 777,
- "password": "monkey"
- })
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
-
- @defer.inlineCallbacks
+ request_data = json.dumps({"username": 777, "password": "monkey"})
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Invalid username"
+ )
+
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
device_id = "frogfone"
- self.request_data = json.dumps({
- "username": "kermit",
- "password": "monkey",
- "device_id": device_id,
- })
+ request_data = json.dumps(
+ {"username": "kermit", "password": "monkey", "device_id": device_id}
+ )
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (None, {
- "username": "kermit",
- "password": "monkey"
- }, None)
+ self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
- self.auth_handler.get_access_token_for_user_id = Mock(
- return_value=token
- )
- self.device_handler.check_device_registered = \
- Mock(return_value=device_id)
+ self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+ self.device_handler.check_device_registered = Mock(return_value=device_id)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
self.auth_handler.get_login_tuple_for_user_id(
- user_id, device_id=device_id, initial_device_display_name=None)
+ user_id, device_id=device_id, initial_device_display_name=None
+ )
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
- self.request_data = json.dumps({
- "username": "kermit",
- "password": "monkey"
- })
+ request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (None, {
- "username": "kermit",
- "password": "monkey"
- }, None)
+ self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"],
+ "Registration has been disabled",
+ )
+
+ def test_POST_guest_registration(self):
+ user_id = "a@b"
+ self.hs.config.macaroon_secret_key = "test"
+ self.hs.config.allow_guest_access = True
+ self.registration_handler.register = Mock(return_value=(user_id, None))
+
+ request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ det_data = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ "device_id": "guest_device",
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
+
+ def test_POST_disabled_guest_registration(self):
+ self.hs.config.allow_guest_access = False
+
+ request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Guest access is disabled"
+ )
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
new file mode 100644
index 0000000000..704cf97a40
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.types
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha import sync
+from synapse.types import UserID
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
+from tests.server import make_request, setup_test_homeserver, wait_until_result
+
+PATH_PREFIX = "/_matrix/client/v2_alpha"
+
+
+class FilterTestCase(unittest.TestCase):
+
+ USER_ID = b"@apple:test"
+ TO_REGISTER = [sync]
+
+ def setUp(self):
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
+ )
+
+ self.auth = self.hs.get_auth()
+
+ def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.USER_ID),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return synapse.types.create_requester(
+ UserID.from_string(self.USER_ID), 1, False, None
+ )
+
+ self.auth.get_user_by_access_token = get_user_by_access_token
+ self.auth.get_user_by_req = get_user_by_req
+
+ self.store = self.hs.get_datastore()
+ self.filtering = self.hs.get_filtering()
+ self.resource = JsonResource(self.hs)
+
+ for r in self.TO_REGISTER:
+ r.register_servlets(self.hs, self.resource)
+
+ def test_sync_argless(self):
+ request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertTrue(
+ set(
+ [
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ ]
+ ).issubset(set(channel.json_body.keys()))
+ )
diff --git a/tests/server.py b/tests/server.py
index e93f5a7f94..c611dd6059 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -80,6 +80,11 @@ def make_request(method, path, content=b""):
content, and return the Request and the Channel underneath.
"""
+ # Decorate it to be the full path
+ if not path.startswith(b"/_matrix"):
+ path = b"/_matrix/client/r0/" + path
+ path = path.replace("//", "/")
+
if isinstance(content, text_type):
content = content.encode('utf8')
@@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100):
clock.advance(0.1)
+def render(request, resource, clock):
+ request.render(resource)
+ wait_until_result(clock, request._channel)
+
+
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
diff --git a/tests/test_server.py b/tests/test_server.py
index 4192013f6d..7e063c0290 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase):
return (200, kwargs)
res = JsonResource(self.homeserver)
- res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
+ )
- request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
+ request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
@@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase):
raise Exception("boo")
res = JsonResource(self.homeserver)
- res.register_paths("GET", [re.compile("^/foo$")], _callback)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/foo")
+ request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
self.assertEqual(channel.result["code"], b'500')
@@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase):
return d
res = JsonResource(self.homeserver)
- res.register_paths("GET", [re.compile("^/foo$")], _callback)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/foo")
+ request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
# No error has been raised yet
@@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver)
- res.register_paths("GET", [re.compile("^/foo$")], _callback)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/foo")
+ request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
self.assertEqual(channel.result["code"], b'403')
@@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase):
self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver)
- res.register_paths("GET", [re.compile("^/foo$")], _callback)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/foobar")
+ request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res)
self.assertEqual(channel.result["code"], b'400')
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
new file mode 100644
index 0000000000..0dc1a924d3
--- /dev/null
+++ b/tests/test_visibility.py
@@ -0,0 +1,324 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+from twisted.internet.defer import succeed
+
+from synapse.events import FrozenEvent
+from synapse.visibility import filter_events_for_server
+
+import tests.unittest
+from tests.utils import setup_test_homeserver
+
+logger = logging.getLogger(__name__)
+
+TEST_ROOM_ID = "!TEST:ROOM"
+
+
+class FilterEventsForServerTestCase(tests.unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.event_creation_handler = self.hs.get_event_creation_handler()
+ self.event_builder_factory = self.hs.get_event_builder_factory()
+ self.store = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_filtering(self):
+ #
+ # The events to be filtered consist of 10 membership events (it doesn't
+ # really matter if they are joins or leaves, so let's make them joins).
+ # One of those membership events is going to be for a user on the
+ # server we are filtering for (so we can check the filtering is doing
+ # the right thing).
+ #
+
+ # before we do that, we persist some other events to act as state.
+ self.inject_visibility("@admin:hs", "joined")
+ for i in range(0, 10):
+ yield self.inject_room_member("@resident%i:hs" % i)
+
+ events_to_filter = []
+
+ for i in range(0, 10):
+ user = "@user%i:%s" % (
+ i, "test_server" if i == 5 else "other_server"
+ )
+ evt = yield self.inject_room_member(user, extra_content={"a": "b"})
+ events_to_filter.append(evt)
+
+ filtered = yield filter_events_for_server(
+ self.store, "test_server", events_to_filter,
+ )
+
+ # the result should be 5 redacted events, and 5 unredacted events.
+ for i in range(0, 5):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertNotIn("a", filtered[i].content)
+
+ for i in range(5, 10):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertEqual(filtered[i].content["a"], "b")
+
+ @tests.unittest.DEBUG
+ @defer.inlineCallbacks
+ def test_erased_user(self):
+ # 4 message events, from erased and unerased users, with a membership
+ # change in the middle of them.
+ events_to_filter = []
+
+ evt = yield self.inject_message("@unerased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@erased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_room_member("@joiner:remote_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@unerased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@erased:local_hs")
+ events_to_filter.append(evt)
+
+ # the erasey user gets erased
+ self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+
+ # ... and the filtering happens.
+ filtered = yield filter_events_for_server(
+ self.store, "test_server", events_to_filter,
+ )
+
+ for i in range(0, len(events_to_filter)):
+ self.assertEqual(
+ events_to_filter[i].event_id, filtered[i].event_id,
+ "Unexpected event at result position %i" % (i, )
+ )
+
+ for i in (0, 3):
+ self.assertEqual(
+ events_to_filter[i].content["body"], filtered[i].content["body"],
+ "Unexpected event content at result position %i" % (i,)
+ )
+
+ for i in (1, 4):
+ self.assertNotIn("body", filtered[i].content)
+
+ @defer.inlineCallbacks
+ def inject_visibility(self, user_id, visibility):
+ content = {"history_visibility": visibility}
+ builder = self.event_builder_factory.new({
+ "type": "m.room.history_visibility",
+ "sender": user_id,
+ "state_key": "",
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def inject_room_member(self, user_id, membership="join", extra_content={}):
+ content = {"membership": membership}
+ content.update(extra_content)
+ builder = self.event_builder_factory.new({
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def inject_message(self, user_id, content=None):
+ if content is None:
+ content = {"body": "testytest"}
+ builder = self.event_builder_factory.new({
+ "type": "m.room.message",
+ "sender": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def test_large_room(self):
+ # see what happens when we have a large room with hundreds of thousands
+ # of membership events
+
+ # As above, the events to be filtered consist of 10 membership events,
+ # where one of them is for a user on the server we are filtering for.
+
+ import cProfile
+ import pstats
+ import time
+
+ # we stub out the store, because building up all that state the normal
+ # way is very slow.
+ test_store = _TestStore()
+
+ # our initial state is 100000 membership events and one
+ # history_visibility event.
+ room_state = []
+
+ history_visibility_evt = FrozenEvent({
+ "event_id": "$history_vis",
+ "type": "m.room.history_visibility",
+ "sender": "@resident_user_0:test.com",
+ "state_key": "",
+ "room_id": TEST_ROOM_ID,
+ "content": {"history_visibility": "joined"},
+ })
+ room_state.append(history_visibility_evt)
+ test_store.add_event(history_visibility_evt)
+
+ for i in range(0, 100000):
+ user = "@resident_user_%i:test.com" % (i, )
+ evt = FrozenEvent({
+ "event_id": "$res_event_%i" % (i, ),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {
+ "membership": "join",
+ "extra": "zzz,"
+ },
+ })
+ room_state.append(evt)
+ test_store.add_event(evt)
+
+ events_to_filter = []
+ for i in range(0, 10):
+ user = "@user%i:%s" % (
+ i, "test_server" if i == 5 else "other_server"
+ )
+ evt = FrozenEvent({
+ "event_id": "$evt%i" % (i, ),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {
+ "membership": "join",
+ "extra": "zzz",
+ },
+ })
+ events_to_filter.append(evt)
+ room_state.append(evt)
+
+ test_store.add_event(evt)
+ test_store.set_state_ids_for_event(evt, {
+ (e.type, e.state_key): e.event_id for e in room_state
+ })
+
+ pr = cProfile.Profile()
+ pr.enable()
+
+ logger.info("Starting filtering")
+ start = time.time()
+ filtered = yield filter_events_for_server(
+ test_store, "test_server", events_to_filter,
+ )
+ logger.info("Filtering took %f seconds", time.time() - start)
+
+ pr.disable()
+ with open("filter_events_for_server.profile", "w+") as f:
+ ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
+ ps.print_stats()
+
+ # the result should be 5 redacted events, and 5 unredacted events.
+ for i in range(0, 5):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertNotIn("extra", filtered[i].content)
+
+ for i in range(5, 10):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertEqual(filtered[i].content["extra"], "zzz")
+
+ test_large_room.skip = "Disabled by default because it's slow"
+
+
+class _TestStore(object):
+ """Implements a few methods of the DataStore, so that we can test
+ filter_events_for_server
+
+ """
+ def __init__(self):
+ # data for get_events: a map from event_id to event
+ self.events = {}
+
+ # data for get_state_ids_for_events mock: a map from event_id to
+ # a map from (type_state_key) -> event_id for the state at that
+ # event
+ self.state_ids_for_events = {}
+
+ def add_event(self, event):
+ self.events[event.event_id] = event
+
+ def set_state_ids_for_event(self, event, state):
+ self.state_ids_for_events[event.event_id] = state
+
+ def get_state_ids_for_events(self, events, types):
+ res = {}
+ include_memberships = False
+ for (type, state_key) in types:
+ if type == "m.room.history_visibility":
+ continue
+ if type != "m.room.member" or state_key is not None:
+ raise RuntimeError(
+ "Unimplemented: get_state_ids with type (%s, %s)" %
+ (type, state_key),
+ )
+ include_memberships = True
+
+ if include_memberships:
+ for event_id in events:
+ res[event_id] = self.state_ids_for_events[event_id]
+
+ else:
+ k = ("m.room.history_visibility", "")
+ for event_id in events:
+ hve = self.state_ids_for_events[event_id][k]
+ res[event_id] = {k: hve}
+
+ return succeed(res)
+
+ def get_events(self, events):
+ return succeed({
+ event_id: self.events[event_id] for event_id in events
+ })
+
+ def are_users_erased(self, users):
+ return succeed({u: False for u in users})
diff --git a/tests/unittest.py b/tests/unittest.py
index b25f2db5d5..b15b06726b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -109,6 +109,17 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key)
+ def assert_dict(self, required, actual):
+ """Does a partial assert of a dict.
+
+ Args:
+ required (dict): The keys and value which MUST be in 'actual'.
+ actual (dict): The test result. Extra keys will not be checked.
+ """
+ for key in required:
+ self.assertEquals(required[key], actual[key],
+ msg="%s mismatch. %s" % (key, actual))
+
def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG.
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index e3897c0d19..65b0f2e6fb 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -141,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase):
)
# Query all the entries mid-way through the stream, but include one
- # that doesn't exist in it. We should get back the one that doesn't
- # exist, too.
+ # that doesn't exist in it. We shouldn't get back the one that doesn't
+ # exist.
self.assertEqual(
cache.get_entities_changed(
[
@@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+ set(["bar@baz.net", "user@elsewhere.org"]),
)
# Query all the entries, but before the first known point. We will get
@@ -178,6 +178,22 @@ class StreamChangeCacheTests(unittest.TestCase):
),
)
+ # Query a subset of the entries mid-way through the stream. We should
+ # only get back the subset.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "bar@baz.net",
+ ],
+ stream_pos=2,
+ ),
+ set(
+ [
+ "bar@baz.net",
+ ]
+ ),
+ )
+
def test_max_pos(self):
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
diff --git a/tests/utils.py b/tests/utils.py
index 6adbdbfca1..e488238bb3 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -65,6 +65,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.federation_domain_whitelist = None
config.federation_rc_reject_limit = 10
config.federation_rc_sleep_limit = 10
+ config.federation_rc_sleep_delay = 100
config.federation_rc_concurrent = 10
config.filter_timeline_limit = 5000
config.user_directory_search_all_users = False
diff --git a/tox.ini b/tox.ini
index 61a20a10cb..ed26644bd9 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = packaging, py27, py36, pep8
+envlist = packaging, py27, py36, pep8, check_isort
[testenv]
deps =
@@ -103,10 +103,14 @@ deps =
flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
+[testenv:check_isort]
+skip_install = True
+deps = isort
+commands = /bin/sh -c "isort -c -sp setup.cfg -rc synapse tests"
[testenv:check-newsfragment]
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
python -m towncrier.check --compare-with=origin/develop
-basepython = python3.6
\ No newline at end of file
+basepython = python3.6
|