diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 32d9d00ffb..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,27 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+import itertools
+import logging
+from collections import namedtuple
+
+from canonicaljson import json
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
-
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import (
- PreserveLoggingContext, make_deferred_yieldable, run_in_background,
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
)
from synapse.util.metrics import Measure
-from synapse.api.errors import SynapseError
-from collections import namedtuple
-
-import logging
-import simplejson as json
-
-# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -75,7 +79,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
- allow_none=False):
+ allow_none=False, check_room_id=None):
"""Get an event from the database by event_id.
Args:
@@ -86,7 +90,9 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
- False throw an exception.
+ False throw a NotFoundError
+ check_room_id (str|None): if not None, check the room of the found event.
+ If there is a mismatch, behave as per allow_none.
Returns:
Deferred : A FrozenEvent.
@@ -98,10 +104,16 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected,
)
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
+ event = events[0] if events else None
+
+ if event is not None and check_room_id is not None:
+ if event.room_id != check_room_id:
+ event = None
- defer.returnValue(events[0] if events else None)
+ if event is None and not allow_none:
+ raise NotFoundError("Could not find event %s" % (event_id,))
+
+ defer.returnValue(event)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
@@ -145,6 +157,9 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
@@ -218,32 +233,47 @@ class EventsWorkerStore(SQLBaseStore):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- event_list = []
i = 0
while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+
+ def _fetch_event_list(self, conn, event_list):
+ """Handle a load of requests from the _event_fetch_list queue
+
+ Args:
+ conn (twisted.enterprise.adbapi.Connection): database connection
+
+ event_list (list[Tuple[list[str], Deferred]]):
+ The fetch requests. Each entry consists of a list of event
+ ids to be fetched, and a deferred to be completed once the
+ events have been fetched.
+
+ """
+ with Measure(self._clock, "_fetch_event_list"):
try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
- conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [],
+ self._fetch_event_rows, event_ids,
)
row_dict = {
@@ -265,20 +295,19 @@ class EventsWorkerStore(SQLBaseStore):
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
+ self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
@@ -304,10 +333,11 @@ class EventsWorkerStore(SQLBaseStore):
should_start = False
if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
+ run_as_background_process(
+ "fetch_events",
+ self.runWithConnection,
+ self._do_fetch,
+ )
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
@@ -414,3 +444,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ defer.returnValue(set(r["event_id"] for r in rows))
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction("get_rejection_reasons", f)
|