Coverage for astrocyte/pipeline/pgqueuer_tasks.py: 71%

73 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""PgQueuer-backed worker integration for memory tasks. 

2 

3PgQueuer owns the PostgreSQL queue mechanics (``FOR UPDATE SKIP LOCKED``, 

4``LISTEN/NOTIFY``, retries, scheduling). Astrocyte keeps the memory semantics in 

5``MemoryTaskDispatcher`` so task handlers remain backend-agnostic. 

6""" 

7 

8from __future__ import annotations 

9 

10import json 

11from datetime import UTC, datetime, timedelta 

12from typing import TYPE_CHECKING, Any 

13 

14from astrocyte.pipeline.tasks import ( 

15 COMPILE_BANK, 

16 COMPILE_PERSONA_PAGE, 

17 INDEX_WIKI_PAGE_VECTOR, 

18 LINT_WIKI_PAGE, 

19 NORMALIZE_TEMPORAL_FACTS, 

20 PROJECT_ENTITY_EDGES, 

21 MemoryTask, 

22 MemoryTaskDispatcher, 

23) 

24 

25if TYPE_CHECKING: 

26 from pgqueuer import Job, PgQueuer 

27 from psycopg import AsyncConnection 

28 

29TASK_ENTRYPOINTS = ( 

30 COMPILE_BANK, 

31 COMPILE_PERSONA_PAGE, 

32 INDEX_WIKI_PAGE_VECTOR, 

33 PROJECT_ENTITY_EDGES, 

34 NORMALIZE_TEMPORAL_FACTS, 

35 LINT_WIKI_PAGE, 

36) 

37 

38 

39class PgQueuerMemoryTaskQueue: 

40 """Adapter that enqueues and runs Astrocyte ``MemoryTask`` jobs via PgQueuer.""" 

41 

42 def __init__(self, pgq: PgQueuer) -> None: 

43 self.pgq = pgq 

44 

45 @classmethod 

46 def in_memory(cls) -> PgQueuerMemoryTaskQueue: 

47 """Create a PgQueuer in-memory queue for tests.""" 

48 

49 from pgqueuer import PgQueuer 

50 

51 return cls(PgQueuer.in_memory()) 

52 

53 @classmethod 

54 def from_psycopg_connection(cls, connection: AsyncConnection[Any]) -> PgQueuerMemoryTaskQueue: 

55 """Create a queue backed by a psycopg async Postgres connection.""" 

56 

57 from pgqueuer import PgQueuer 

58 

59 return cls(PgQueuer.from_psycopg_connection(connection)) 

60 

61 async def install(self) -> None: 

62 """Install PgQueuer database objects when using a Postgres connection.""" 

63 

64 try: 

65 await self._queries().install() 

66 except Exception as exc: 

67 if "already exists" not in str(exc): 

68 raise 

69 

70 async def clear(self) -> None: 

71 """Clear queued/logged work for the Astrocyte task entrypoints.""" 

72 

73 queries = self._queries() 

74 await queries.clear_queue(list(TASK_ENTRYPOINTS)) 

75 await queries.clear_queue_log(list(TASK_ENTRYPOINTS)) 

76 

77 async def enqueue(self, task: MemoryTask, *, priority: int = 0) -> str: 

78 """Enqueue a memory task and return the PgQueuer job id.""" 

79 

80 job_ids = await self._queries().enqueue( 

81 task.task_type, 

82 _encode_task(task), 

83 priority=priority, 

84 execute_after=_delay_from_now(task.run_after), 

85 dedupe_key=task.idempotency_key, 

86 ) 

87 return str(job_ids[0]) 

88 

89 def register_dispatcher(self, dispatcher: MemoryTaskDispatcher) -> None: 

90 """Register one PgQueuer entrypoint for each Astrocyte task type.""" 

91 

92 for task_type in TASK_ENTRYPOINTS: 

93 self._register_task_type(task_type, dispatcher) 

94 

95 async def run_drain( 

96 self, 

97 *, 

98 batch_size: int = 10, 

99 max_concurrent_tasks: int | None = None, 

100 ) -> None: 

101 """Process currently queued jobs and return when the queue is drained. 

102 

103 Args: 

104 batch_size: How many jobs PgQueuer dequeues per round. 

105 max_concurrent_tasks: Hard ceiling on simultaneously-running task 

106 handlers. ``None`` falls back to PgQueuer's unbounded default 

107 (``sys.maxsize``), which can exhaust downstream connection 

