summary refs log tree commit diff
diff options
context:
space:
mode:
authorWill Hunt <will@half-shot.uk>2020-10-01 14:50:29 +0100
committerWill Hunt <will@half-shot.uk>2020-10-01 15:00:14 +0100
commit97d173991098ea6221c21fbe2b044f5173b182e4 (patch)
tree801c62cfbb77de9fdafcbe1407e26f56333a2b40
parentchangelog (diff)
downloadsynapse-97d173991098ea6221c21fbe2b044f5173b182e4.tar.xz
Fixup types
-rw-r--r--synapse/appservice/__init__.py27
-rw-r--r--synapse/appservice/api.py16
-rw-r--r--synapse/appservice/scheduler.py31
-rw-r--r--synapse/handlers/receipts.py5
-rw-r--r--synapse/storage/databases/main/appservice.py17
-rw-r--r--synapse/storage/databases/main/receipts.py4
6 files changed, 72 insertions, 28 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 1335175009..2a6a180665 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, List
 
 from synapse.api.constants import EventTypes
 from synapse.appservice.api import ApplicationServiceApi
-from synapse.types import GroupID, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.events import EventBase
+from synapse.types import GroupID, UserID, get_domain_from_id
+from synapse.util.caches.descriptors import _CacheContext, cached
 
 if TYPE_CHECKING:
     from synapse.storage.databases.main import DataStore
@@ -35,7 +36,13 @@ class ApplicationServiceState:
 class AppServiceTransaction:
     """Represents an application service transaction."""
 
-    def __init__(self, service, id, events, ephemeral=None):
+    def __init__(
+        self,
+        service: ApplicationService,
+        id: int,
+        events: List[EventBase],
+        ephemeral=None,
+    ):
         self.service = service
         self.id = id
         self.events = events
@@ -198,9 +205,11 @@ class ApplicationService:
         return does_match
 
     @cached(num_args=1, cache_context=True)
-    async def matches_user_in_member_list(self, room_id, store, cache_context):
+    async def matches_user_in_member_list(
+        self, room_id: str, store: DataStore, cache_context: _CacheContext
+    ):
         member_list = await store.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate
+            room_id
         )
 
         # check joined member events
@@ -246,7 +255,9 @@ class ApplicationService:
         return False
 
     @cached(num_args=1, cache_context=True)
-    async def is_interested_in_presence(self, user_id, store, cache_context):
+    async def is_interested_in_presence(
+        self, user_id: UserID, store: DataStore, cache_context: _CacheContext
+    ):
         # Find all the rooms the sender is in
         if self.is_interested_in_user(user_id.to_string()):
             return True
@@ -254,7 +265,7 @@ class ApplicationService:
 
         # Then find out if the appservice is interested in any of those rooms
         for room_id in room_ids:
-            if await self.matches_user_in_member_list(room_id, store, cache_context):
+            if await self.matches_user_in_member_list(room_id, store):
                 return True
         return False
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6d0038ddd1..d405c1c7e3 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 import logging
 import urllib
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, List, Optional
 
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
+from synapse.events import EventBase
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, ThirdPartyInstanceID
@@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
         key = (service.id, protocol)
         return await self.protocol_meta_cache.wrap(key, _get)
 
-    async def push_bulk(self, service, events, ephemeral=None, txn_id=None):
+    async def push_bulk(
+        self,
+        service: ApplicationService,
+        events: List[EventBase],
+        ephemeral: Optional[Any] = None,
+        txn_id: Optional[int] = None,
+    ):
         if service.url is None:
             return True
 
@@ -211,10 +218,9 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning(
                 "push_bulk: Missing txn ID sending events to %s", service.url
             )
-            txn_id = str(0)
-        txn_id = str(txn_id)
+            txn_id = 0
 
-        uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
+        uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
         body = {"events": events}
         if ephemeral:
             body["de.sorunome.msc2409.ephemeral"] = ephemeral
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 31b664dd36..03b81b0e9f 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -49,8 +49,10 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 import logging
+from typing import Any, List, Optional
 
-from synapse.appservice import ApplicationServiceState
+from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.events import EventBase
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 
@@ -82,10 +84,12 @@ class ApplicationServiceScheduler:
         for service in services:
             self.txn_ctrl.start_recoverer(service)
 
-    def submit_event_for_as(self, service, event):
+    def submit_event_for_as(self, service: ApplicationService, event: EventBase):
         self.queuer.enqueue(service, event)
 
