Coverage for astrocyte/pipeline/pgqueuer_tasks.py: 71%
73 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""PgQueuer-backed worker integration for memory tasks.
3PgQueuer owns the PostgreSQL queue mechanics (``FOR UPDATE SKIP LOCKED``,
4``LISTEN/NOTIFY``, retries, scheduling). Astrocyte keeps the memory semantics in
5``MemoryTaskDispatcher`` so task handlers remain backend-agnostic.
6"""
8from __future__ import annotations
10import json
11from datetime import UTC, datetime, timedelta
12from typing import TYPE_CHECKING, Any
14from astrocyte.pipeline.tasks import (
15 COMPILE_BANK,
16 COMPILE_PERSONA_PAGE,
17 INDEX_WIKI_PAGE_VECTOR,
18 LINT_WIKI_PAGE,
19 NORMALIZE_TEMPORAL_FACTS,
20 PROJECT_ENTITY_EDGES,
21 MemoryTask,
22 MemoryTaskDispatcher,
23)
25if TYPE_CHECKING:
26 from pgqueuer import Job, PgQueuer
27 from psycopg import AsyncConnection
29TASK_ENTRYPOINTS = (
30 COMPILE_BANK,
31 COMPILE_PERSONA_PAGE,
32 INDEX_WIKI_PAGE_VECTOR,
33 PROJECT_ENTITY_EDGES,
34 NORMALIZE_TEMPORAL_FACTS,
35 LINT_WIKI_PAGE,
36)
39class PgQueuerMemoryTaskQueue:
40 """Adapter that enqueues and runs Astrocyte ``MemoryTask`` jobs via PgQueuer."""
42 def __init__(self, pgq: PgQueuer) -> None:
43 self.pgq = pgq
45 @classmethod
46 def in_memory(cls) -> PgQueuerMemoryTaskQueue:
47 """Create a PgQueuer in-memory queue for tests."""
49 from pgqueuer import PgQueuer
51 return cls(PgQueuer.in_memory())
53 @classmethod
54 def from_psycopg_connection(cls, connection: AsyncConnection[Any]) -> PgQueuerMemoryTaskQueue:
55 """Create a queue backed by a psycopg async Postgres connection."""
57 from pgqueuer import PgQueuer
59 return cls(PgQueuer.from_psycopg_connection(connection))
61 async def install(self) -> None:
62 """Install PgQueuer database objects when using a Postgres connection."""
64 try:
65 await self._queries().install()
66 except Exception as exc:
67 if "already exists" not in str(exc):
68 raise
70 async def clear(self) -> None:
71 """Clear queued/logged work for the Astrocyte task entrypoints."""
73 queries = self._queries()
74 await queries.clear_queue(list(TASK_ENTRYPOINTS))
75 await queries.clear_queue_log(list(TASK_ENTRYPOINTS))
77 async def enqueue(self, task: MemoryTask, *, priority: int = 0) -> str:
78 """Enqueue a memory task and return the PgQueuer job id."""
80 job_ids = await self._queries().enqueue(
81 task.task_type,
82 _encode_task(task),
83 priority=priority,
84 execute_after=_delay_from_now(task.run_after),
85 dedupe_key=task.idempotency_key,
86 )
87 return str(job_ids[0])
89 def register_dispatcher(self, dispatcher: MemoryTaskDispatcher) -> None:
90 """Register one PgQueuer entrypoint for each Astrocyte task type."""
92 for task_type in TASK_ENTRYPOINTS:
93 self._register_task_type(task_type, dispatcher)
95 async def run_drain(
96 self,
97 *,
98 batch_size: int = 10,
99 max_concurrent_tasks: int | None = None,
100 ) -> None:
101 """Process currently queued jobs and return when the queue is drained.
103 Args:
104 batch_size: How many jobs PgQueuer dequeues per round.
105 max_concurrent_tasks: Hard ceiling on simultaneously-running task
106 handlers. ``None`` falls back to PgQueuer's unbounded default
107 (``sys.maxsize``), which can exhaust downstream connection
108 pools when handlers do per-task DB I/O. Recommend setting
109 to ~2x ``batch_size`` (PgQueuer's enforced minimum) plus a
110 small margin — e.g. ``20`` for ``batch_size=10``.
111 """
113 from pgqueuer.types import QueueExecutionMode
115 await self.pgq.run(
116 batch_size=batch_size,
117 mode=QueueExecutionMode.drain,
118 max_concurrent_tasks=max_concurrent_tasks,
119 )
121 async def run_continuous(
122 self,
123 *,
124 batch_size: int = 10,
125 max_concurrent_tasks: int | None = None,
126 ) -> None:
127 """Run the PgQueuer worker until shutdown is requested.
129 ``max_concurrent_tasks`` bounds in-flight handler concurrency. Without
130 a ceiling, PgQueuer dispatches as many tasks as the queue contains,
131 which can saturate downstream connection pools (pgvector / AGE) when
132 handlers do per-task DB I/O. See :meth:`run_drain` for sizing notes.
133 """
135 await self.pgq.run(
136 batch_size=batch_size,
137 max_concurrent_tasks=max_concurrent_tasks,
138 )
140 async def shutdown(self) -> None:
141 """Stop PgQueuer listeners/workers."""
143 shutdown = self.pgq.shutdown
144 if callable(shutdown):
145 result = shutdown()
146 if result is not None:
147 _ = await result # drive the shutdown coroutine; return value unused
148 else:
149 shutdown.set()
151 def _register_task_type(
152 self,
153 task_type: str,
154 dispatcher: MemoryTaskDispatcher,
155 ) -> None:
156 @self.pgq.entrypoint(task_type)
157 async def _handle(job: Job) -> None:
158 task = task_from_pgqueuer_payload(job.payload, fallback_task_type=task_type)
159 await dispatcher.run(task)
161 def _queries(self) -> Any:
162 return self.pgq.queries or self.pgq.qm.queries
165def _encode_task(task: MemoryTask) -> bytes:
166 return json.dumps(task_to_pgqueuer_payload(task), sort_keys=True).encode("utf-8")
169def task_to_pgqueuer_payload(task: MemoryTask) -> dict[str, Any]:
170 """Return the stable JSON payload shape stored in PgQueuer jobs."""
172 return {
173 "id": task.id,
174 "task_type": task.task_type,
175 "bank_id": task.bank_id,
176 "payload": task.payload,
177 "idempotency_key": task.idempotency_key,
178 "attempts": task.attempts,
179 "max_attempts": task.max_attempts,
180 "run_after": task.run_after.isoformat(),
181 "created_at": task.created_at.isoformat(),
182 }
185def task_from_pgqueuer_payload(payload: bytes | None, *, fallback_task_type: str) -> MemoryTask:
186 """Decode the stable PgQueuer JSON payload shape into a MemoryTask."""
188 if not payload:
189 raise ValueError("PgQueuer job payload is required for Astrocyte memory tasks")
190 raw = json.loads(payload.decode("utf-8"))
191 return MemoryTask(
192 id=str(raw.get("id") or ""),
193 task_type=str(raw.get("task_type") or fallback_task_type),
194 bank_id=str(raw["bank_id"]),
195 payload=dict(raw.get("payload") or {}),
196 idempotency_key=raw.get("idempotency_key"),
197 attempts=int(raw.get("attempts") or 0),
198 max_attempts=int(raw.get("max_attempts") or 5),
199 run_after=_parse_dt(raw.get("run_after")) or datetime.now(UTC),
200 created_at=_parse_dt(raw.get("created_at")) or datetime.now(UTC),
201 )
204def _parse_dt(value: Any) -> datetime | None:
205 if not value:
206 return None
207 parsed = datetime.fromisoformat(str(value))
208 if parsed.tzinfo is None:
209 return parsed.replace(tzinfo=UTC)
210 return parsed
213def _delay_from_now(run_after: datetime) -> timedelta | None:
214 delay = run_after - datetime.now(UTC)
215 if delay.total_seconds() <= 0:
216 return None
217 return delay