summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py23
-rw-r--r--synapse/handlers/message.py16
-rw-r--r--synapse/handlers/register.py13
-rw-r--r--synapse/rest/client/v1/register.py4
-rw-r--r--synapse/rest/client/v1/room.py11
-rw-r--r--synapse/rest/client/v2_alpha/register.py50
-rw-r--r--synapse/storage/events.py82
-rw-r--r--synapse/storage/registration.py6
-rw-r--r--synapse/storage/schema/delta/33/event_fields.py60
-rw-r--r--synapse/storage/stream.py56
-rw-r--r--tests/rest/client/v2_alpha/test_register.py6
-rw-r--r--tests/storage/event_injector.py1
-rw-r--r--tests/storage/test_events.py12
13 files changed, 303 insertions, 37 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4f5a4281fa..3b3ef70750 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -191,6 +191,17 @@ class Filter(object):
     def __init__(self, filter_json):
         self.filter_json = filter_json
 
+        self.types = self.filter_json.get("types", None)
+        self.not_types = self.filter_json.get("not_types", [])
+
+        self.rooms = self.filter_json.get("rooms", None)
+        self.not_rooms = self.filter_json.get("not_rooms", [])
+
+        self.senders = self.filter_json.get("senders", None)
+        self.not_senders = self.filter_json.get("not_senders", [])
+
+        self.contains_url = self.filter_json.get("contains_url", None)
+
     def check(self, event):
         """Checks whether the filter matches the given event.
 
@@ -209,9 +220,10 @@ class Filter(object):
             event.get("room_id", None),
             sender,
             event.get("type", None),
+            "url" in event.get("content", {})
         )
 
-    def check_fields(self, room_id, sender, event_type):
+    def check_fields(self, room_id, sender, event_type, contains_url):
         """Checks whether the filter matches the given event fields.
 
         Returns:
@@ -225,15 +237,20 @@ class Filter(object):
 
         for name, match_func in literal_keys.items():
             not_name = "not_%s" % (name,)
-            disallowed_values = self.filter_json.get(not_name, [])
+            disallowed_values = getattr(self, not_name)
             if any(map(match_func, disallowed_values)):
                 return False
 
-            allowed_values = self.filter_json.get(name, None)
+            allowed_values = getattr(self, name)
             if allowed_values is not None:
                 if not any(map(match_func, allowed_values)):
                     return False
 
+        contains_url_filter = self.filter_json.get("contains_url")
+        if contains_url_filter is not None:
+            if contains_url_filter != contains_url:
+                return False
+
         return True
 
     def filter_rooms(self, room_ids):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ad2753c1b5..dc76d34a52 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -66,7 +66,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_messages(self, requester, room_id=None, pagin_config=None,
-                     as_client_event=True):
+                     as_client_event=True, event_filter=None):
         """Get messages in a room.
 
         Args:
@@ -75,11 +75,11 @@ class MessageHandler(BaseHandler):
             pagin_config (synapse.api.streams.PaginationConfig): The pagination
                 config rules to apply, if any.
             as_client_event (bool): True to get events in client-server format.
+            event_filter (Filter): Filter to apply to results or None
         Returns:
             dict: Pagination API results
         """
         user_id = requester.user.to_string()
-        data_source = self.hs.get_event_sources().sources["room"]
 
         if pagin_config.from_token:
             room_token = pagin_config.from_token.room_key
@@ -129,8 +129,13 @@ class MessageHandler(BaseHandler):
                     room_id, max_topo
                 )
 
-            events, next_key = yield data_source.get_pagination_rows(
-                requester.user, source_config, room_id
+            events, next_key = yield self.store.paginate_room_events(
+                room_id=room_id,
+                from_key=source_config.from_key,
+                to_key=source_config.to_key,
+                direction=source_config.direction,
+                limit=source_config.limit,
+                event_filter=event_filter,
             )
 
             next_token = pagin_config.from_token.copy_and_replace(
@@ -144,6 +149,9 @@ class MessageHandler(BaseHandler):
                 "end": next_token.to_string(),
             })
 