-    def submit_ephemeral_events_for_as(self, service, events):
+    def submit_ephemeral_events_for_as(
+        self, service: ApplicationService, events: List[Any]
+    ):
         self.queuer.enqueue_ephemeral(service, events)
 
 
@@ -99,7 +103,7 @@ class _ServiceQueuer:
 
     def __init__(self, txn_ctrl, clock):
         self.queued_events = {}  # dict of {service_id: [events]}
-        self.queued_ephemeral = {} # dict of {service_id: [events]}
+        self.queued_ephemeral = {}  # dict of {service_id: [events]}
 
         # the appservices which currently have a transaction in flight
         self.requests_in_flight = set()
@@ -118,7 +122,7 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id), self._send_request, service
         )
 
-    def enqueue_ephemeral(self, service, events):
+    def enqueue_ephemeral(self, service: ApplicationService, events: List[Any]):
         self.queued_ephemeral.setdefault(service.id, []).extend(events)
 
         # start a sender for this appservice if we don't already have one
@@ -130,7 +134,9 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id), self._send_request, service
         )
 
-    async def _send_request(self, service, ephemeral=None):
+    async def _send_request(
+        self, service: ApplicationService, ephemeral: Optional[Any] = None
+    ):
         # sanity-check: we shouldn't get here if this service already has a sender
         # running.
         assert service.id not in self.requests_in_flight
@@ -175,9 +181,16 @@ class _TransactionController:
         # for UTs
         self.RECOVERER_CLASS = _Recoverer
 
-    async def send(self, service, events, ephemeral=None):
+    async def send(
+        self,
+        service: ApplicationService,
+        events: List[EventBase],
+        ephemeral: Optional[Any] = None,
+    ):
         try:
-            txn = await self.store.create_appservice_txn(service=service, events=events, ephemeral=ephemeral)
+            txn = await self.store.create_appservice_txn(
+                service=service, events=events, ephemeral=ephemeral
+            )
             service_is_up = await self.is_service_up(service)
             if service_is_up:
                 sent = await txn.send(self.as_api)
@@ -221,7 +234,7 @@ class _TransactionController:
         recoverer.recover()
         logger.info("Now %i active recoverers", len(self.recoverers))
 
-    async def is_service_up(self, service):
+    async def is_service_up(self, service: ApplicationService):
         state = await self.store.get_appservice_state(service)
         return state == ApplicationServiceState.UP or state is None
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index d9e4b1c271..a6db85c888 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import logging
 
+from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
 from synapse.types import ReadReceipt, get_domain_from_id
 from synapse.util.async_helpers import maybe_awaitable
@@ -140,7 +141,9 @@ class ReceiptEventSource:
 
         return (events, to_key)
 
-    async def get_new_events_as(self, from_key, service, **kwargs):
+    async def get_new_events_as(
+        self, from_key: int, service: ApplicationService, **kwargs
+    ):
         from_key = int(from_key)
         to_key = self.get_current_key()
 
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index b29ce1d1be..a1d3f4be16 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,9 +15,11 @@
 # limitations under the License.
 import logging
 import re
+from typing import Any, List, Optional
 
-from synapse.appservice import AppServiceTransaction
+from synapse.appservice import ApplicationService, AppServiceTransaction
 from synapse.config.appservice import load_appservices
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -172,7 +174,12 @@ class ApplicationServiceTransactionWorkerStore(
             "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
-    async def create_appservice_txn(self, service, events, ephemeral=None):
+    async def create_appservice_txn(
+        self,
+        service: ApplicationService,
+        events: List[EventBase],
+        ephemeral: Optional[Any] = None,
+    ):
         """Atomically creates a new transaction for this application service
         with the given list of events.
 
@@ -353,7 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
 
         return upper_bound, events
 
-    async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
+    async def get_type_stream_id_for_appservice(
+        self, service: ApplicationService, type: str
+    ) -> int:
         def get_type_stream_id_for_appservice_txn(txn):
             stream_id_type = "%s_stream_id" % type
             txn.execute(
@@ -371,7 +380,7 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service, type: str, pos: int
+        self, service: ApplicationService, type: str, pos: int
     ) -> None:
         def set_type_stream_id_for_appservice_txn(txn):
             stream_id_type = "%s_stream_id" % type
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index d26c315ed4..66862acb7d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -284,7 +284,9 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         return results
 
     @cached(num_args=2,)
-    async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
+    async def _get_linearized_receipts_for_all_rooms(
+        self, to_key: int, from_key: Optional[int] = None
+    ):
         def f(txn):
             if from_key:
                 sql = """