diff --git a/AUTHORS.rst b/AUTHORS.rst
index e13ac5ad34..9a83d90153 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -62,4 +62,7 @@ Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication
Pierre Jaury <pierre at jaury.eu>
-* Docker packaging
\ No newline at end of file
+* Docker packaging
+
+Serban Constantin <serban.constantin at gmail dot com>
+ * Small bug fix
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 565341fee3..0242be5f68 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,16 +1,32 @@
FROM docker.io/python:2-alpine3.7
-RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev
+RUN apk add --no-cache --virtual .nacl_deps \
+ build-base \
+ libffi-dev \
+ libjpeg-turbo-dev \
+ libressl-dev \
+ libxslt-dev \
+ linux-headers \
+ postgresql-dev \
+ su-exec \
+ zlib-dev
COPY . /synapse
# A wheel cache may be provided in ./cache for faster build
RUN cd /synapse \
- && pip install --upgrade pip setuptools psycopg2 lxml \
+ && pip install --upgrade \
+ lxml \
+ pip \
+ psycopg2 \
+ setuptools \
&& mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
- && rm -rf setup.py setup.cfg synapse
+ && rm -rf \
+ setup.cfg \
+ setup.py \
+ synapse
VOLUME ["/data"]
diff --git a/changelog.d/2952.bugfix b/changelog.d/2952.bugfix
new file mode 100644
index 0000000000..07a3e48304
--- /dev/null
+++ b/changelog.d/2952.bugfix
@@ -0,0 +1 @@
+Make /directory/list API return 404 for room not found instead of 400
diff --git a/changelog.d/3384.misc b/changelog.d/3384.misc
new file mode 100644
index 0000000000..5d56c876d9
--- /dev/null
+++ b/changelog.d/3384.misc
@@ -0,0 +1 @@
+Rewrite cache list decorator
diff --git a/changelog.d/3543.misc b/changelog.d/3543.misc
new file mode 100644
index 0000000000..d231d17749
--- /dev/null
+++ b/changelog.d/3543.misc
@@ -0,0 +1 @@
+Improve Dockerfile and docker-compose instructions
diff --git a/changelog.d/3569.bugfix b/changelog.d/3569.bugfix
new file mode 100644
index 0000000000..d77f035ee0
--- /dev/null
+++ b/changelog.d/3569.bugfix
@@ -0,0 +1 @@
+Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
diff --git a/changelog.d/3612.misc b/changelog.d/3612.misc
new file mode 100644
index 0000000000..f90d2f2ff5
--- /dev/null
+++ b/changelog.d/3612.misc
@@ -0,0 +1 @@
+Make EventStore inherit from EventFederationStore
diff --git a/changelog.d/3628.misc b/changelog.d/3628.misc
new file mode 100644
index 0000000000..1aebefbe18
--- /dev/null
+++ b/changelog.d/3628.misc
@@ -0,0 +1 @@
+Remove unused field "pdu_failures" from transactions.
diff --git a/changelog.d/3630.feature b/changelog.d/3630.feature
new file mode 100644
index 0000000000..8007a04840
--- /dev/null
+++ b/changelog.d/3630.feature
@@ -0,0 +1 @@
+Add ability to limit number of monthly active users on the server
diff --git a/changelog.d/3634.misc b/changelog.d/3634.misc
new file mode 100644
index 0000000000..2cd6af91ff
--- /dev/null
+++ b/changelog.d/3634.misc
@@ -0,0 +1 @@
+rename replication_layer to federation_client
diff --git a/contrib/docker/README.md b/contrib/docker/README.md
index 61592109cb..562cdaac2b 100644
--- a/contrib/docker/README.md
+++ b/contrib/docker/README.md
@@ -9,13 +9,7 @@ use that server.
## Build
-Build the docker image with the `docker build` command from the root of the synapse repository.
-
-```
-docker build -t docker.io/matrixdotorg/synapse .
-```
-
-The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
+Build the docker image with the `docker-compose build` command.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml
index 0b531949e0..3a8dfbae34 100644
--- a/contrib/docker/docker-compose.yml
+++ b/contrib/docker/docker-compose.yml
@@ -6,6 +6,7 @@ version: '3'
services:
synapse:
+ build: ../..
image: docker.io/matrixdotorg/synapse:latest
# Since snyapse does not retry to connect to the database, restart upon
# failure
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 073229b4c4..5bbbe8e2e7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -252,10 +252,10 @@ class Auth(object):
if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None))
- if "user_id" not in request.args:
+ if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))
- user_id = request.args["user_id"][0]
+ user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service))
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6074df292f..14f5540280 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -55,6 +55,7 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
+ MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
class CodeMessageException(RuntimeError):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 79772fa61a..5f0ca51ac7 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -20,6 +20,8 @@ import sys
from six import iteritems
+from prometheus_client import Gauge
+
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource
@@ -301,6 +303,11 @@ class SynapseHomeServer(HomeServer):
quit_with_error(e.message)
+# Gauges to expose monthly active user control metrics
+current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
+max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
+
+
def setup(config_options):
"""
Args:
@@ -516,6 +523,18 @@ def run(hs):
MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60
)
+ @defer.inlineCallbacks
+ def generate_monthly_active_users():
+ count = 0
+ if hs.config.limit_usage_by_mau:
+ count = yield hs.get_datastore().count_monthly_users()
+ current_mau_gauge.set(float(count))
+ max_mau_value_gauge.set(float(hs.config.max_mau_value))
+
+ generate_monthly_active_users()
+ if hs.config.limit_usage_by_mau:
+ clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 18102656b0..a8014e9c50 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -67,6 +67,14 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # Options to control access by tracking MAU
+ self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
+ if self.limit_usage_by_mau:
+ self.max_mau_value = config.get(
+ "max_mau_value", 0,
+ )
+ else:
+ self.max_mau_value = 0
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e501251b6e..657935d1ac 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -207,10 +207,6 @@ class FederationServer(FederationBase):
edu.content
)
- pdu_failures = getattr(transaction, "pdu_failures", [])
- for fail in pdu_failures:
- logger.info("Got failure %r", fail)
-
response = {
"pdus": pdu_results,
}
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5157c3860d..0bb468385d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu
- self.failures = SortedDict() # stream position -> (destination, Failure)
-
self.device_messages = SortedDict() # stream position -> destination
self.pos = 1
@@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "failures", "device_messages", "pos_time",
+ "edus", "device_messages", "pos_time",
]:
register(queue_name, getattr(self, queue_name))
@@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.edus[key]
- # Delete things out of failure map
- keys = self.failures.keys()
- i = self.failures.bisect_left(position_to_delete)
- for key in keys[:i]:
- del self.failures[key]
-
# Delete things out of device map
keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete)
@@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
- def send_failure(self, failure, destination):
- """As per TransactionQueue"""
- pos = self._next_pos()
-
- self.failures[pos] = (destination, str(failure))
- self.notifier.on_new_replication_data()
-
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
@@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
- # Fetch changed failures
- i = self.failures.bisect_right(from_token)
- j = self.failures.bisect_right(to_token) + 1
- failures = self.failures.items()[i:j]
-
- for (pos, (destination, failure)) in failures:
- rows.append((pos, FailureRow(
- destination=destination,
- failure=failure,
- )))
-
# Fetch changed device messages
i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1
@@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
- "destination", # str
- "failure",
-))):
- """Streams failures to a remote server. Failures are issued when there was
- something wrong with a transaction the remote sent us, e.g. it included
- an event that was invalid.
- """
-
- TypeId = "f"
-
- @staticmethod
- def from_data(data):
- return FailureRow(
- destination=data["destination"],
- failure=data["failure"],
- )
-
- def to_data(self):
- return {
- "destination": self.destination,
- "failure": self.failure,
- }
-
- def add_to_buffer(self, buff):
- buff.failures.setdefault(self.destination, []).append(self.failure)
-
-
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str
))):
@@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow,
KeyedEduRow,
EduRow,
- FailureRow,
DeviceRow,
)
}
@@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
- "failures", # dict of destination -> [failures]
"device_destinations", # set of destinations
))
@@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[],
keyed_edus={},
edus={},
- failures={},
device_destinations=set(),
)
@@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in iteritems(buff.failures):
- for failure in failure_list:
- transaction_queue.send_failure(destination, failure)
-
for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6996d6b695..78f9d40a3a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -116,9 +116,6 @@ class TransactionQueue(object):
),
)
- # destination -> list of tuple(failure, deferred)
- self.pending_failures_by_dest = {}
-
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
@@ -382,19 +379,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination)
- def send_failure(self, failure, destination):
- if destination == self.server_name or destination == "localhost":
- return
-
- if not self.can_send_to(destination):
- return
-
- self.pending_failures_by_dest.setdefault(
- destination, []
- ).append(failure)
-
- self._attempt_new_transaction(destination)
-
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -469,7 +453,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@@ -497,7 +480,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
+ if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
@@ -507,7 +490,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION
success = yield self._send_new_transaction(
- destination, pending_pdus, pending_edus, pending_failures,
+ destination, pending_pdus, pending_edus,
)
if success:
sent_transactions_counter.inc()
@@ -584,14 +567,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
- def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures):
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
- failures = [x.get_dict() for x in pending_failures]
success = True
@@ -601,11 +582,10 @@ class TransactionQueue(object):
logger.debug(
"TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
+ " (pdus: %d, edus: %d)",
destination, txn_id,
len(pdus),
len(edus),
- len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
@@ -617,7 +597,6 @@ class TransactionQueue(object):
destination=destination,
pdus=pdus,
edus=edus,
- pdu_failures=failures,
)
self._next_txn_id += 1
@@ -627,12 +606,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
+ " (PDUs: %d, EDUs: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
- len(failures),
)
# Actually send the transaction
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8574898f0c..eae5f2b427 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
- if value.startswith(b"\""):
+ if value.startswith("\""):
return value[1:-1]
else:
return value
@@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
)
logger.info(
- "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+ "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
- len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index bb1b3b13f7..c5ab14314e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "pdu_failures",
]
internal_keys = [
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 402e44cdef..184eef09d0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import unicodedata
import attr
import bcrypt
@@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
+ yield self._check_mau_limits()
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
@@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
+
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
@@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
multiple inexact matches.
Args:
- user_id (str): complete @user:id
+ user_id (unicode): complete @user:id
+ password (unicode): the provided password
Returns:
- (str) the canonical_user_id, or None if unknown user / bad password
+ (unicode) the canonical_user_id, or None if unknown user / bad password
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
@@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
device_id)
defer.returnValue(access_token)
+ @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
+ yield self._check_mau_limits()
auth_api = self.hs.get_auth()
+ user_id = None
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id)
- return user_id
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def delete_access_token(self, access_token):
@@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
"""Computes a secure hash of password.
Args:
- password (str): Password to hash.
+ password (unicode): Password to hash.
Returns:
- Deferred(str): Hashed password.
+ Deferred(unicode): Hashed password.
"""
def _do_hash():
- return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
- bcrypt.gensalt(self.bcrypt_rounds))
+ # Normalise the Unicode in the password
+ pw = unicodedata.normalize("NFKC", password)
+
+ return bcrypt.hashpw(
+ pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+ bcrypt.gensalt(self.bcrypt_rounds),
+ ).decode('ascii')
return make_deferred_yieldable(
threads.deferToThreadPool(
@@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
"""Validates that self.hash(password) == stored_hash.
Args:
- password (str): Password to hash.
- stored_hash (str): Expected hash value.
+ password (unicode): Password to hash.
+ stored_hash (unicode): Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
+ # Normalise the Unicode in the password
+ pw = unicodedata.normalize("NFKC", password)
+
return bcrypt.checkpw(
- password.encode('utf8') + self.hs.config.password_pepper,
+ pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8')
)
@@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
+ @defer.inlineCallbacks
+ def _check_mau_limits(self):
+ """
+ Ensure that if mau blocking is enabled that invalid users cannot
+ log in.
+ """
+ if self.hs.config.limit_usage_by_mau is True:
+ current_mau = yield self.store.count_monthly_users()
+ if current_mau >= self.hs.config.max_mau_value:
+ raise AuthError(
+ 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
+ )
+
@attr.s
class MacaroonGenerator(object):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 49068c06d9..91d8def08b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
- self.replication_layer = hs.get_federation_client()
+ self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
@@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
state, got_auth_chain = (
- yield self.replication_layer.get_state_for_room(
+ yield self.federation_client.get_state_for_room(
origin, pdu.room_id, p
)
)
@@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
#
# see https://github.com/matrix-org/synapse/pull/1744
- missing_events = yield self.replication_layer.get_missing_events(
+ missing_events = yield self.federation_client.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
@@ -522,7 +522,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- events = yield self.replication_layer.backfill(
+ events = yield self.federation_client.backfill(
dest,
room_id,
limit=limit,
@@ -570,7 +570,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.replication_layer.get_state_for_room(
+ state, auth = yield self.federation_client.get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id
@@ -612,7 +612,7 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
logcontext.run_in_background(
- self.replication_layer.get_pdu,
+ self.federation_client.get_pdu,
[dest],
event_id,
outlier=True,
@@ -893,7 +893,7 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution.
"""
- pdu = yield self.replication_layer.send_invite(
+ pdu = yield self.federation_client.send_invite(
destination=target_host,
room_id=event.room_id,
event_id=event.event_id,
@@ -955,7 +955,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
- ret = yield self.replication_layer.send_join(target_hosts, event)
+ ret = yield self.federation_client.send_join(target_hosts, event)
origin = ret["origin"]
state = ret["state"]
@@ -1211,7 +1211,7 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
- yield self.replication_layer.send_leave(
+ yield self.federation_client.send_leave(
target_hosts,
event
)
@@ -1234,7 +1234,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},):
- origin, pdu = yield self.replication_layer.make_membership_event(
+ origin, pdu = yield self.federation_client.make_membership_event(
target_hosts,
room_id,
user_id,
@@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler):
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
- m_ev = yield self.replication_layer.get_pdu(
+ m_ev = yield self.federation_client.get_pdu(
[origin],
e_id,
outlier=True,
@@ -1777,7 +1777,7 @@ class FederationHandler(BaseHandler):
logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
try:
- remote_auth_chain = yield self.replication_layer.get_event_auth(
+ remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
@@ -1893,7 +1893,7 @@ class FederationHandler(BaseHandler):
try:
# 2. Get remote difference.
- result = yield self.replication_layer.query_auth(
+ result = yield self.federation_client.query_auth(
origin,
event.room_id,
event.event_id,
@@ -2192,7 +2192,7 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context)
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
- yield self.replication_layer.forward_third_party_invite(
+ yield self.federation_client.forward_third_party_invite(
destinations,
room_id,
event_dict,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7caff0cbc8..289704b241 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
hs (synapse.server.HomeServer):
"""
super(RegistrationHandler, self).__init__(hs)
-
+ self.hs = hs
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
@@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
Args:
localpart : The local part of the user ID to register. If None,
one will be generated.
- password (str) : The password to assign to this user so they can
+ password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
@@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
+ yield self._check_mau_limits()
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
@@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
)
+ yield self._check_mau_limits()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
-
+ yield self._check_mau_limits()
need_register = True
try:
@@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts,
action="join",
)
+
+ @defer.inlineCallbacks
+ def _check_mau_limits(self):
+ """
+ Do not accept registrations if monthly active user limits exceeded
+ and limiting is enabled
+ """
+ if self.hs.config.limit_usage_by_mau is True:
+ current_mau = yield self.store.count_monthly_users()
+ if current_mau >= self.hs.config.max_mau_value:
+ raise RegistrationError(
+ 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
+ )
diff --git a/synapse/http/server.py b/synapse/http/server.py
index c70fdbdfd2..1940c1c4f4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import cgi
import collections
import logging
-import urllib
-from six.moves import http_client
+from six import PY3
+from six.moves import http_client, urllib
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
self.hs = hs
def register_paths(self, method, path_patterns, callback):
+ method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
@@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
+ def _unquote(s):
+ if PY3:
+ # On Python 3, unquote is unicode -> unicode
+ return urllib.parse.unquote(s)
+ else:
+ # On Python 2, unquote is bytes -> bytes We need to encode the
+ # URL again (as it was decoded by _get_handler_for request), as
+ # ASCII because it's a URL, and then decode it to get the UTF-8
+ # characters that were quoted.
+ return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
+
kwargs = intern_dict({
- name: urllib.unquote(value).decode("UTF-8") if value else value
+ name: _unquote(value) if value else value
for name, value in group_dict.items()
})
@@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request):
Returns:
- Tuple[Callable, dict[str, str]]: callback method, and the dict
- mapping keys to path components as specified in the handler's
- path match regexp.
+ Tuple[Callable, dict[unicode, unicode]]: callback method, and the
+ dict mapping keys to path components as specified in the
+ handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
@@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path)
+ m = path_entry.pattern.match(request.path.decode('ascii'))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
@@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path
def render_GET(self, request):
- return redirectTo(self.url, request)
+ return redirectTo(self.url.encode('ascii'), request)
def getChild(self, name, request):
if len(name) == 0:
@@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
return
if pretty_print:
- json_bytes = encode_pretty_printed_json(json_object) + "\n"
+ json_bytes = (encode_pretty_printed_json(json_object) + "\n"
+ ).encode("utf-8")
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
+ # canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object)
else:
- json_bytes = json.dumps(json_object)
+ json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes(
request, code, json_bytes,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 882816dc8f..69f7085291 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body:
return None
+ # Decode to Unicode so that simplejson will return Unicode strings on
+ # Python 2
try:
- content = json.loads(content_bytes)
+ content_unicode = content_bytes.decode('utf8')
+ except UnicodeDecodeError:
+ logger.warn("Unable to decode UTF-8")
+ raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+ try:
+ content = json.loads(content_unicode)
except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 99f6c6e3c3..80d625eecc 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -18,6 +18,7 @@ import hashlib
import hmac
import logging
+from six import text_type
from six.moves import http_client
from twisted.internet import defer
@@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON,
)
else:
- if (not isinstance(body['username'], str) or len(body['username']) > 512):
+ if (
+ not isinstance(body['username'], text_type)
+ or len(body['username']) > 512
+ ):
raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8")
@@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON,
)
else:
- if (not isinstance(body['password'], str) or len(body['password']) > 512):
+ if (
+ not isinstance(body['password'], text_type)
+ or len(body['password']) > 512
+ ):
raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8")
@@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
- if not hmac.compare_digest(want_mac, got_mac):
- raise SynapseError(
- 403, "HMAC incorrect",
- )
+ if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
+ raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register(
- localpart=username.lower(), password=password, admin=bool(admin),
+ localpart=body['username'].lower(),
+ password=body["password"],
+ admin=bool(admin),
generate_token=False,
)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 69dcd618cb..97733f3026 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -18,7 +18,7 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.types import RoomAlias
@@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def on_GET(self, request, room_id):
room = yield self.store.get_room(room_id)
if room is None:
- raise SynapseError(400, "Unknown room")
+ raise NotFoundError("Unknown room")
defer.returnValue((200, {
"visibility": "public" if room["is_public"] else "private"
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d6cf915d86..2f64155d13 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- kind = "user"
- if "kind" in request.args:
- kind = request.args["kind"][0]
+ kind = b"user"
+ if b"kind" in request.args:
+ kind = request.args[b"kind"][0]
- if kind == "guest":
+ if kind == b"guest":
ret = yield self._do_guest_registration(body)
defer.returnValue(ret)
return
- elif kind != "user":
+ elif kind != b"user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
@@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
- new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
+ new_password = params.get("password", None)
if desired_username is not None:
desired_username = desired_username.lower()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index b25993fcb5..a6189224ee 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -177,7 +177,7 @@ class MediaStorage(object):
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "w"), self.hs.get_reactor())
+ open(local_path, "wb"), self.hs.get_reactor())
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
diff --git a/synapse/state.py b/synapse/state.py
index 033f55d967..e1092b97a9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -577,7 +577,7 @@ def _make_state_cache_entry(
def _ordered_events(events):
def key_func(e):
- return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ba88a54979..134e4a80f1 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
+ EventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
@@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
- EventsStore,
ReceiptsStore,
EndToEndKeyStore,
SearchStore,
@@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
+ self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
@@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users)
+ def count_monthly_users(self):
+ """Counts the number of users who used this homeserver in the last 30 days
+
+ This method should be refactored with count_daily_users - the only
+ reason not to is waiting on definition of mau
+
+ Returns:
+ Defered[int]
+ """
+ def _count_monthly_users(txn):
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago,))
+ count, = txn.fetchone()
+ return count
+
+ return self.runInteraction("count_monthly_users", _count_monthly_users)
+
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 9f12b360bc..31248d5e06 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from ._base import SQLBaseStore
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5d3ee90017..8bd35df119 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2f482af3a1..61223da1a5 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred
@@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
def encode_json(json_object):
- return frozendict_json_encoder.encode(json_object)
+ """
+ Encode a Python object as JSON and return it in a Unicode string.
+ """
+ out = frozendict_json_encoder.encode(json_object)
+ if isinstance(out, bytes):
+ out = out.decode('utf8')
+ return out
class _EventPeristenceQueue(object):
@@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(EventsWorkerStore):
+# inherits from EventFederationStore so that we can call _update_backward_extremities
+# and _handle_mult_prev_events (though arguably those could both be moved in here)
+class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@@ -1054,7 +1064,7 @@ class EventsStore(EventsWorkerStore):
metadata_json = encode_json(
event.internal_metadata.get_dict()
- ).decode("UTF-8")
+ )
sql = (
"UPDATE event_json SET internal_metadata = ?"
@@ -1168,8 +1178,8 @@ class EventsStore(EventsWorkerStore):
"room_id": event.room_id,
"internal_metadata": encode_json(
event.internal_metadata.get_dict()
- ).decode("UTF-8"),
- "json": encode_json(event_dict(event)).decode("UTF-8"),
+ ),
+ "json": encode_json(event_dict(event)),
}
for event, _ in events_and_contexts
],
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 027bf8c85e..10dce21cea 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,7 +24,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 470212aa2a..5623391f6e 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
txn (cursor):
event_id (str): Id for the Event.
Returns:
- A dict of algorithm -> hash.
+ A dict[unicode, bytes] of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 25d0097b58..b9f2b74ac6 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -43,7 +43,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
diff --git a/synapse/types.py b/synapse/types.py
index 08f058f714..41afb27a74 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -137,7 +137,7 @@ class DomainSpecificString(
@classmethod
def from_string(cls, s):
"""Parse the string given by 's' into a structure object."""
- if len(s) < 1 or s[0] != cls.SIGIL:
+ if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL,
))
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
- # If we're passed a cache_context then we'll want to call its invalidate()
- # whenever we are invalidated
+ # If we're passed a cache_context then we'll want to call its
+ # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- # cached is a dict arg -> deferred, where deferred results in a
- # 2-tuple (`arg`, `result`)
results = {}
- cached_defers = {}
- missing = []
+
+ def update_results_dict(res, arg):
+ results[arg] = res
+
+ # list of deferreds to wait for
+ cached_defers = []
+
+ missing = set()
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
- def cache_get(arg):
- return cache.get(arg, callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ return arg
else:
- key = list(keyargs)
+ keylist = list(keyargs)
- def cache_get(arg):
- key[self.list_pos] = arg
- return cache.get(tuple(key), callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ keylist[self.list_pos] = arg
+ return tuple(keylist)
for arg in list_args:
try:
- res = cache_get(arg)
-
+ res = cache.get(arg_to_cache_key(arg),
+ callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached_defers[arg] = res
+ res.addCallback(update_results_dict, arg)
+ cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
- missing.append(arg)
+ missing.add(arg)
if missing:
+ # we need an observable deferred for each entry in the list,
+ # which we put in the cache. Each deferred resolves with the
+ # relevant result for that key.
+ deferreds_map = {}
+ for arg in missing:
+ deferred = defer.Deferred()
+ deferreds_map[arg] = deferred
+ key = arg_to_cache_key(arg)
+ observable = ObservableDeferred(deferred)
+ cache.set(key, observable, callback=invalidate_callback)
+
+ def complete_all(res):
+ # the wrapped function has completed. It returns a
+ # a dict. We can now resolve the observable deferreds in
+ # the cache and update our own result map.
+ for e in missing:
+ val = res.get(e, None)
+ deferreds_map[e].callback(val)
+ results[e] = val
+
+ def errback(f):
+ # the wrapped function has failed. Invalidate any cache
+ # entries we're supposed to be populating, and fail
+ # their deferreds.
+ for e in missing:
+ key = arg_to_cache_key(e)
+ cache.invalidate(key)
+ deferreds_map[e].errback(f)
+
+ # return the failure, to propagate to our caller.
+ return f
+
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = list(missing)
- ret_d = defer.maybeDeferred(
+ cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
- )
-
- ret_d = ObservableDeferred(ret_d)
-
- # We need to create deferreds for each arg in the list so that
- # we can insert the new deferred into the cache.
- for arg in missing:
- observer = ret_d.observe()
- observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
- observer = ObservableDeferred(observer)
-
- if num_args == 1:
- cache.set(
- arg, observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, arg)
- else:
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
-
- res = observer.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
-
- cached_defers[arg] = res
+ ).addCallbacks(complete_all, errback))
if cached_defers:
- def update_results_dict(res):
- results.update(res)
- return results
-
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- list(cached_defers.values()),
+ d = defer.gatherResults(
+ cached_defers,
consumeErrors=True,
- ).addCallback(update_results_dict).addErrback(
+ ).addCallbacks(
+ lambda _: results,
unwrapFirstError
- ))
+ )
+ return logcontext.make_deferred_yieldable(d)
else:
return results
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.
Args:
- cache (Cache): The underlying cache to use.
+ cached_method_name (str): The name of the single-item lookup method.
+ This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 581c6052ac..014edea971 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import binary_type, text_type
from canonicaljson import json
from frozendict import frozendict
@@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict):
return o
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
@@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 5f158ec4b9..a82d737e71 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
- self.test_token = "_test_token_"
+ self.test_token = b"_test_token_"
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
@@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
- request.args["access_token"] = [self.test_token]
+ request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
- masquerading_user_id = "@doppelganger:matrix.org"
+ masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None,
@@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
- request.args["access_token"] = [self.test_token]
- request.args["user_id"] = [masquerading_user_id]
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
- self.assertEquals(requester.user.to_string(), masquerading_user_id)
+ self.assertEquals(
+ requester.user.to_string(),
+ masquerading_user_id.decode('utf8')
+ )
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
- masquerading_user_id = "@doppelganger:matrix.org"
+ masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None,
@@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
- request.args["access_token"] = [self.test_token]
- request.args["user_id"] = [masquerading_user_id]
+ request.args[b"access_token"] = [self.test_token]
+ request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
# check the token works
request = Mock(args={})
- request.args["access_token"] = [token]
+ request.args[b"access_token"] = [token.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
# the token should *not* work now
request = Mock(args={})
- request.args["access_token"] = [guest_tok]
+ request.args[b"access_token"] = [guest_tok.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm:
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 2e5e8e4dec..55eab9e9cf 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
import pymacaroons
@@ -19,6 +20,7 @@ from twisted.internet import defer
import synapse
import synapse.api.errors
+from synapse.api.errors import AuthError
from synapse.handlers.auth import AuthHandler
from tests import unittest
@@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
+ # MAU tests
+ self.hs.config.max_mau_value = 50
+ self.small_number_of_users = 1
+ self.large_number_of_users = 100
def test_token_is_a_macaroon(self):
token = self.macaroon_generator.generate_access_token("some_user")
@@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
+ @defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000
)
-
- self.assertEqual(
- "a_user",
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
+ self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
+ @defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
self.assertEqual(
- "a_user",
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ "a_user", user_id
)
# add another "user_id" caveat, which might allow us to override the
@@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
+
+ @defer.inlineCallbacks
+ def test_mau_limits_disabled(self):
+ self.hs.config.limit_usage_by_mau = False
+ # Ensure does not throw exception
+ yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+ yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
+
+ @defer.inlineCallbacks
+ def test_mau_limits_exceeded(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.get_datastore().count_monthly_users = Mock(
+ return_value=defer.succeed(self.large_number_of_users)
+ )
+
+ with self.assertRaises(AuthError):
+ yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+ self.hs.get_datastore().count_monthly_users = Mock(
+ return_value=defer.succeed(self.large_number_of_users)
+ )
+ with self.assertRaises(AuthError):
+ yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
+
+ @defer.inlineCallbacks
+ def test_mau_limits_not_exceeded(self):
+ self.hs.config.limit_usage_by_mau = True
+
+ self.hs.get_datastore().count_monthly_users = Mock(
+ return_value=defer.succeed(self.small_number_of_users)
+ )
+ # Ensure does not raise exception
+ yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+ self.hs.get_datastore().count_monthly_users = Mock(
+ return_value=defer.succeed(self.small_number_of_users)
+ )
+ yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
+
+ def _get_macaroon(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", 5000
+ )
+ return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 025fa1be81..0937d71cf6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.errors import RegistrationError
from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester
@@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
requester, local_part, display_name)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
+
+ @defer.inlineCallbacks
+ def test_cannot_register_when_mau_limits_exceeded(self):
+ local_part = "someone"
+ display_name = "someone"
+ requester = create_requester("@as:test")
+ store = self.hs.get_datastore()
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.max_mau_value = 50
+ lots_of_users = 100
+ small_number_users = 1
+
+ store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+ # Ensure does not throw exception
+ yield self.handler.get_or_create_user(requester, 'a', display_name)
+
+ self.hs.config.limit_usage_by_mau = True
+
+ with self.assertRaises(RegistrationError):
+ yield self.handler.get_or_create_user(requester, 'b', display_name)
+
+ store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
+
+ self._macaroon_mock_generator("another_secret")
+
+ # Ensure does not throw exception
+ yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
+
+ self._macaroon_mock_generator("another another secret")
+ store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+ with self.assertRaises(RegistrationError):
+ yield self.handler.register(localpart=local_part)
+
+ self._macaroon_mock_generator("another another secret")
+ store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+ with self.assertRaises(RegistrationError):
+ yield self.handler.register_saml2(local_part)
+
+ def _macaroon_mock_generator(self, secret):
+ """
+ Reset macaroon generator in the case where the test creates multiple users
+ """
+ macaroon_generator = Mock(
+ generate_access_token=Mock(return_value=secret))
+ self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
+ self.hs.handlers = RegistrationHandlers(self.hs)
+ self.handler = self.hs.get_handlers().registration_handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b08856f763..2c263af1a3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content,
}
],
- "pdu_failures": [],
}
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
new file mode 100644
index 0000000000..f19cb1265c
--- /dev/null
+++ b/tests/storage/test__init__.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.utils
+
+
+class InitTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(InitTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ hs.config.max_mau_value = 50
+ hs.config.limit_usage_by_mau = True
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def test_count_monthly_users(self):
+ count = yield self.store.count_monthly_users()
+ self.assertEqual(0, count)
+
+ yield self._insert_user_ips("@user:server1")
+ yield self._insert_user_ips("@user:server2")
+
+ count = yield self.store.count_monthly_users()
+ self.assertEqual(2, count)
+
+ @defer.inlineCallbacks
+ def _insert_user_ips(self, user):
+ """
+ Helper function to populate user_ips without using batch insertion infra
+ args:
+ user (str): specify username i.e. @user:server.com
+ """
+ yield self.store._simple_upsert(
+ table="user_ips",
+ keyvalues={
+ "user_id": user,
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "device_id": "device_id",
+ },
+ values={
+ "last_seen": self.clock.time_msec(),
+ }
+ )
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def test_cache(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ assert (
+ logcontext.LoggingContext.current_context().request == "c1"
+ )
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ assert (
+ logcontext.LoggingContext.current_context().request == "c1"
+ )
+ defer.returnValue(self.mock(args1, arg2))
+
+ with logcontext.LoggingContext() as c1:
+ c1.request = "c1"
+ obj = Cls()
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ d1 = obj.list_fn([10, 20], 2)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ r = yield d1
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ c1
+ )
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = {30: 'peas'}
+ r = yield obj.list_fn([20, 30], 2)
+ obj.mock.assert_called_once_with([30], 2)
+ self.assertEqual(r, {20: 'chips', 30: 'peas'})
+ obj.mock.reset_mock()
+
+ # all the values should now be cached
+ r = yield obj.fn(10, 2)
+ self.assertEqual(r, 'fish')
+ r = yield obj.fn(20, 2)
+ self.assertEqual(r, 'chips')
+ r = yield obj.fn(30, 2)
+ self.assertEqual(r, 'peas')
+ r = yield obj.list_fn([10, 20, 30], 2)
+ obj.mock.assert_not_called()
+ self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ """Make sure that invalidation callbacks are called."""
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ defer.returnValue(self.mock(args1, arg2))
+
+ obj = Cls()
+ invalidate0 = mock.Mock()
+ invalidate1 = mock.Mock()
+
+ # cache miss
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # cache hit
+ r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+ obj.mock.assert_not_called()
+ self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+ invalidate0.assert_not_called()
+ invalidate1.assert_not_called()
+
+ # now if we invalidate the keys, both invalidations should get called
+ obj.fn.invalidate((10, 2))
+ invalidate0.assert_called_once()
+ invalidate1.assert_called_once()
diff --git a/tests/utils.py b/tests/utils.py
index c3dbff8507..9bff3ff3b9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
self.prefix = prefix
def trigger_get(self, path):
- return self.trigger("GET", path, None)
+ return self.trigger(b"GET", path, None)
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
@@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
headers = {}
if federation_auth:
- headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
+ headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
@@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
except Exception:
pass
+ if isinstance(path, bytes):
+ path = path.decode('utf8')
+
for (method, pattern, func) in self.callbacks:
if http_method != method:
continue
@@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
if matcher:
try:
args = [
- urlparse.unquote(u).decode("UTF-8")
+ urlparse.unquote(u)
for u in matcher.groups()
]
|