summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/cas_handler.py3
-rw-r--r--synapse/handlers/device.py1
-rw-r--r--synapse/handlers/devicemessage.py25
-rw-r--r--synapse/handlers/e2e_room_keys.py1
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/pagination.py24
-rw-r--r--synapse/handlers/presence.py29
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/handlers/room_member.py7
-rw-r--r--synapse/handlers/typing.py40
12 files changed, 91 insertions, 64 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 64aaa1335c..76f213723a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,11 +14,10 @@
 # limitations under the License.
 
 import logging
+import urllib
 import xml.etree.ElementTree as ET
 from typing import Dict, Optional, Tuple
 
-from six.moves import urllib
-
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 83f8fa1180..31346b56c3 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -691,6 +691,7 @@ class DeviceListUpdater(object):
 
         return False
 
+    @trace
     @defer.inlineCallbacks
     def _maybe_retry_device_resync(self):
         """Retry to resync device lists that are out of sync, except if another retry is
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 05c4b3eec0..610b08d00b 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -18,8 +18,6 @@ from typing import Any, Dict
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
 
         self._device_list_updater = hs.get_device_handler().device_list_updater
 
-    @defer.inlineCallbacks
-    def on_direct_to_device_edu(self, origin, content):
+    async def on_direct_to_device_edu(self, origin, content):
         local_messages = {}
         sender_user_id = content["sender"]
         if origin != get_domain_from_id(sender_user_id):
@@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
             }
             local_messages[user_id] = messages_by_device
 
-            yield self._check_for_unknown_devices(
+            await self._check_for_unknown_devices(
                 message_type, sender_user_id, by_device
             )
 
-        stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+        stream_id = await self.store.add_messages_from_remote_to_device_inbox(
             origin, message_id, local_messages
         )
 
@@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
             "to_device_key", stream_id, users=local_messages.keys()
         )
 
-    @defer.inlineCallbacks
-    def _check_for_unknown_devices(
+    async def _check_for_unknown_devices(
         self,
         message_type: str,
         sender_user_id: str,
         by_device: Dict[str, Dict[str, Any]],
     ):
-        """Checks inbound device messages for unkown remote devices, and if
+        """Checks inbound device messages for unknown remote devices, and if
         found marks the remote cache for the user as stale.
         """
 
@@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
             requesting_device_ids.add(device_id)
 
         # Check if we are tracking the devices of the remote user.
-        room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+        room_ids = await self.store.get_rooms_for_user(sender_user_id)
         if not room_ids:
             logger.info(
                 "Received device message from remote device we don't"
@@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
 
         # If we are tracking check that we know about the sending
         # devices.
-        cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+        cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
 
         unknown_devices = requesting_device_ids - set(cached_devices)
         if unknown_devices:
@@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
                 sender_user_id,
                 unknown_devices,
             )
-            yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+            await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
             run_in_background(
                 self._device_list_updater.user_device_resync, sender_user_id
             )
 
-    @defer.inlineCallbacks
-    def send_device_message(self, sender_user_id, message_type, messages):
+    async def send_device_message(self, sender_user_id, message_type, messages):
         set_tag("number_of_messages", len(messages))
         set_tag("sender", sender_user_id)
         local_messages = {}
@@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
                 }
 
         log_kv({"local_messages": local_messages})
-        stream_id = yield self.store.add_messages_to_device_inbox(
+        stream_id = await self.store.add_messages_to_device_inbox(
             local_messages, remote_edu_contents
         )
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 2efea801bc..f55470a707 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -349,6 +349,7 @@ class E2eRoomKeysHandler(object):
                     raise
 
             res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+            res["etag"] = str(res["etag"])
             return res
 
     @trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d6038d9995..873f6bc39f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,10 +19,9 @@
 
 import itertools
 import logging
+from http import HTTPStatus
 from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
-from six.moves import http_client, zip
-
 import attr
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
@@ -1194,7 +1193,7 @@ class FederationHandler(BaseHandler):
                 ev.event_id,
                 len(ev.prev_event_ids()),
             )