+        if event_filter:
+            events = event_filter.filter(events)
+
         events = yield filter_events_for_client(
             self.store,
             user_id,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 6b33b27149..94b19d0cb0 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler):
             localpart : The local part of the user ID to register. If None,
               one will be generated.
             password (str) : The password to assign to this user so they can
-            login again. This can be None which means they cannot login again
-            via a password (e.g. the user is an application service user).
+              login again. This can be None which means they cannot login again
+              via a password (e.g. the user is an application service user).
+            generate_token (bool): Whether a new access token should be
+              generated. Having this be True should be considered deprecated,
+              since it offers no means of associating a device_id with the
+              access_token. Instead you should call auth_handler.issue_access_token
+              after registration.
         Returns:
             A tuple of (user_id, access_token).
         Raises:
@@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler):
             user_id, allowed_appservice=service
         )
 
-        token = self.auth_handler().generate_access_token(user_id)
         yield self.store.register(
             user_id=user_id,
-            token=token,
             password_hash="",
             appservice_id=service_id,
             create_profile_with_localpart=user.localpart,
         )
-        defer.returnValue((user_id, token))
+        defer.returnValue(user_id)
 
     @defer.inlineCallbacks
     def check_recaptcha(self, ip, private_key, challenge, response):
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index efe796c65f..2383b9df86 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -64,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
         # TODO: persistent storage
         self.sessions = {}
         self.enable_registration = hs.config.enable_registration
+        self.auth_handler = hs.get_auth_handler()
 
     def on_GET(self, request):
         if self.hs.config.enable_registration_captcha:
@@ -303,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
         user_localpart = register_json["user"].encode("utf-8")
 
         handler = self.handlers.registration_handler
