summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-08-12 14:03:08 +0100
committerGitHub <noreply@github.com>2020-08-12 14:03:08 +0100
commit9d1e4942ab728ebfe09ff9a63c66708ceaaf7591 (patch)
tree263d97d6937a48a4cdef067b961aa4d593ce71d2 /synapse
parentMerge pull request #8060 from matrix-org/erikj/type_server (diff)
downloadsynapse-9d1e4942ab728ebfe09ff9a63c66708ceaaf7591.tar.xz
Fix typing for notifier (#8064)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/sender/transaction_manager.py7
-rw-r--r--synapse/notifier.py12
-rw-r--r--synapse/types.py23
-rw-r--r--synapse/util/metrics.py9
4 files changed, 35 insertions, 16 deletions
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 8280f8b900..c7f6cb3d73 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Tuple
 
 from canonicaljson import json
 
@@ -54,7 +54,10 @@ class TransactionManager(object):
 
     @measure_func("_send_new_transaction")
     async def send_new_transaction(
-        self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
+        self,
+        destination: str,
+        pending_pdus: List[Tuple[EventBase, int]],
+        pending_edus: List[Edu],
     ):
 
         # Make a transaction-sending opentracing span. This span follows on from
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 694efe7116..dfb096e589 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -25,6 +25,7 @@ from typing import (
     Set,
     Tuple,
     TypeVar,
+    Union,
 )
 
 from prometheus_client import Counter
@@ -186,7 +187,7 @@ class Notifier(object):
         self.store = hs.get_datastore()
         self.pending_new_room_events = (
             []
-        )  # type: List[Tuple[int, EventBase, Collection[str]]]
+        )  # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
 
         # Called when there are new things to stream over replication
         self.replication_callbacks = []  # type: List[Callable[[], None]]
@@ -246,7 +247,7 @@ class Notifier(object):
         event: EventBase,
         room_stream_id: int,
         max_room_stream_id: int,
-        extra_users: Collection[str] = [],
+        extra_users: Collection[Union[str, UserID]] = [],
     ):
         """ Used by handlers to inform the notifier something has happened
         in the room, room event wise.
@@ -282,7 +283,10 @@ class Notifier(object):
                 self._on_new_room_event(event, room_stream_id, extra_users)
 
     def _on_new_room_event(
-        self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
+        self,
+        event: EventBase,
+        room_stream_id: int,
+        extra_users: Collection[Union[str, UserID]] = [],
     ):
         """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
@@ -310,7 +314,7 @@ class Notifier(object):
         self,
         stream_key: str,
         new_token: int,
-        users: Collection[str] = [],
+        users: Collection[Union[str, UserID]] = [],
         rooms: Collection[str] = [],
     ):
         """ Used to inform listeners that something has happened event wise.
diff --git a/synapse/types.py b/synapse/types.py
index 238b938064..9e580f4295 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -13,11 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import re
 import string
 import sys
 from collections import namedtuple
-from typing import Any, Dict, Tuple, TypeVar
+from typing import Any, Dict, Tuple, Type, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -33,7 +34,7 @@ else:
 
     T_co = TypeVar("T_co", covariant=True)
 
-    class Collection(Iterable[T_co], Container[T_co], Sized):
+    class Collection(Iterable[T_co], Container[T_co], Sized):  # type: ignore
         __slots__ = ()
 
 
@@ -141,6 +142,9 @@ def get_localpart_from_id(string):
     return string[1:idx]
 
 
+DS = TypeVar("DS", bound="DomainSpecificString")
+
+
 class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
     """Common base class among ID/name strings that have a local part and a
     domain name, prefixed with a sigil.
@@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
         'domain' : The domain part of the name
     """
 
+    __metaclass__ = abc.ABCMeta
+
+    SIGIL = abc.abstractproperty()  # type: str  # type: ignore
+
     # Deny iteration because it will bite you if you try to create a singleton
     # set by:
     #    users = set(user)
@@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
         return self
 
     @classmethod
-    def from_string(cls, s: str):
+    def from_string(cls: Type[DS], s: str) -> DS:
         """Parse the string given by 's' into a structure object."""
         if len(s) < 1 or s[0:1] != cls.SIGIL:
             raise SynapseError(
@@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
         # names on one HS
         return cls(localpart=parts[0], domain=domain)
 
-    def to_string(self):
+    def to_string(self) -> str:
         """Return a string encoding the fields of the structure object."""
         return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
 
     @classmethod
-    def is_valid(cls, s):
+    def is_valid(cls: Type[DS], s: str) -> bool:
         try:
             cls.from_string(s)
             return True
@@ -235,8 +243,9 @@ class GroupID(DomainSpecificString):
     SIGIL = "+"
 
     @classmethod
-    def from_string(cls, s):
-        group_id = super(GroupID, cls).from_string(s)
+    def from_string(cls: Type[DS], s: str) -> DS:
+        group_id = super().from_string(s)  # type: DS # type: ignore
+
         if not group_id.localpart:
             raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index a805f51df1..13775b43f9 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,6 +15,7 @@
 
 import logging
 from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
 
 from prometheus_client import Counter
 
@@ -57,8 +58,10 @@ in_flight = InFlightGauge(
     sub_metrics=["real_time_max", "real_time_sum"],
 )
 
+T = TypeVar("T", bound=Callable[..., Any])
 
-def measure_func(name=None):
+
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
     """
     Used to decorate an async function with a `Measure` context manager.
 
@@ -76,7 +79,7 @@ def measure_func(name=None):
 
     """
 
-    def wrapper(func):
+    def wrapper(func: T) -> T:
         block_name = func.__name__ if name is None else name
 
         @wraps(func)
@@ -85,7 +88,7 @@ def measure_func(name=None):
                 r = await func(self, *args, **kwargs)
             return r
 
-        return measured_func
+        return cast(T, measured_func)
 
     return wrapper