-            raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
 
         if len(ev.auth_event_ids()) > 10:
             logger.warning(
@@ -1202,7 +1201,7 @@ class FederationHandler(BaseHandler):
                 ev.event_id,
                 len(ev.auth_event_ids()),
             )
-            raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
     async def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
@@ -1545,7 +1544,7 @@ class FederationHandler(BaseHandler):
 
         # block any attempts to invite the server notices mxid
         if event.state_key == self._server_notices_mxid:
-            raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+            raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 354da9a3b5..200127d291 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -17,8 +17,6 @@
 import logging
 from typing import Optional, Tuple
 
-from six import string_types
-
 from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
@@ -715,7 +713,7 @@ class EventCreationHandler(object):
 
             spam_error = self.spam_checker.check_event_for_spam(event)
             if spam_error:
-                if not isinstance(spam_error, string_types):
+                if not isinstance(spam_error, str):
                     spam_error = "Spam is not permitted here"
                 raise SynapseError(403, spam_error, Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7fbc229502..da06582d4b 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
 from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes, Membership
@@ -97,8 +96,7 @@ class PaginationHandler(object):
                     job["longest_max_lifetime"],
                 )
 
-    @defer.inlineCallbacks
-    def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+    async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
         """Purge outdated events from rooms within the given retention range.
 
         If a default retention policy is defined in the server's configuration and its
@@ -137,7 +135,7 @@ class PaginationHandler(object):
             include_null,
         )
 
-        rooms = yield self.store.get_rooms_for_retention_period_in_range(
+        rooms = await self.store.get_rooms_for_retention_period_in_range(
             min_ms, max_ms, include_null
         )
 
@@ -165,9 +163,9 @@ class PaginationHandler(object):
             # Figure out what token we should start purging at.
             ts = self.clock.time_msec() - max_lifetime
 
-            stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
+            stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
 
-            r = yield self.store.get_room_event_before_stream_ordering(
+            r = await self.store.get_room_event_before_stream_ordering(
                 room_id, stream_ordering,
             )
             if not r:
@@ -227,8 +225,7 @@ class PaginationHandler(object):
         )
         return purge_id
 
-    @defer.inlineCallbacks
-    def _purge_history(self, purge_id, room_id, token, delete_local_events):
+    async def _purge_history(self, purge_id, room_id, token, delete_local_events):
         """Carry out a history purge on a room.
 
         Args:
@@ -237,14 +234,11 @@ class PaginationHandler(object):
             token (str): topological token to delete events before
             delete_local_events (bool): True to delete local events as well as
                 remote ones
