summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py166
-rw-r--r--synapse/push/httppusher.py17
-rw-r--r--synapse/replication/slave/storage/account_data.py5
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py3
-rw-r--r--synapse/replication/slave/storage/groups.py3
-rw-r--r--synapse/replication/slave/storage/presence.py3
-rw-r--r--synapse/replication/slave/storage/push_rule.py3
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/slave/storage/receipts.py11
-rw-r--r--synapse/replication/slave/storage/room.py3
-rw-r--r--synapse/storage/data_stores/main/cache.py8
-rw-r--r--synapse/storage/data_stores/main/events_worker.py6
-rw-r--r--synapse/storage/engines/postgres.py2
13 files changed, 126 insertions, 107 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4dbd8e1d98..b5aaa244dd 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,8 +19,9 @@
 
 import itertools
 import logging
+from collections import Container
 from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
                 # device and recognize the algorithm then we can work out the
                 # exact key to expect. Otherwise check it matches any key we
                 # have for that device.
+
+                current_keys = []  # type: Container[str]
+
                 if device:
                     keys = device.get("keys", {}).get("keys", {})
 
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
                         current_keys = keys.values()
                 elif device_id:
                     # We don't have any keys for the device ID.
-                    current_keys = []
+                    pass
                 else:
                     # The event didn't include a device ID, so we just look for
                     # keys across all devices.
-                    current_keys = (
+                    current_keys = [
                         key
                         for device in cached_devices
                         for key in device.get("keys", {}).get("keys", {}).values()
-                    )
+                    ]
 
                 # We now check that the sender key matches (one of) the expected
                 # keys.
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
                 if e_type == EventTypes.Member and event.membership == Membership.JOIN
             ]
 
-            joined_domains = {}
+            joined_domains = {}  # type: Dict[str, int]
             for u, d in joined_users:
                 try:
                     dom = get_domain_from_id(u)
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
         try:
             # Try the host we successfully got a response to /make_join/
             # request first.
+            host_list = list(target_hosts)
             try:
-                target_hosts.remove(origin)
-                target_hosts.insert(0, origin)
+                host_list.remove(origin)
+                host_list.insert(0, origin)
             except ValueError:
                 pass
 
             ret = await self.federation_client.send_join(
-                target_hosts, event, room_version_obj
+                host_list, event, room_version_obj
             )
 
             origin = ret["origin"]
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
 
         # Try the host that we succesfully called /make_leave/ on first for
         # the /send_leave/ request.
+        host_list = list(target_hosts)
         try:
-            target_hosts.remove(origin)
-            target_hosts.insert(0, origin)
+            host_list.remove(origin)
+            host_list.insert(0, origin)
         except ValueError:
             pass
 
-        await self.federation_client.send_leave(target_hosts, event)
+        await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
         stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
         user_id: str,
         membership: str,
         content: JsonDict = {},
