Coverage for astrocyte/db_budget.py: 97%
77 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Connection-pool budget enforcement per operation.
3Adopted from Hindsight's ``engine/db_budget.py``. Limits the number of
4concurrent DB connections any single operation can hold, preventing one
5operation (e.g., a recall that fans out 5 retrieval strategies in
6parallel) from monopolizing the shared connection pool and starving
7other in-flight requests.
9The typical failure mode this prevents:
10 - pool size = 20
11 - one recall fans out 5 parallel strategy queries
12 - 4 concurrent recalls = 20 connections held = pool exhausted
13 - 5th concurrent recall blocks indefinitely waiting for a connection
15With a per-operation budget of 4, no single recall can hold more than
164 connections regardless of fan-out — pool stays available for other
17in-flight work.
19Pool-agnostic. Works with both psycopg ``AsyncConnectionPool``
20(``pool.connection()`` ctx-manager) and asyncpg pools (``acquire()``/
21``release()``).
23Public API:
24 ConnectionBudgetManager(default_budget=4)
25 async with mgr.operation(max_connections=2) as op:
26 async with op.acquire(pool) as conn:
27 ...
28"""
30from __future__ import annotations
32import asyncio
33import logging
34import uuid
35from contextlib import asynccontextmanager
36from dataclasses import dataclass, field
37from typing import TYPE_CHECKING, Any
39if TYPE_CHECKING:
40 from collections.abc import AsyncIterator
42logger = logging.getLogger(__name__)
45# ─── budget tracking ──────────────────────────────────────────────────
48@dataclass
49class OperationBudget:
50 """Per-operation connection budget — one semaphore + counters."""
52 operation_id: str
53 max_connections: int
54 semaphore: asyncio.Semaphore = field(init=False)
55 active_count: int = field(default=0, init=False)
56 peak_count: int = field(default=0, init=False)
57 total_acquired: int = field(default=0, init=False)
59 def __post_init__(self) -> None:
60 self.semaphore = asyncio.Semaphore(self.max_connections)
63# ─── manager ──────────────────────────────────────────────────────────
66class ConnectionBudgetManager:
67 """Tracks per-operation connection budgets across the process.
69 Each ``operation()`` call registers a budget keyed by ID, returns a
70 ``BudgetedOperation`` context. Acquires through that context are
71 semaphore-bounded; the budget is unregistered on context exit.
72 """
74 def __init__(self, default_budget: int = 4) -> None:
75 self.default_budget = default_budget
76 self._operations: dict[str, OperationBudget] = {}
77 self._lock = asyncio.Lock()
79 @asynccontextmanager
80 async def operation(
81 self,
82 max_connections: int | None = None,
83 operation_id: str | None = None,
84 ) -> AsyncIterator[BudgetedOperation]:
85 """Open a budgeted operation context.
87 ``max_connections`` defaults to ``default_budget``. Auto-generates
88 an op_id if not supplied.
89 """
90 op_id = operation_id or f"op-{uuid.uuid4().hex[:12]}"
91 budget = max_connections if max_connections is not None else self.default_budget
92 if budget < 1:
93 raise ValueError(f"max_connections must be >= 1 (got {budget})")
95 async with self._lock:
96 if op_id in self._operations:
97 raise ValueError(f"Operation {op_id!r} already exists")
98 self._operations[op_id] = OperationBudget(op_id, budget)
100 try:
101 yield BudgetedOperation(self, op_id)
102 finally:
103 async with self._lock:
104 self._operations.pop(op_id, None)
106 def _get_budget(self, operation_id: str) -> OperationBudget:
107 budget = self._operations.get(operation_id)
108 if not budget:
109 raise ValueError(f"Operation {operation_id!r} not found (already exited?)")
110 return budget
112 @property
113 def active_operations(self) -> int:
114 return len(self._operations)
117# ─── budgeted operation context ───────────────────────────────────────
120class BudgetedOperation:
121 """One operation's view of the budget. Acquire connections through here."""
123 def __init__(self, manager: ConnectionBudgetManager, operation_id: str) -> None:
124 self._manager = manager
125 self.operation_id = operation_id
127 @property
128 def budget(self) -> OperationBudget:
129 return self._manager._get_budget(self.operation_id)
131 @asynccontextmanager
132 async def acquire(self, pool: Any) -> AsyncIterator[Any]:
133 """Acquire a connection within this operation's budget.
135 Pool-agnostic: works with psycopg ``AsyncConnectionPool``
136 (``pool.connection()`` ctx-manager) or asyncpg pools
137 (``acquire()``/``release()``).
138 """
139 budget = self.budget
140 async with budget.semaphore:
141 budget.active_count += 1
142 budget.total_acquired += 1
143 budget.peak_count = max(budget.peak_count, budget.active_count)
144 try:
145 if hasattr(pool, "connection"):
146 # psycopg AsyncConnectionPool — ctx-manager
147 async with pool.connection() as conn:
148 yield conn
149 elif hasattr(pool, "acquire") and hasattr(pool, "release"):
150 # asyncpg pool — acquire/release
151 conn = await pool.acquire()
152 try:
153 yield conn
154 finally:
155 await pool.release(conn)
156 else:
157 raise TypeError(
158 f"Unsupported pool shape {type(pool).__name__}; "
159 "expected .connection() or .acquire()/.release()",
160 )
161 finally:
162 budget.active_count -= 1
165# ─── module-level default ─────────────────────────────────────────────
167_default_manager: ConnectionBudgetManager | None = None
170def get_default_budget_manager() -> ConnectionBudgetManager:
171 """Process-wide default manager. Default budget = 4 connections / operation."""
172 global _default_manager
173 if _default_manager is None:
174 _default_manager = ConnectionBudgetManager()
175 return _default_manager
178def set_default_budget_manager(manager: ConnectionBudgetManager) -> None:
179 global _default_manager
180 _default_manager = manager