diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/wheel_timer.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index e108adc460..177e198e7e 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generic, List, TypeVar +import logging +from typing import Generic, Hashable, List, Set, TypeVar -T = TypeVar("T") +import attr +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=Hashable) -class _Entry(Generic[T]): - __slots__ = ["end_key", "queue"] - def __init__(self, end_key: int) -> None: - self.end_key: int = end_key - self.queue: List[T] = [] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _Entry(Generic[T]): + end_key: int + elements: Set[T] = attr.Factory(set) class WheelTimer(Generic[T]): @@ -48,17 +51,27 @@ class WheelTimer(Generic[T]): then: When to return the object strictly after. """ then_key = int(then / self.bucket_size) + 1 + now_key = int(now / self.bucket_size) if self.entries: min_key = self.entries[0].end_key max_key = self.entries[-1].end_key + if min_key < now_key - 10: + # If we have ten buckets that are due and still nothing has + # called `fetch()` then we likely have a bug that is causing a + # memory leak. + logger.warning( + "Inserting into a wheel timer that hasn't been read from recently. Item: %s", + obj, + ) + if then_key <= max_key: # The max here is to protect against inserts for times in the past - self.entries[max(min_key, then_key) - min_key].queue.append(obj) + self.entries[max(min_key, then_key) - min_key].elements.add(obj) return - next_key = int(now / self.bucket_size) + 1 + next_key = now_key + 1 if self.entries: last_key = self.entries[-1].end_key else: @@ -71,7 +84,7 @@ class WheelTimer(Generic[T]): # to insert. This ensures there are no gaps. self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) - self.entries[-1].queue.append(obj) + self.entries[-1].elements.add(obj) def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out @@ -84,11 +97,11 @@ class WheelTimer(Generic[T]): """ now_key = int(now / self.bucket_size) - ret = [] + ret: List[T] = [] while self.entries and self.entries[0].end_key <= now_key: - ret.extend(self.entries.pop(0).queue) + ret.extend(self.entries.pop(0).elements) return ret def __len__(self) -> int: - return sum(len(entry.queue) for entry in self.entries) + return sum(len(entry.elements) for entry in self.entries) |