-        (user_id, token) = yield handler.appservice_register(
+        user_id = yield handler.appservice_register(
             user_localpart, as_token
         )
+        token = yield self.auth_handler.issue_access_token(user_id)
         self._remove_session(session)
         defer.returnValue({
             "user_id": user_id,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 86fbe2747d..866a1e9120 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
 from synapse.api.errors import SynapseError, Codes, AuthError
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import EventTypes, Membership
+from synapse.api.filtering import Filter
 from synapse.types import UserID, RoomID, RoomAlias
 from synapse.events.utils import serialize_event
 from synapse.http.servlet import parse_json_object_from_request
 
 import logging
 import urllib
+import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
             request, default_limit=10,
         )
         as_client_event = "raw" not in request.args
+        filter_bytes = request.args.get("filter", None)
+        if filter_bytes:
+            filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+            event_filter = Filter(json.loads(filter_json))
+        else:
+            event_filter = None
         handler = self.handlers.message_handler
         msgs = yield handler.get_messages(
             room_id=room_id,
             requester=requester,
             pagin_config=pagination_config,
-            as_client_event=as_client_event
+            as_client_event=as_client_event,
+            event_filter=event_filter,
         )
 
         defer.returnValue((200, msgs))
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 2722a58e3e..b7e03ea9d1 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -235,19 +235,17 @@ class RegisterRestServlet(RestServlet):
 
             add_email = True
 
-        access_token = yield self.auth_handler.issue_access_token(
+        result = yield self._create_registration_details(
             registered_user_id
         )
 
         if add_email and result and LoginType.EMAIL_IDENTITY in result:
             threepid = result[LoginType.EMAIL_IDENTITY]
             yield self._register_email_threepid(
-                registered_user_id, threepid, access_token,
+                registered_user_id, threepid, result["access_token"],
                 params.get("bind_email")
             )
 
-        result = yield self._create_registration_details(registered_user_id,
-                                                         access_token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
@@ -255,10 +253,10 @@ class RegisterRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def _do_appservice_registration(self, username, as_token):
-        (user_id, token) = yield self.registration_handler.appservice_register(
+        user_id = yield self.registration_handler.appservice_register(
             username, as_token
         )
-        defer.returnValue((yield self._create_registration_details(user_id, token)))
+        defer.returnValue((yield self._create_registration_details(user_id)))
 
     @defer.inlineCallbacks
     def _do_shared_secret_registration(self, username, password, mac):
@@ -282,10 +280,12 @@ class RegisterRestServlet(RestServlet):
                 403, "HMAC incorrect",
             )
 
-        (user_id, token) = yield self.registration_handler.register(
-            localpart=username, password=password
+        (user_id, _) = yield self.registration_handler.register(
+            localpart=username, password=password, generate_token=False,
         )
-        defer.returnValue((yield self._create_registration_details(user_id, token)))
+
+        result = yield self._create_registration_details(user_id)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def _register_email_threepid(self, user_id, threepid, token, bind_email):
@@ -358,11 +358,31 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue()
 
     @defer.inlineCallbacks
-    def _create_registration_details(self, user_id, token):
-        refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
+    def _create_registration_details(self, user_id):
+        """Complete registration of newly-registered user
+
+        Issues access_token and refresh_token, and builds the success response
+        body.
+
+        Args:
+            (str) user_id: full canonical @user:id
+
+
+        Returns:
+            defer.Deferred: (object) dictionary for response from /register
+        """
+
+        access_token = yield self.auth_handler.issue_access_token(
+            user_id
+        )
+
+        refresh_token = yield self.auth_handler.issue_refresh_token(
+            user_id
+        )
+
         defer.returnValue({
             "user_id": user_id,
-            "access_token": token,
+            "access_token": access_token,
             "home_server": self.hs.hostname,
             "refresh_token": refresh_token,
         })
@@ -375,7 +395,11 @@ class RegisterRestServlet(RestServlet):
             generate_token=False,
             make_guest=True
         )
-        access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
+        access_token = self.auth_handler.generate_access_token(
+            user_id, ["guest = true"]
+        )
+        # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
+        # so long as we don't return a refresh_token here.
         defer.returnValue((200, {
             "user_id": user_id,
             "access_token": access_token,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 9d74fd159d..6610549281 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -152,6 +152,7 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 class EventsStore(SQLBaseStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+    EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
     def __init__(self, hs):
         super(EventsStore, self).__init__(hs)
@@ -159,6 +160,10 @@ class EventsStore(SQLBaseStore):
         self.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
+        self.register_background_update_handler(
+            self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+            self._background_reindex_fields_sender,
+        )
 
         self._event_persist_queue = _EventPeristenceQueue()
 
@@ -576,6 +581,11 @@ class EventsStore(SQLBaseStore):
                     "content": encode_json(event.content).decode("UTF-8"),
                     "origin_server_ts": int(event.origin_server_ts),
                     "received_ts": self._clock.time_msec(),
+                    "sender": event.sender,
+                    "contains_url": (
+                        "url" in event.content
+                        and isinstance(event.content["url"], basestring)
+                    ),
                 }
                 for event, _ in events_and_contexts
             ],
@@ -1116,6 +1126,78 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
+    def _background_reindex_fields_sender(self, progress, batch_size):
+        target_min_stream_id = progress["target_min_stream_id_inclusive"]
+        max_stream_id = progress["max_stream_id_exclusive"]
+        rows_inserted = progress.get("rows_inserted", 0)
+
+        INSERT_CLUMP_SIZE = 1000
+
+        def reindex_txn(txn):
+            sql = (
+                "SELECT stream_ordering, event_id, json FROM events"
+                " INNER JOIN event_json USING (event_id)"
+                " WHERE ? <= stream_ordering AND stream_ordering < ?"
+                " ORDER BY stream_ordering DESC"
+                " LIMIT ?"
+            )
+
+            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+            rows = txn.fetchall()
+            if not rows:
+                return 0
+
+            min_stream_id = rows[-1][0]
+
+            update_rows = []
+            for row in rows:
+                try:
+                    event_id = row[1]
+                    event_json = json.loads(row[2])
+                    sender = event_json["sender"]
+                    content = event_json["content"]
+
+                    contains_url = "url" in content
+                    if contains_url:
+                        contains_url &= isinstance(content["url"], basestring)
+                except (KeyError, AttributeError):
+                    # If the event is missing a necessary field then
+                    # skip over it.
+                    continue
+
+                update_rows.append((sender, contains_url, event_id))
+
+            sql = (
+                "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+            )
+
+            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+                clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+                txn.executemany(sql, clump)
+
+            progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+                "rows_inserted": rows_inserted + len(rows)
+            }
+
+            self._background_update_progress_txn(
+                txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+            )
+
+            return len(rows)
+
+        result = yield self.runInteraction(
+            self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+        )
+
+        if not result:
+            yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
     def _background_reindex_origin_server_ts(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26ef1cfd8a..9a92b35361 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash,
+    def register(self, user_id, token=None, password_hash=None,
                  was_guest=False, make_guest=False, appservice_id=None,
                  create_profile_with_localpart=None, admin=False):
         """Attempts to register an account.
 
         Args:
             user_id (str): The desired user ID to register.
-            token (str): The desired access token to use for this user.
+            token (str): The desired access token to use for this user. If this
+                is not None, the given access token is associated with the user
+                id.
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
new file mode 100644
index 0000000000..83066cccc9
--- /dev/null
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -0,0 +1,60 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE events ADD COLUMN sender TEXT;
+ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    for statement in get_statements(ALTER_TABLE.splitlines()):
+        cur.execute(statement)
+
+    cur.execute("SELECT MIN(stream_ordering) FROM events")
+    rows = cur.fetchall()
+    min_stream_id = rows[0][0]
+
+    cur.execute("SELECT MAX(stream_ordering) FROM events")
+    rows = cur.fetchall()
+    max_stream_id = rows[0][0]
+
+    if min_stream_id is not None and max_stream_id is not None:
+        progress = {
+            "target_min_stream_id_inclusive": min_stream_id,
+            "max_stream_id_exclusive": max_stream_id + 1,
+            "rows_inserted": 0,
+        }
+        progress_json = ujson.dumps(progress)
+
+        sql = (
+            "INSERT into background_updates (update_name, progress_json)"
+            " VALUES (?, ?)"
+        )
+
+        sql = database_engine.convert_param_style(sql)
+
+        cur.execute(sql, ("event_fields_sender_url", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index c33ac5a8d7..862c5c3ea1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -95,6 +95,54 @@ def upper_bound(token, engine, inclusive=True):
         )
 
 
+def filter_to_clause(event_filter):
+    # NB: This may create SQL clauses that don't optimise well (and we don't
+    # have indices on all possible clauses). E.g. it may create
+    # "room_id == X AND room_id != X", which postgres doesn't optimise.
+
+    if not event_filter:
+        return "", []
+
+    clauses = []
+    args = []
+
+    if event_filter.types:
+        clauses.append(
+            "(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
+        )
+        args.extend(event_filter.types)
+
+    for typ in event_filter.not_types:
+        clauses.append("type != ?")
+        args.append(typ)
+
+    if event_filter.senders:
+        clauses.append(
+            "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
+        )
+        args.extend(event_filter.senders)
+
+    for sender in event_filter.not_senders:
+        clauses.append("sender != ?")
+        args.append(sender)
+
+    if event_filter.rooms:
+        clauses.append(
+            "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
+        )
+        args.extend(event_filter.rooms)
+
+    for room_id in event_filter.not_rooms:
+        clauses.append("room_id != ?")
+        args.append(room_id)
+
+    if event_filter.contains_url:
+        clauses.append("contains_url = ?")
+        args.append(event_filter.contains_url)
+
+    return " AND ".join(clauses), args
+
+
 class StreamStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@@ -320,7 +368,7 @@ class StreamStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def paginate_room_events(self, room_id, from_key, to_key=None,
-                             direction='b', limit=-1):
+                             direction='b', limit=-1, event_filter=None):
         # Tokens really represent positions between elements, but we use
         # the convention of pointing to the event before the gap. Hence
         # we have a bit of asymmetry when it comes to equalities.
@@ -344,6 +392,12 @@ class StreamStore(SQLBaseStore):
                     RoomStreamToken.parse(to_key), self.database_engine
                 ))
 
+        filter_clause, filter_args = filter_to_clause(event_filter)
+
+        if filter_clause:
+            bounds += " AND " + filter_clause
+            args.extend(filter_args)
+
         if int(limit) > 0:
             args.append(int(limit))
             limit_str = " LIMIT ?"
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 9a4215fef7..ccbb8776d3 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "id": "1234"
         }
         self.registration_handler.appservice_register = Mock(
-            return_value=(user_id, token)
+            return_value=user_id
         )
+        self.auth_handler.issue_access_token = Mock(return_value=token)
+
         (code, result) = yield self.servlet.on_POST(self.request)
         self.assertEquals(code, 200)
         det_data = {
@@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         }
         self.assertDictContainsSubset(det_data, result)
         self.assertIn("refresh_token", result)
+        self.auth_handler.issue_access_token.assert_called_once_with(
+            user_id)
 
     def test_POST_disabled_registration(self):
         self.hs.config.enable_registration = False
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
index f22ba8db89..38556da9a7 100644
--- a/tests/storage/event_injector.py
+++ b/tests/storage/event_injector.py
@@ -30,6 +30,7 @@ class EventInjector:
     def create_room(self, room):
         builder = self.event_builder_factory.new({
             "type": EventTypes.Create,
+            "sender": "",
             "room_id": room.to_string(),
             "content": {},
         })
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 18a6cff0c7..3762b38e37 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_count_daily_messages(self):
-        self.db_pool.runQuery("DELETE FROM stats_reporting")
+        yield self.db_pool.runQuery("DELETE FROM stats_reporting")
 
         self.hs.clock.now = 100
 
@@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase):
         # it isn't old enough.
         count = yield self.store.count_daily_messages()
         self.assertIsNone(count)
-        self._assert_stats_reporting(1, self.hs.clock.now)
+        yield self._assert_stats_reporting(1, self.hs.clock.now)
 
         # Already reported yesterday, two new events from today.
         yield self.event_injector.inject_message(room, user, "Yeah they are!")
@@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase):
         self.hs.clock.now += 60 * 60 * 24
         count = yield self.store.count_daily_messages()
         self.assertEqual(2, count)  # 2 since yesterday
-        self._assert_stats_reporting(3, self.hs.clock.now)  # 3 ever
+        yield self._assert_stats_reporting(3, self.hs.clock.now)  # 3 ever
 
         # Last reported too recently.
         yield self.event_injector.inject_message(room, user, "Who could disagree?")
         self.hs.clock.now += 60 * 60 * 22
         count = yield self.store.count_daily_messages()
         self.assertIsNone(count)
-        self._assert_stats_reporting(4, self.hs.clock.now)
+        yield self._assert_stats_reporting(4, self.hs.clock.now)
 
         # Last reported too long ago
         yield self.event_injector.inject_message(room, user, "No one.")
         self.hs.clock.now += 60 * 60 * 26
         count = yield self.store.count_daily_messages()
         self.assertIsNone(count)
-        self._assert_stats_reporting(5, self.hs.clock.now)
+        yield self._assert_stats_reporting(5, self.hs.clock.now)
 
         # And now let's actually report something
         yield self.event_injector.inject_message(room, user, "Indeed.")
@@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase):
         self.hs.clock.now += (60 * 60 * 24) + 50
         count = yield self.store.count_daily_messages()
         self.assertEqual(3, count)
-        self._assert_stats_reporting(8, self.hs.clock.now)
+        yield self._assert_stats_reporting(8, self.hs.clock.now)
 
     @defer.inlineCallbacks
     def _get_last_stream_token(self):