108 pools when handlers do per-task DB I/O. Recommend setting 

109 to ~2x ``batch_size`` (PgQueuer's enforced minimum) plus a 

110 small margin — e.g. ``20`` for ``batch_size=10``. 

111 """ 

112 

113 from pgqueuer.types import QueueExecutionMode 

114 

115 await self.pgq.run( 

116 batch_size=batch_size, 

117 mode=QueueExecutionMode.drain, 

118 max_concurrent_tasks=max_concurrent_tasks, 

119 ) 

120 

121 async def run_continuous( 

122 self, 

123 *, 

124 batch_size: int = 10, 

125 max_concurrent_tasks: int | None = None, 

126 ) -> None: 

127 """Run the PgQueuer worker until shutdown is requested. 

128 

129 ``max_concurrent_tasks`` bounds in-flight handler concurrency. Without 

130 a ceiling, PgQueuer dispatches as many tasks as the queue contains, 

131 which can saturate downstream connection pools (pgvector / AGE) when 

132 handlers do per-task DB I/O. See :meth:`run_drain` for sizing notes. 

133 """ 

134 

135 await self.pgq.run( 

136 batch_size=batch_size, 

137 max_concurrent_tasks=max_concurrent_tasks, 

138 ) 

139 

140 async def shutdown(self) -> None: 

141 """Stop PgQueuer listeners/workers.""" 

142 

143 shutdown = self.pgq.shutdown 

144 if callable(shutdown): 

145 result = shutdown() 

146 if result is not None: 

147 _ = await result # drive the shutdown coroutine; return value unused 

148 else: 

149 shutdown.set() 

150 

151 def _register_task_type( 

152 self, 

153 task_type: str, 

154 dispatcher: MemoryTaskDispatcher, 

155 ) -> None: 

156 @self.pgq.entrypoint(task_type) 

157 async def _handle(job: Job) -> None: 

158 task = task_from_pgqueuer_payload(job.payload, fallback_task_type=task_type) 

159 await dispatcher.run(task) 

160 

161 def _queries(self) -> Any: 

162 return self.pgq.queries or self.pgq.qm.queries 

163 

164 

165def _encode_task(task: MemoryTask) -> bytes: 

166 return json.dumps(task_to_pgqueuer_payload(task), sort_keys=True).encode("utf-8") 

167 

168 

169def task_to_pgqueuer_payload(task: MemoryTask) -> dict[str, Any]: 

170 """Return the stable JSON payload shape stored in PgQueuer jobs.""" 

171 

172 return { 

173 "id": task.id, 

174 "task_type": task.task_type, 

175 "bank_id": task.bank_id, 

176 "payload": task.payload, 

177 "idempotency_key": task.idempotency_key, 

178 "attempts": task.attempts, 

179 "max_attempts": task.max_attempts, 

180 "run_after": task.run_after.isoformat(), 

181 "created_at": task.created_at.isoformat(), 

182 } 

183 

184 

185def task_from_pgqueuer_payload(payload: bytes | None, *, fallback_task_type: str) -> MemoryTask: 

186 """Decode the stable PgQueuer JSON payload shape into a MemoryTask.""" 

187 

188 if not payload: 

189 raise ValueError("PgQueuer job payload is required for Astrocyte memory tasks") 

190 raw = json.loads(payload.decode("utf-8")) 

191 return MemoryTask( 

192 id=str(raw.get("id") or ""), 

193 task_type=str(raw.get("task_type") or fallback_task_type), 

194 bank_id=str(raw["bank_id"]), 

195 payload=dict(raw.get("payload") or {}), 

196 idempotency_key=raw.get("idempotency_key"), 

197 attempts=int(raw.get("attempts") or 0), 

198 max_attempts=int(raw.get("max_attempts") or 5), 

199 run_after=_parse_dt(raw.get("run_after")) or datetime.now(UTC), 

200 created_at=_parse_dt(raw.get("created_at")) or datetime.now(UTC), 

201 ) 

202 

203 

204def _parse_dt(value: Any) -> datetime | None: 

205 if not value: 

206 return None 

207 parsed = datetime.fromisoformat(str(value)) 

208 if parsed.tzinfo is None: 

209 return parsed.replace(tzinfo=UTC) 

210 return parsed 

211 

212 

213def _delay_from_now(run_after: datetime) -> timedelta | None: 

214 delay = run_after - datetime.now(UTC) 

215 if delay.total_seconds() <= 0: 

216 return None 

217 return delay