diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index ad07ee86f6..9e7ac149a1 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -14,24 +14,70 @@
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, Optional
-from synapse.types import RoomStreamToken
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
+@attr.s(slots=True)
+class PusherConfig:
+ """Parameters necessary to configure a pusher."""
+
+ id = attr.ib(type=Optional[str])
+ user_name = attr.ib(type=str)
+ access_token = attr.ib(type=Optional[int])
+ profile_tag = attr.ib(type=str)
+ kind = attr.ib(type=str)
+ app_id = attr.ib(type=str)
+ app_display_name = attr.ib(type=str)
+ device_display_name = attr.ib(type=str)
+ pushkey = attr.ib(type=str)
+ ts = attr.ib(type=int)
+ lang = attr.ib(type=Optional[str])
+ data = attr.ib(type=Optional[JsonDict])
+ last_stream_ordering = attr.ib(type=Optional[int])
+ last_success = attr.ib(type=Optional[int])
+ failing_since = attr.ib(type=Optional[int])
+
+ def as_dict(self) -> Dict[str, Any]:
+ """Information that can be retrieved about a pusher after creation."""
+ return {
+ "app_display_name": self.app_display_name,
+ "app_id": self.app_id,
+ "data": self.data,
+ "device_display_name": self.device_display_name,
+ "kind": self.kind,
+ "lang": self.lang,
+ "profile_tag": self.profile_tag,
+ "pushkey": self.pushkey,
+ }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+ """Parameters for controlling the rate of sending pushes via email."""
+
+ last_sent_ts = attr.ib(type=int)
+ throttle_ms = attr.ib(type=int)
+
+
class Pusher(metaclass=abc.ABCMeta):
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict["id"]
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
- self.pushkey = pusherdict["pushkey"]
+ self.pusher_id = pusher_config.id
+ self.user_id = pusher_config.user_name
+ self.app_id = pusher_config.app_id
+ self.pushkey = pusher_config.pushkey
+
+ self.last_stream_ordering = pusher_config.last_stream_ordering
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 11a97b8df4..d2eff75a58 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,13 +14,13 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig, ThrottleParams
from synapse.push.mailer import Mailer
if TYPE_CHECKING:
@@ -60,15 +60,14 @@ class EmailPusher(Pusher):
factor out the common parts
"""
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
- super().__init__(hs, pusherdict)
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+ super().__init__(hs, pusher_config)
self.mailer = mailer
self.store = self.hs.get_datastore()
- self.email = pusherdict["pushkey"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
- self.throttle_params = {} # type: Dict[str, Dict[str, int]]
+ self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
self._is_processing = False
@@ -132,6 +131,7 @@ class EmailPusher(Pusher):
if not self._inited:
# this is our first loop: load up the throttle params
+ assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
@@ -157,6 +157,7 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
+ assert start is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
@@ -244,13 +245,13 @@ class EmailPusher(Pusher):
def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["throttle_ms"]
+ return self.throttle_params[room_id].throttle_ms
else:
return 0
def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
- return self.throttle_params[room_id]["last_sent_ts"]
+ return self.throttle_params[room_id].last_sent_ts
else:
return 0
@@ -301,10 +302,10 @@ class EmailPusher(Pusher):
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
- self.throttle_params[room_id] = {
- "last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms,
- }
+ self.throttle_params[room_id] = ThrottleParams(
+ self.clock.time_msec(), new_throttle_ms,
+ )
+ assert self.pusher_id is not None
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e8b25bcd2a..417fe0f1f5 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools
@@ -62,33 +62,29 @@ class HttpPusher(Pusher):
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
- super().__init__(hs, pusherdict)
+ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+ super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
- self.app_display_name = pusherdict["app_display_name"]
- self.device_display_name = pusherdict["device_display_name"]
- self.pushkey_ts = pusherdict["ts"]
- self.data = pusherdict["data"]
- self.last_stream_ordering = pusherdict["last_stream_ordering"]
+ self.app_display_name = pusher_config.app_display_name
+ self.device_display_name = pusher_config.device_display_name
+ self.pushkey_ts = pusher_config.ts
+ self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict["failing_since"]
+ self.failing_since = pusher_config.failing_since
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
- if "data" not in pusherdict:
- raise PusherConfigException("No 'data' key for HTTP pusher")
- self.data = pusherdict["data"]
+ self.data = pusher_config.data
+ if self.data is None:
+ raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % (
- pusherdict["user_name"],
- pusherdict["app_id"],
- pusherdict["pushkey"],
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
)
- if self.data is None:
- raise PusherConfigException("data can not be null for HTTP pusher")
-
# Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
@@ -180,6 +176,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
+ assert self.last_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -208,6 +205,7 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
+ assert self.last_stream_ordering is not None
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
@@ -314,6 +312,8 @@ class HttpPusher(Pusher):
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
+ # This was checked in the __init__, but mypy doesn't seem to know that.
+ assert self.data is not None
if self.data.get("format") == "event_id_only":
d = {
"notification": {
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 8f1072b094..2aa7918fb4 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,9 +14,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Callable, Dict, Optional
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
@@ -34,7 +34,7 @@ class PusherFactory:
self.pusher_types = {
"http": HttpPusher
- } # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
+ } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
@@ -47,18 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type")
- def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
- kind = pusherdict["kind"]
+ def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+ kind = pusher_config.kind
f = self.pusher_types.get(kind, None)
if not f:
return None
- logger.debug("creating %s pusher for %r", kind, pusherdict)
- return f(self.hs, pusherdict)
+ logger.debug("creating %s pusher for %r", kind, pusher_config)
+ return f(self.hs, pusher_config)
def _create_email_pusher(
- self, _hs: "HomeServer", pusherdict: Dict[str, Any]
+ self, _hs: "HomeServer", pusher_config: PusherConfig
) -> EmailPusher:
- app_name = self._app_name_from_pusherdict(pusherdict)
+ app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
@@ -68,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
- return EmailPusher(self.hs, pusherdict, mailer)
+ return EmailPusher(self.hs, pusher_config, mailer)
- def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
- data = pusherdict["data"]
+ def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+ data = pusher_config.data
if isinstance(data, dict):
brand = data.get("brand")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 9c12d81cfb..8158356d40 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge
@@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@@ -77,7 +77,7 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
- def start(self):
+ def start(self) -> None:
"""Starts the pushers off in a background process.
"""
if not self._should_start_pushers:
@@ -87,16 +87,16 @@ class PusherPool:
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- lang,
- data,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ lang: Optional[str],
+ data: JsonDict,
+ profile_tag: str = "",
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@@ -111,21 +111,23 @@ class PusherPool:
# recreated, added and started: this means we have only one
# code path adding pushers.
self.pusher_factory.create_pusher(
- {
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None,
- }
+ PusherConfig(
+ id=None,
+ user_name=user_id,
+ access_token=access_token,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=time_now_msec,
+ lang=lang,
+ data=data,
+ last_stream_ordering=None,
+ last_success=None,
+ failing_since=None,
+ )
)
# create the pusher setting last_stream_ordering to the current maximum
@@ -151,43 +153,44 @@ class PusherPool:
return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user(
- self, app_id, pushkey, not_user_id
- ):
+ self, app_id: str, pushkey: str, not_user_id: str
+ ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p["user_name"] != not_user_id:
+ if p.user_name != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id,
pushkey,
- p["user_name"],
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- async def remove_pushers_by_access_token(self, user_id, access_tokens):
+ async def remove_pushers_by_access_token(
+ self, user_id: str, access_tokens: Iterable[int]
+ ) -> None:
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
- user_id (str): user to remove pushers for
- access_tokens (Iterable[int]): access token *ids* to remove pushers
- for
+ user_id: user to remove pushers for
+ access_tokens: access token *ids* to remove pushers for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
- if p["access_token"] in tokens:
+ if p.access_token in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p["app_id"],
- p["pushkey"],
- p["user_name"],
+ p.app_id,
+ p.pushkey,
+ p.user_name,
)
- await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+ await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -206,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications")
- async def _on_new_notifications(self, max_token: RoomStreamToken):
+ async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -236,7 +239,9 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ async def on_new_receipts(
+ self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+ ) -> None:
if not self.pushers:
# nothing to do here.
return
@@ -280,14 +285,14 @@ class PusherPool:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- pusher_dict = None
+ pusher_config = None
for r in resultlist:
- if r["user_name"] == user_id:
- pusher_dict = r
+ if r.user_name == user_id:
+ pusher_config = r
pusher = None
- if pusher_dict:
- pusher = await self._start_pusher(pusher_dict)
+ if pusher_config:
+ pusher = await self._start_pusher(pusher_config)
return pusher
@@ -302,44 +307,44 @@ class PusherPool:
logger.info("Started pushers")
- async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
+ async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher
Args:
- pusherdict: dict with the values pulled from the db table
+ pusher_config: The pusher configuration with the values pulled from the db table
Returns:
The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
- self._instance_name, pusherdict["user_name"]
+ self._instance_name, pusher_config.user_name
):
return None
try:
- p = self.pusher_factory.create_pusher(pusherdict)
+ p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
- pusherdict["id"],
- pusherdict.get("user_name"),
- pusherdict.get("app_id"),
- pusherdict.get("pushkey"),
+ pusher_config.id,
+ pusher_config.user_name,
+ pusher_config.app_id,
+ pusher_config.pushkey,
e,
)
return None
except Exception:
logger.exception(
- "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ "Couldn't start pusher id %i: caught Exception", pusher_config.id,
)
return None
if not p:
return None
- appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+ appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
- byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
@@ -349,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
- user_id = pusherdict["user_name"]
- last_stream_ordering = pusherdict["last_stream_ordering"]
+ user_id = pusher_config.user_name
+ last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -364,7 +369,7 @@ class PusherPool:
return p
- async def remove_pusher(self, app_id, pushkey, user_id):
+ async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker:
- def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ def __init__(
+ self,
+ db_conn: Connection,
+ table: str,
+ column: str,
+ extra_tables: Optional[List[Tuple[str, str]]] = None,
+ step: int = 1,
+ ):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
+ if extra_tables:
+ for table, column in extra_tables:
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, instance_name, new_id):
+ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""
Returns:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker(
+ self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token, rows
+ ) -> None:
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
+ self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 88cba369f5..6658c2da56 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_GET_PUSHERS_ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id)
- filtered_pushers = [
- {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
- for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = [
- {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4fb..871fb646a5 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,9 +149,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..77ba9d819e 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
+from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
- yield r
+ yield PusherConfig(**r)
- async def user_has_pusher(self, user_id):
+ async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
- def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+ async def get_pushers_by_app_id_and_pushkey(
+ self, app_id: str, pushkey: str
+ ) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
- def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({"user_name": user_id})
+ async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"user_name": user_id})
- async def get_pushers_by(self, keyvalues):
+ async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self):
+ async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id):
+ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- async def get_if_users_have_pushers(self, user_ids):
+ async def get_if_users_have_pushers(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
- self, app_id, pushkey, user_id, failing_since
+ self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None:
await self.db_pool.simple_update(
table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since",
)
- async def get_throttle_params_by_room(self, pusher_id):
+ async def get_throttle_params_by_room(
+ self, pusher_id: str
+ ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
+ params_by_room[row["room_id"]] = ThrottleParams(
+ row["last_sent_ts"], row["throttle_ms"],
+ )
return params_by_room
- async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+ async def set_throttle_params(
+ self, pusher_id: str, room_id: str, params: ThrottleParams
+ ) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
- params,
+ {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
class PusherStore(PusherWorkerStore):
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- pushkey_ts,
- lang,
- data,
- last_stream_ordering,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ pushkey_ts: int,
+ lang: Optional[str],
+ data: Optional[JsonDict],
+ last_stream_ordering: int,
+ profile_tag: str = "",
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream,
+ self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
- self, app_id, pushkey, user_id
+ self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream(
+ self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302ea..133c0e7a28 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -153,12 +153,12 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
- int
+ The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
|