Coverage for astrocyte/db_budget.py: 97%

77 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Connection-pool budget enforcement per operation. 

2 

3Adopted from Hindsight's ``engine/db_budget.py``. Limits the number of 

4concurrent DB connections any single operation can hold, preventing one 

5operation (e.g., a recall that fans out 5 retrieval strategies in 

6parallel) from monopolizing the shared connection pool and starving 

7other in-flight requests. 

8 

9The typical failure mode this prevents: 

10 - pool size = 20 

11 - one recall fans out 5 parallel strategy queries 

12 - 4 concurrent recalls = 20 connections held = pool exhausted 

13 - 5th concurrent recall blocks indefinitely waiting for a connection 

14 

15With a per-operation budget of 4, no single recall can hold more than 

164 connections regardless of fan-out — pool stays available for other 

17in-flight work. 

18 

19Pool-agnostic. Works with both psycopg ``AsyncConnectionPool`` 

20(``pool.connection()`` ctx-manager) and asyncpg pools (``acquire()``/ 

21``release()``). 

22 

23Public API: 

24 ConnectionBudgetManager(default_budget=4) 

25 async with mgr.operation(max_connections=2) as op: 

26 async with op.acquire(pool) as conn: 

27 ... 

28""" 

29 

30from __future__ import annotations 

31 

32import asyncio 

33import logging 

34import uuid 

35from contextlib import asynccontextmanager 

36from dataclasses import dataclass, field 

37from typing import TYPE_CHECKING, Any 

38 

39if TYPE_CHECKING: 

40 from collections.abc import AsyncIterator 

41 

42logger = logging.getLogger(__name__) 

43 

44 

45# ─── budget tracking ────────────────────────────────────────────────── 

46 

47 

48@dataclass 

49class OperationBudget: 

50 """Per-operation connection budget — one semaphore + counters.""" 

51 

52 operation_id: str 

53 max_connections: int 

54 semaphore: asyncio.Semaphore = field(init=False) 

55 active_count: int = field(default=0, init=False) 

56 peak_count: int = field(default=0, init=False) 

57 total_acquired: int = field(default=0, init=False) 

58 

59 def __post_init__(self) -> None: 

60 self.semaphore = asyncio.Semaphore(self.max_connections) 

61 

62 

63# ─── manager ────────────────────────────────────────────────────────── 

64 

65 

66class ConnectionBudgetManager: 

67 """Tracks per-operation connection budgets across the process. 

68 

69 Each ``operation()`` call registers a budget keyed by ID, returns a 

70 ``BudgetedOperation`` context. Acquires through that context are 

71 semaphore-bounded; the budget is unregistered on context exit. 

72 """ 

73 

74 def __init__(self, default_budget: int = 4) -> None: 

75 self.default_budget = default_budget 

76 self._operations: dict[str, OperationBudget] = {} 

77 self._lock = asyncio.Lock() 

78 

79 @asynccontextmanager 

80 async def operation( 

81 self, 

82 max_connections: int | None = None, 

83 operation_id: str | None = None, 

84 ) -> AsyncIterator[BudgetedOperation]: 

85 """Open a budgeted operation context. 

86 

87 ``max_connections`` defaults to ``default_budget``. Auto-generates 

88 an op_id if not supplied. 

89 """ 

90 op_id = operation_id or f"op-{uuid.uuid4().hex[:12]}" 

91 budget = max_connections if max_connections is not None else self.default_budget 

92 if budget < 1: 

93 raise ValueError(f"max_connections must be >= 1 (got {budget})") 

94 

95 async with self._lock: 

96 if op_id in self._operations: 

97 raise ValueError(f"Operation {op_id!r} already exists") 

98 self._operations[op_id] = OperationBudget(op_id, budget) 

99 

100 try: 

101 yield BudgetedOperation(self, op_id) 

102 finally: 

103 async with self._lock: 

104 self._operations.pop(op_id, None) 

105 

106 def _get_budget(self, operation_id: str) -> OperationBudget: 

107 budget = self._operations.get(operation_id) 

108 if not budget: 

109 raise ValueError(f"Operation {operation_id!r} not found (already exited?)") 

110 return budget 

111 

112 @property 

113 def active_operations(self) -> int: 

114 return len(self._operations) 

115 

116 

117# ─── budgeted operation context ─────────────────────────────────────── 

118 

119 

120class BudgetedOperation: 

121 """One operation's view of the budget. Acquire connections through here.""" 

122 

123 def __init__(self, manager: ConnectionBudgetManager, operation_id: str) -> None: 

124 self._manager = manager 

125 self.operation_id = operation_id 

126 

127 @property 

128 def budget(self) -> OperationBudget: 

129 return self._manager._get_budget(self.operation_id) 

130 

131 @asynccontextmanager 

132 async def acquire(self, pool: Any) -> AsyncIterator[Any]: 

133 """Acquire a connection within this operation's budget. 

134 

135 Pool-agnostic: works with psycopg ``AsyncConnectionPool`` 

136 (``pool.connection()`` ctx-manager) or asyncpg pools 

137 (``acquire()``/``release()``). 

138 """ 

139 budget = self.budget 

140 async with budget.semaphore: 

141 budget.active_count += 1 

142 budget.total_acquired += 1 

143 budget.peak_count = max(budget.peak_count, budget.active_count) 

144 try: 

145 if hasattr(pool, "connection"): 

146 # psycopg AsyncConnectionPool — ctx-manager 

147 async with pool.connection() as conn: 

148 yield conn 

149 elif hasattr(pool, "acquire") and hasattr(pool, "release"): 

150 # asyncpg pool — acquire/release 

151 conn = await pool.acquire() 

152 try: 

153 yield conn 

154 finally: 

155 await pool.release(conn) 

156 else: 

157 raise TypeError( 

158 f"Unsupported pool shape {type(pool).__name__}; " 

159 "expected .connection() or .acquire()/.release()", 

160 ) 

161 finally: 

162 budget.active_count -= 1 

163 

164 

165# ─── module-level default ───────────────────────────────────────────── 

166 

167_default_manager: ConnectionBudgetManager | None = None 

168 

169 

170def get_default_budget_manager() -> ConnectionBudgetManager: 

171 """Process-wide default manager. Default budget = 4 connections / operation.""" 

172 global _default_manager 

173 if _default_manager is None: 

174 _default_manager = ConnectionBudgetManager() 

175 return _default_manager 

176 

177 

178def set_default_budget_manager(manager: ConnectionBudgetManager) -> None: 

179 global _default_manager 

180 _default_manager = manager