diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
+ @staticmethod
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
- return []
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..b690abedad 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
- # hase been disabled on the master.
+ # has been disabled on the master.
continue
self.streams.append(stream(hs))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..084604e8b0 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -104,7 +104,8 @@ class Stream(object):
implemented by subclasses.
current_token_function is called to get the current token of the underlying
- stream.
+ stream. It is only meaningful on the process that is the source of the
+ replication stream (ie, usually the master).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..b0505b8a2c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import Stream, make_http_update_function
class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
- )
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token():
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ee3a2ab031..536cef3abd 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+ "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
+ # poked users.
sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+ self.db.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
+ # create a unique index on device_lists_outbound_last_success
+ self.db.updates.register_background_index_update(
+ "device_lists_outbound_last_success_unique_idx",
+ index_name="device_lists_outbound_last_success_unique_idx",
+ table="device_lists_outbound_last_success",
+ columns=["destination", "user_id"],
+ unique=True,
+ )
+
+ # once that completes, we can remove the old non-unique index.
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+ self._drop_device_lists_outbound_last_success_non_unique_idx,
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
+ async def _drop_device_lists_outbound_last_success_non_unique_idx(
+ self, progress, batch_size
+ ):
+ def f(txn):
+ txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+ await self.db.runInteraction(
+ "drop_device_lists_outbound_last_success_non_unique_idx", f,
+ )
+ await self.db.updates._end_background_update(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+ )
+ return 1
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+ VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+ VALUES (
+ 5804,
+ 'drop_device_lists_outbound_last_success_non_unique_idx',
+ '{}',
+ 'device_lists_outbound_last_success_unique_idx'
+ );
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a7cd97b0b0..2b635d6ca0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
"event_search": "event_search_event_id_idx",
+ "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
}
@@ -889,20 +891,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Iterable[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ ) -> None:
"""
Upsert, many times, using batching where possible.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
|