-        params: Optional[Dict[str, str]] = None,
+        params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
     ) -> Tuple[str, EventBase, RoomVersion]:
         (
             origin,
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
             auth_events_ids = await self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_x = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
         # For new (non-backfilled and non-outlier) events we check if the event
         # passes auth based on the current state. If it doesn't then we
         # "soft-fail" the event.
-        do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
-        if do_soft_fail_check:
-            extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
-            extrem_ids = set(extrem_ids)
-            prev_event_ids = set(event.prev_event_ids())
-
-            if extrem_ids == prev_event_ids:
-                # If they're the same then the current state is the same as the
-                # state at the event, so no point rechecking auth for soft fail.
-                do_soft_fail_check = False
-
-        if do_soft_fail_check:
-            room_version = await self.store.get_room_version_id(event.room_id)
-            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
-            # Calculate the "current state".
-            if state is not None:
-                # If we're explicitly given the state then we won't have all the
-                # prev events, and so we have a gap in the graph. In this case
-                # we want to be a little careful as we might have been down for
-                # a while and have an incorrect view of the current state,
-                # however we still want to do checks as gaps are easy to
-                # maliciously manufacture.
-                #
-                # So we use a "current state" that is actually a state
-                # resolution across the current forward extremities and the
-                # given state at the event. This should correctly handle cases
-                # like bans, especially with state res v2.
+        if backfilled or event.internal_metadata.is_outlier():
+            return
 
-                state_sets = await self.state_store.get_state_groups(
-                    event.room_id, extrem_ids
-                )
-                state_sets = list(state_sets.values())
-                state_sets.append(state)
-                current_state_ids = await self.state_handler.resolve_events(
-                    room_version, state_sets, event
-                )
-                current_state_ids = {
-                    k: e.event_id for k, e in current_state_ids.items()
-                }
-            else:
-                current_state_ids = await self.state_handler.get_current_state_ids(
-                    event.room_id, latest_event_ids=extrem_ids
-                )
+        extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+        extrem_ids = set(extrem_ids)
+        prev_event_ids = set(event.prev_event_ids())
 
-            logger.debug(
-                "Doing soft-fail check for %s: state %s",
-                event.event_id,
-                current_state_ids,
+        if extrem_ids == prev_event_ids:
+            # If they're the same then the current state is the same as the
+            # state at the event, so no point rechecking auth for soft fail.
+            return
+
+        room_version = await self.store.get_room_version_id(event.room_id)
+        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+        # Calculate the "current state".
+        if state is not None:
+            # If we're explicitly given the state then we won't have all the
+            # prev events, and so we have a gap in the graph. In this case
+            # we want to be a little careful as we might have been down for
+            # a while and have an incorrect view of the current state,
+            # however we still want to do checks as gaps are easy to
+            # maliciously manufacture.
+            #
+            # So we use a "current state" that is actually a state
+            # resolution across the current forward extremities and the
+            # given state at the event. This should correctly handle cases
+            # like bans, especially with state res v2.
+
+            state_sets = await self.state_store.get_state_groups(
+                event.room_id, extrem_ids
+            )
+            state_sets = list(state_sets.values())
+            state_sets.append(state)
+            current_state_ids = await self.state_handler.resolve_events(
+                room_version, state_sets, event
+            )
+            current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+        else:
+            current_state_ids = await self.state_handler.get_current_state_ids(
+                event.room_id, latest_event_ids=extrem_ids
             )
 
-            # Now check if event pass auth against said current state
-            auth_types = auth_types_for_event(event)
-            current_state_ids = [
-                e for k, e in current_state_ids.items() if k in auth_types
-            ]
+        logger.debug(
+            "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+        )
 
-            current_auth_events = await self.store.get_events(current_state_ids)
-            current_auth_events = {
-                (e.type, e.state_key): e for e in current_auth_events.values()
-            }
+        # Now check if event pass auth against said current state
+        auth_types = auth_types_for_event(event)
+        current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
 
-            try:
-                event_auth.check(
-                    room_version_obj, event, auth_events=current_auth_events
-                )
-            except AuthError as e:
-                logger.warning("Soft-failing %r because %s", event, e)
-                event.internal_metadata.soft_failed = True
+        current_auth_events = await self.store.get_events(current_state_ids)
+        current_auth_events = {
+            (e.type, e.state_key): e for e in current_auth_events.values()
+        }
+
+        try:
+            event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+        except AuthError as e:
+            logger.warning("Soft-failing %r because %s", event, e)
+            event.internal_metadata.soft_failed = True
 
     async def on_query_auth(
         self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
                     remote_auth_chain = await self.federation_client.get_event_auth(
                         origin, event.room_id, event.event_id
                     )
-                except RequestSendFailed as e:
+                except RequestSendFailed as e1:
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
-                    logger.info("Failed to get event auth from remote: %s", e)
+                    logger.info("Failed to get event auth from remote: %s", e1)
                     return context
 
                 seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Checking auth on event %r", event.content)
 
-        last_exception = None
+        last_exception = None  # type: Optional[Exception]
+
         # for each public key in the 3pid invite event
         for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
             try:
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
                         return
             except Exception as e:
                 last_exception = e
+
+        if last_exception is None:
+            # we can only get here if get_public_keys() returned an empty list
+            # TODO: make this better
+            raise RuntimeError("no public key in invite event")
+
         raise last_exception
 
     async def _check_key_revocation(self, public_key, url):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ed60dbc1bf..2fac07593b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -20,6 +20,7 @@ from prometheus_client import Counter
 from twisted.internet import defer
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
+from synapse.api.constants import EventTypes
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
@@ -305,12 +306,23 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _build_notification_dict(self, event, tweaks, badge):
+        priority = "low"
+        if (
+            event.type == EventTypes.Encrypted
+            or tweaks.get("highlight")
+            or tweaks.get("sound")
+        ):
+            # HACK send our push as high priority only if it generates a sound, highlight
+            #  or may do so (i.e. is encrypted so has unknown effects).
+            priority = "high"
+
         if self.data.get("format") == "event_id_only":
             d = {
                 "notification": {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
                     "counts": {"unread": badge},
+                    "prio": priority,
                     "devices": [
                         {
                             "app_id": self.app_id,
@@ -334,9 +346,8 @@ class HttpPusher(object):
                 "room_id": event.room_id,
                 "type": event.type,
                 "sender": event.user_id,
-                "counts": {  # -- we don't mark messages as read yet so
-                    # we have no way of knowing
-                    # Just set the badge to 1 until we have read receipts
+                "prio": priority,
+                "counts": {
                     "unread": badge,
                     # 'missed_calls': 2
                 },
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..525b94fd87 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,6 +16,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.storage.data_stores.main.tags import TagsWorkerStore
 from synapse.storage.database import Database
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
         return self._account_data_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "tag_account_data":
+        if stream_name == TagAccountDataStream.NAME:
             self._account_data_id_gen.advance(token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == "account_data":
+        elif stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(token)
             for row in rows:
                 if not row.room_id:
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..bd394f6b00 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,6 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
 from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "to_device":
+        if stream_name == ToDeviceStream.NAME:
             self._device_inbox_id_gen.advance(token)
             for row in rows:
                 if row.entity.startswith("@"):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..5d210fa3a1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,6 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import GroupServerStream
 from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
 from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
         return self._group_updates_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "groups":
+        if stream_name == GroupServerStream.NAME:
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..2938cb8e43 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage import DataStore
 from synapse.storage.data_stores.main.presence import PresenceStore
 from synapse.storage.database import Database
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
         return self._presence_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "presence":
+        if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..23ec1c5b11 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
 
 from .events import SlavedEventStore
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "push_rules":
+        if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..ff449f3658 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.database import Database
 
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         return self._pushers_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "pushers":
+        if stream_name == PushersStream.NAME:
             self._pushers_id_gen.advance(token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..6982686eb5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,20 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
     def __init__(self, database: Database, db_conn, hs):
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "receipts":
+        if stream_name == ReceiptsStream.NAME:
             self._receipts_id_gen.advance(token)
             for row in rows:
                 self.invalidate_caches_for_receipt(
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..8710207ada 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PublicRoomsStream
 from synapse.storage.data_stores.main.room import RoomWorkerStore
 from synapse.storage.database import Database
 
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
         return self._public_room_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "public_rooms":
+        if stream_name == PublicRoomsStream.NAME:
             self._public_room_id_gen.advance(token)
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index eac5a4e55b..d30766e543 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -19,7 +19,9 @@ import logging
 from typing import Any, Iterable, Optional, Tuple
 
 from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams import BackfillStream, CachesStream
 from synapse.replication.tcp.streams.events import (
+    EventsStream,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
 )
@@ -71,10 +73,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
+        if stream_name == EventsStream.NAME:
             for row in rows:
                 self._process_event_stream_row(token, row)
-        elif stream_name == "backfill":
+        elif stream_name == BackfillStream.NAME:
             for row in rows:
                 self._invalidate_caches_for_event(
                     -token,
@@ -86,7 +88,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     row.relates_to,
                     backfilled=True,
                 )
-        elif stream_name == "caches":
+        elif stream_name == CachesStream.NAME:
             if self._cache_id_gen:
                 self._cache_id_gen.advance(instance_name, token)
 
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index a48c7a96ca..47a3e63589 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
@@ -113,9 +115,9 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_ongoing = 0
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
+        if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(token)
-        elif stream_name == "backfill":
+        elif stream_name == BackfillStream.NAME:
             self._backfill_id_gen.advance(-token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6c7d08a6f2..a31588080d 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine):
             errors.append("    - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
 
         if ctype != "C":
-            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
 
         if errors:
             raise IncorrectDatabaseSetup(