-
-        Returns:
-            Deferred
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
-            with (yield self.pagination_lock.write(room_id)):
-                yield self.storage.purge_events.purge_history(
+            with await self.pagination_lock.write(room_id):
+                await self.storage.purge_events.purge_history(
                     room_id, token, delete_local_events
                 )
             logger.info("[purge] complete")
@@ -282,9 +276,7 @@ class PaginationHandler(object):
             await self.store.get_room_version_id(room_id)
 
             # first check that we have no users in this room
-            joined = await defer.maybeDeferred(
-                self.store.is_host_joined, room_id, self._server_name
-            )
+            joined = await self.store.is_host_joined(room_id, self._server_name)
 
             if joined:
                 raise SynapseError(400, "Users are still joined to this room")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 2e8914be14..d2f25ae12a 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set
+from typing import Dict, Iterable, List, Set, Tuple
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
@@ -773,7 +773,9 @@ class PresenceHandler(BasePresenceHandler):
 
         return False
 
-    async def get_all_presence_updates(self, last_id, current_id, limit):
+    async def get_all_presence_updates(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
         """
         Gets a list of presence update rows from between the given stream ids.
         Each row has:
@@ -785,10 +787,31 @@ class PresenceHandler(BasePresenceHandler):
         - last_user_sync_ts(int)
         - status_msg(int)
         - currently_active(int)
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
+
         # TODO(markjh): replicate the unpersisted changes.
         # This could use the in-memory stores for recent changes.
-        rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
+        rows = await self.store.get_all_presence_updates(
+            instance_name, last_id, current_id, limit
+        )
         return rows
 
     def notify_new_event(self):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 302efc1b9a..4b1e3073a8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from six import raise_from
-
 from twisted.internet import defer
 
 from synapse.api.errors import (
@@ -84,7 +82,7 @@ class BaseProfileHandler(BaseHandler):
                 )
                 return result
             except RequestSendFailed as e:
-                raise_from(SynapseError(502, "Failed to fetch profile"), e)
+                raise SynapseError(502, "Failed to fetch profile") from e
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
@@ -135,7 +133,7 @@ class BaseProfileHandler(BaseHandler):
                     ignore_backoff=True,
                 )
             except RequestSendFailed as e:
-                raise_from(SynapseError(502, "Failed to fetch profile"), e)
+                raise SynapseError(502, "Failed to fetch profile") from e
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
@@ -212,7 +210,7 @@ class BaseProfileHandler(BaseHandler):
                     ignore_backoff=True,
                 )
             except RequestSendFailed as e:
-                raise_from(SynapseError(502, "Failed to fetch profile"), e)
+                raise SynapseError(502, "Failed to fetch profile") from e
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f7401373ca..950a84acd0 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -24,8 +24,6 @@ import string
 from collections import OrderedDict
 from typing import Tuple
 
-from six import string_types
-
 from synapse.api.constants import (
     EventTypes,
     JoinRules,
@@ -595,7 +593,7 @@ class RoomCreationHandler(BaseHandler):
             "room_version", self.config.default_room_version.identifier
         )
 
-        if not isinstance(room_version_id, string_types):
+        if not isinstance(room_version_id, str):
             raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON)
 
         room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0f7af982f0..27c479da9e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,10 +17,9 @@
 
 import abc
 import logging
+from http import HTTPStatus
 from typing import Dict, Iterable, List, Optional, Tuple
 
-from six.moves import http_client
-
 from synapse import types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes, SynapseError
@@ -361,7 +360,7 @@ class RoomMemberHandler(object):
         if effective_membership_state == Membership.INVITE:
             # block any attempts to invite the server notices mxid
             if target.to_string() == self._server_notices_mxid:
-                raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+                raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
             block_invite = False
 
@@ -444,7 +443,7 @@ class RoomMemberHandler(object):
                 is_blocked = await self._is_server_notice_room(room_id)
                 if is_blocked:
                     raise SynapseError(
-                        http_client.FORBIDDEN,
+                        HTTPStatus.FORBIDDEN,
                         "You cannot reject this invite",
                         errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
                     )
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c7bc14c623..4330abb9f7 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,7 +15,7 @@
 
 import logging
 from collections import namedtuple
-from typing import List
+from typing import List, Tuple
 
 from twisted.internet import defer
 
@@ -259,14 +259,31 @@ class TypingHandler(object):
         )
 
     async def get_all_typing_updates(
-        self, last_id: int, current_id: int, limit: int
-    ) -> List[dict]:
-        """Get up to `limit` typing updates between the given tokens, earliest
-        updates first.
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        """Get updates for typing replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
@@ -280,9 +297,16 @@ class TypingHandler(object):
             serial = self._room_serials[room_id]
             if last_id < serial <= current_id:
                 typing = self._room_typing[room_id]
-                rows.append((serial, room_id, list(typing)))
+                rows.append((serial, [room_id, list(typing)]))
         rows.sort()
-        return rows[:limit]
+
+        limited = False
+        if len(rows) > limit:
+            rows = rows[:limit]
+            current_id = rows[-1][0]
+            limited = True
+
+        return rows, current_id, limited
 
     def get_current_token(self):
         return self._latest_room_serial