diff options
-rw-r--r-- | changelog.d/7760.bugfix | 1 | ||||
-rw-r--r-- | changelog.d/7765.misc | 1 | ||||
-rw-r--r-- | changelog.d/7768.misc | 1 | ||||
-rw-r--r-- | changelog.d/7769.misc | 1 | ||||
-rw-r--r-- | changelog.d/7770.misc | 1 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 166 | ||||
-rw-r--r-- | synapse/push/httppusher.py | 17 | ||||
-rw-r--r-- | synapse/replication/slave/storage/account_data.py | 5 | ||||
-rw-r--r-- | synapse/replication/slave/storage/deviceinbox.py | 3 | ||||
-rw-r--r-- | synapse/replication/slave/storage/groups.py | 3 | ||||
-rw-r--r-- | synapse/replication/slave/storage/presence.py | 3 | ||||
-rw-r--r-- | synapse/replication/slave/storage/push_rule.py | 3 | ||||
-rw-r--r-- | synapse/replication/slave/storage/pushers.py | 3 | ||||
-rw-r--r-- | synapse/replication/slave/storage/receipts.py | 11 | ||||
-rw-r--r-- | synapse/replication/slave/storage/room.py | 3 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/cache.py | 8 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/events_worker.py | 6 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 2 | ||||
-rw-r--r-- | tests/push/test_http.py | 352 | ||||
-rw-r--r-- | tox.ini | 1 |
20 files changed, 479 insertions, 112 deletions
diff --git a/changelog.d/7760.bugfix b/changelog.d/7760.bugfix new file mode 100644 index 0000000000..f6081f3d30 --- /dev/null +++ b/changelog.d/7760.bugfix @@ -0,0 +1 @@ +Fix incorrect error message when database CTYPE was set incorrectly. diff --git a/changelog.d/7765.misc b/changelog.d/7765.misc new file mode 100644 index 0000000000..fa9cfd24cb --- /dev/null +++ b/changelog.d/7765.misc @@ -0,0 +1 @@ +Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. diff --git a/changelog.d/7768.misc b/changelog.d/7768.misc new file mode 100644 index 0000000000..dfb3d24c7d --- /dev/null +++ b/changelog.d/7768.misc @@ -0,0 +1 @@ +Use symbolic names for replication stream names. diff --git a/changelog.d/7769.misc b/changelog.d/7769.misc new file mode 100644 index 0000000000..2e200286ce --- /dev/null +++ b/changelog.d/7769.misc @@ -0,0 +1 @@ +Add early returns to `_check_for_soft_fail`. diff --git a/changelog.d/7770.misc b/changelog.d/7770.misc new file mode 100644 index 0000000000..5b864084be --- /dev/null +++ b/changelog.d/7770.misc @@ -0,0 +1 @@ +Fix up `synapse.handlers.federation` to pass mypy. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4dbd8e1d98..b5aaa244dd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,8 +19,9 @@ import itertools import logging +from collections import Container from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -742,6 +743,9 @@ class FederationHandler(BaseHandler): # device and recognize the algorithm then we can work out the # exact key to expect. Otherwise check it matches any key we # have for that device. + + current_keys = [] # type: Container[str] + if device: keys = device.get("keys", {}).get("keys", {}) @@ -758,15 +762,15 @@ class FederationHandler(BaseHandler): current_keys = keys.values() elif device_id: # We don't have any keys for the device ID. - current_keys = [] + pass else: # The event didn't include a device ID, so we just look for # keys across all devices. - current_keys = ( + current_keys = [ key for device in cached_devices for key in device.get("keys", {}).get("keys", {}).values() - ) + ] # We now check that the sender key matches (one of) the expected # keys. @@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler): if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} + joined_domains = {} # type: Dict[str, int] for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler): try: # Try the host we successfully got a response to /make_join/ # request first. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass ret = await self.federation_client.send_join( - target_hosts, event, room_version_obj + host_list, event, room_version_obj ) origin = ret["origin"] @@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler): # Try the host that we succesfully called /make_leave/ on first for # the /send_leave/ request. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass - await self.federation_client.send_leave(target_hosts, event) + await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) stream_id = await self.persist_events_and_notify([(event, context)]) @@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler): user_id: str, membership: str, content: JsonDict = {}, - params: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( origin, @@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler): auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler): # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() - if do_soft_fail_check: - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - - extrem_ids = set(extrem_ids) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - do_soft_fail_check = False - - if do_soft_fail_check: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. + if backfilled or event.internal_metadata.is_outlier(): + return - state_sets = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets = list(state_sets.values()) - state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids = { - k: e.event_id for k, e in current_state_ids.items() - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) + extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids) + prev_event_ids = set(event.prev_event_ids()) - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets = list(state_sets.values()) + state_sets.append(state) + current_state_ids = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids ) - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(event) - current_state_ids = [ - e for k, e in current_state_ids.items() if k in auth_types - ] + logger.debug( + "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + ) - current_auth_events = await self.store.get_events(current_state_ids) - current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() - } + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(event) + current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - try: - event_auth.check( - room_version_obj, event, auth_events=current_auth_events - ) - except AuthError as e: - logger.warning("Soft-failing %r because %s", event, e) - event.internal_metadata.soft_failed = True + current_auth_events = await self.store.get_events(current_state_ids) + current_auth_events = { + (e.type, e.state_key): e for e in current_auth_events.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning("Soft-failing %r because %s", event, e) + event.internal_metadata.soft_failed = True async def on_query_auth( self, origin, event_id, room_id, remote_auth_chain, rejects, missing @@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler): remote_auth_chain = await self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) - except RequestSendFailed as e: + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e) + logger.info("Failed to get event auth from remote: %s", e1) return context seen_remotes = await self.store.have_seen_events( @@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler): logger.debug("Checking auth on event %r", event.content) - last_exception = None + last_exception = None # type: Optional[Exception] + # for each public key in the 3pid invite event for public_key_object in self.hs.get_auth().get_public_keys(invite_event): try: @@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler): return except Exception as e: last_exception = e + + if last_exception is None: + # we can only get here if get_public_keys() returned an empty list + # TODO: make this better + raise RuntimeError("no public key in invite event") + raise last_exception async def _check_key_revocation(self, public_key, url): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index ed60dbc1bf..2fac07593b 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -20,6 +20,7 @@ from prometheus_client import Counter from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.api.constants import EventTypes from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException @@ -305,12 +306,23 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): + priority = "low" + if ( + event.type == EventTypes.Encrypted + or tweaks.get("highlight") + or tweaks.get("sound") + ): + # HACK send our push as high priority only if it generates a sound, highlight + # or may do so (i.e. is encrypted so has unknown effects). + priority = "high" + if self.data.get("format") == "event_id_only": d = { "notification": { "event_id": event.event_id, "room_id": event.room_id, "counts": {"unread": badge}, + "prio": priority, "devices": [ { "app_id": self.app_id, @@ -334,9 +346,8 @@ class HttpPusher(object): "room_id": event.room_id, "type": event.type, "sender": event.user_id, - "counts": { # -- we don't mark messages as read yet so - # we have no way of knowing - # Just set the badge to 1 until we have read receipts + "prio": priority, + "counts": { "unread": badge, # 'missed_calls': 2 }, diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 9db6c62bc7..525b94fd87 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -16,6 +16,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore from synapse.storage.database import Database @@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved return self._account_data_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "tag_account_data": + if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == "account_data": + elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 6e7fd259d4..bd394f6b00 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -15,6 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache @@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "to_device": + if stream_name == ToDeviceStream.NAME: self._device_inbox_id_gen.advance(token) for row in rows: if row.entity.startswith("@"): diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 1851e7d525..5d210fa3a1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -15,6 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import GroupServerStream from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): return self._group_updates_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "groups": + if stream_name == GroupServerStream.NAME: self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 4e0124842d..2938cb8e43 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PresenceStream from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore from synapse.storage.database import Database @@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore): return self._presence_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "presence": + if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6adb19463a..23ec1c5b11 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from .events import SlavedEventStore @@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "push_rules": + if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index cb78b49acb..ff449f3658 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PushersStream from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.database import Database @@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): return self._pushers_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "pushers": + if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index be716cc558..6982686eb5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,20 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): @@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "receipts": + if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(token) for row in rows: self.invalidate_caches_for_receipt( diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 8873bf37e5..8710207ada 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PublicRoomsStream from synapse.storage.data_stores.main.room import RoomWorkerStore from synapse.storage.database import Database @@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): return self._public_room_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "public_rooms": + if stream_name == PublicRoomsStream.NAME: self._public_room_id_gen.advance(token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index eac5a4e55b..d30766e543 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -19,7 +19,9 @@ import logging from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams.events import ( + EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, ) @@ -71,10 +73,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": + if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) - elif stream_name == "backfill": + elif stream_name == BackfillStream.NAME: for row in rows: self._invalidate_caches_for_event( -token, @@ -86,7 +88,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): row.relates_to, backfilled=True, ) - elif stream_name == "caches": + elif stream_name == CachesStream.NAME: if self._cache_id_gen: self._cache_id_gen.advance(instance_name, token) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index a48c7a96ca..47a3e63589 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -38,6 +38,8 @@ from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import BackfillStream +from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator @@ -113,9 +115,9 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_ongoing = 0 def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": + if stream_name == EventsStream.NAME: self._stream_id_gen.advance(token) - elif stream_name == "backfill": + elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(-token) super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 6c7d08a6f2..a31588080d 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine): errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,)) if ctype != "C": - errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,)) + errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,)) if errors: raise IncorrectDatabaseSetup( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index baf9c785f4..b567868b02 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase class HTTPPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor, clock): - self.push_attempts = [] m = Mock() @@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase): # Create a room room = self.helper.create_room_as(user_id, tok=access_token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - # The other user joins self.helper.join(room=room, user=other_user_id, tok=other_access_token) @@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + + def test_sends_high_priority_for_encrypted(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to an encrypted message. + This will happen both in 1:1 rooms and larger rooms. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send an encrypted event + # I know there'd normally be set-up of an encrypted room first + # but this will do for our purposes + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM", + "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk" + "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH" + "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w" + "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT" + "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv" + "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y" + "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl" + "EQ0", + "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A", + "device_id": "AHQDUSTAAA", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Add yet another person — we want to make this room not a 1:1 + # (as encrypted messages in a 1:1 currently have tweaks applied + # so it doesn't properly exercise the condition of all encrypted + # messages need to be high). + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another encrypted event + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m" + "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq" + "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH" + "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/" + "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC" + "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw", + "device_id": "SRCFTWTHXO", + "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4", + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") + + def test_sends_high_priority_for_one_to_one_only(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message in a one-to-one room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Hi!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority — this is a one-to-one room + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Yet another user joins + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another event + self.helper.send(room, body="Welcome!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_mention(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message containing the user's display name. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other users join + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Oh, user, hello!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time with no mention + self.helper.send(room, body="Are you there?", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_atroom(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message that contains @room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room (as other_user so the power levels are compatible with + # other_user sending @room). + room = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The other users join + self.helper.join(room=room, user=user_id, tok=access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send( + room, + body="@room eeek! There's a spider on the table!", + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time as someone without the power of @room + self.helper.send( + room, body="@room the spider is gone", tok=yet_another_access_token + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") diff --git a/tox.ini b/tox.ini index 812fbff200..ab6557f15e 100644 --- a/tox.ini +++ b/tox.ini @@ -184,6 +184,7 @@ commands = mypy \ synapse/handlers/auth.py \ synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ + synapse/handlers/federation.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/room_member.py \ |