Coverage for astrocyte/pipeline/tasks.py: 83%
315 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Async memory task contract and handlers.
3This module is the framework-facing half of the Postgres task design. The
4production backend can persist these ``MemoryTask`` rows in ``astrocyte_tasks``;
5tests and single-process deployments can use ``InMemoryTaskBackend``.
6"""
8from __future__ import annotations
10import json
11import re
12import uuid
13from dataclasses import dataclass, field
14from datetime import UTC, datetime, timedelta
15from typing import TYPE_CHECKING, Any, Literal, Protocol
17from astrocyte.pipeline.temporal import temporal_metadata
18from astrocyte.types import (
19 Entity,
20 EntityLink,
21 MemoryEntityAssociation,
22 Message,
23 VectorItem,
24 WikiPage,
25)
27if TYPE_CHECKING:
28 from astrocyte.pipeline.compile import CompileEngine
29 from astrocyte.pipeline.lint import LintEngine
30 from astrocyte.provider import GraphStore, LLMProvider, VectorStore, WikiStore
32TaskStatus = Literal["queued", "running", "succeeded", "failed", "dead"]
34COMPILE_BANK = "compile_bank"
35COMPILE_PERSONA_PAGE = "compile_persona_page"
36INDEX_WIKI_PAGE_VECTOR = "index_wiki_page_vector"
37PROJECT_ENTITY_EDGES = "project_entity_edges"
38NORMALIZE_TEMPORAL_FACTS = "normalize_temporal_facts"
39LINT_WIKI_PAGE = "lint_wiki_page"
42@dataclass
43class MemoryTask:
44 """A durable background task for memory quality work."""
46 task_type: str
47 bank_id: str
48 payload: dict[str, Any] = field(default_factory=dict)
49 id: str = field(default_factory=lambda: uuid.uuid4().hex)
50 status: TaskStatus = "queued"
51 idempotency_key: str | None = None
52 attempts: int = 0
53 max_attempts: int = 5
54 run_after: datetime = field(default_factory=lambda: datetime.now(UTC))
55 result: dict[str, Any] | None = None
56 error: str | None = None
57 claimed_by: str | None = None
58 claimed_at: datetime | None = None
59 created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
60 updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
63class TaskBackend(Protocol):
64 """Storage contract for task queues, implemented by Postgres in production."""
66 async def enqueue(self, task: MemoryTask) -> str:
67 pass
69 async def claim(self, worker_id: str, limit: int = 10) -> list[MemoryTask]:
70 pass
72 async def complete(self, task_id: str, result: dict[str, Any]) -> None:
73 pass
75 async def fail(self, task_id: str, error: str, retry_at: datetime | None) -> None:
76 pass
79class InMemoryTaskBackend:
80 """Deterministic task backend for unit tests and local development."""
82 def __init__(self) -> None:
83 self._tasks: dict[str, MemoryTask] = {}
85 async def enqueue(self, task: MemoryTask) -> str:
86 if task.idempotency_key is not None:
87 for existing in self._tasks.values():
88 if (
89 existing.task_type == task.task_type
90 and existing.idempotency_key == task.idempotency_key
91 and existing.status in {"queued", "running"}
92 ):
93 return existing.id
94 self._tasks[task.id] = task
95 return task.id
97 async def claim(self, worker_id: str, limit: int = 10) -> list[MemoryTask]:
98 now = datetime.now(UTC)
99 queued = sorted(
100 (task for task in self._tasks.values() if task.status == "queued" and task.run_after <= now),
101 key=lambda task: (task.run_after, task.created_at),
102 )
103 claimed: list[MemoryTask] = []
104 for task in queued[:limit]:
105 task.status = "running"
106 task.claimed_by = worker_id
107 task.claimed_at = now
108 task.updated_at = now
109 task.attempts += 1
110 claimed.append(task)
111 return claimed
113 async def complete(self, task_id: str, result: dict[str, Any]) -> None:
114 task = self._tasks[task_id]
115 task.status = "succeeded"
116 task.result = result
117 task.error = None
118 task.updated_at = datetime.now(UTC)
120 async def fail(self, task_id: str, error: str, retry_at: datetime | None) -> None:
121 task = self._tasks[task_id]
122 task.error = error
123 task.updated_at = datetime.now(UTC)
124 if retry_at is None or task.attempts >= task.max_attempts:
125 task.status = "dead"
126 return
127 task.status = "queued"
128 task.run_after = retry_at
130 def get(self, task_id: str) -> MemoryTask | None:
131 return self._tasks.get(task_id)
133 def list_by_status(self, status: TaskStatus | None = None) -> list[MemoryTask]:
134 tasks = list(self._tasks.values())
135 if status is not None:
136 tasks = [task for task in tasks if task.status == status]
137 return sorted(tasks, key=lambda task: (task.created_at, task.id))
139 async def recover_stale(
140 self,
141 *,
142 stale_after: timedelta,
143 now: datetime | None = None,
144 ) -> int:
145 """Requeue running tasks whose worker claim is stale."""
146 current = now or datetime.now(UTC)
147 recovered = 0
148 for task in self._tasks.values():
149 if task.status != "running" or task.claimed_at is None:
150 continue
151 if current - task.claimed_at < stale_after:
152 continue
153 task.status = "queued" if task.attempts < task.max_attempts else "dead"
154 task.claimed_by = None
155 task.claimed_at = None
156 task.run_after = current
157 task.updated_at = current
158 recovered += 1
159 return recovered
162@dataclass
163class TaskHandlerContext:
164 """Dependencies available to task handlers."""
166 vector_store: VectorStore
167 llm_provider: LLMProvider
168 wiki_store: WikiStore | None = None
169 graph_store: GraphStore | None = None
170 compile_engine: CompileEngine | None = None
171 lint_engine: LintEngine | None = None
174class MemoryTaskDispatcher:
175 """Dispatches task rows to memory-quality handlers."""
177 def __init__(self, context: TaskHandlerContext) -> None:
178 self._ctx = context
180 async def run(self, task: MemoryTask) -> dict[str, Any]:
181 if task.task_type == COMPILE_BANK:
182 return await self._compile_bank(task)
183 if task.task_type == COMPILE_PERSONA_PAGE:
184 return await self._compile_persona_page(task)
185 if task.task_type == INDEX_WIKI_PAGE_VECTOR:
186 return await self._index_wiki_page_vector(task)
187 if task.task_type == PROJECT_ENTITY_EDGES:
188 return await self._project_entity_edges(task)
189 if task.task_type == NORMALIZE_TEMPORAL_FACTS:
190 return await self._normalize_temporal_facts(task)
191 if task.task_type == LINT_WIKI_PAGE:
192 return await self._lint_wiki_page(task)
193 raise ValueError(f"Unknown memory task type: {task.task_type}")
195 async def _compile_bank(self, task: MemoryTask) -> dict[str, Any]:
196 if self._ctx.compile_engine is None:
197 raise ValueError("compile_bank requires CompileEngine")
198 result = await self._ctx.compile_engine.run(
199 task.bank_id,
200 scope=_optional_str(task.payload.get("scope")),
201 )
202 return {
203 "bank_id": result.bank_id,
204 "scopes_compiled": result.scopes_compiled,
205 "pages_created": result.pages_created,
206 "pages_updated": result.pages_updated,
207 "noise_memories": result.noise_memories,
208 "tokens_used": result.tokens_used,
209 "elapsed_ms": result.elapsed_ms,
210 "error": result.error,
211 }
213 async def _compile_persona_page(self, task: MemoryTask) -> dict[str, Any]:
214 if self._ctx.wiki_store is None:
215 raise ValueError("compile_persona_page requires WikiStore")
216 person = _optional_str(task.payload.get("person"))
217 # Optional scope qualifier (e.g. ``"convo:convo-3"``) — when set,
218 # the persona page is keyed per-scope so distinct contexts produce
219 # distinct pages even when the person's name is the same. LoCoMo
220 # uses this to keep "Caroline in conversation_0" separate from
221 # "Caroline in conversation_3".
222 scope = _optional_str(task.payload.get("scope"))
223 source_ids = [str(value) for value in task.payload.get("source_ids", []) if value]
224 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id)
225 relevant = [
226 item for item in items if _item_matches_person(item, person) and (not source_ids or item.id in source_ids)
227 ]
228 # When scope is set, also filter the relevant items to those that
229 # belong to the same scope — otherwise the page would merge memories
230 # from every context that mentioned this person, defeating the
231 # whole point of scoping.
232 if scope is not None:
233 relevant = [item for item in relevant if _item_in_scope(item, scope)]
234 if person is None:
235 names = sorted({name for item in items for name in _item_person_names(item)})
236 if not names:
237 return {"pages_created": 0, "pages_updated": 0, "reason": "no_person_metadata"}
238 created = 0
239 updated = 0
240 page_ids: list[str] = []
241 for name in names:
242 subtask = MemoryTask(
243 task_type=COMPILE_PERSONA_PAGE,
244 bank_id=task.bank_id,
245 payload={"person": name, "scope": scope} if scope else {"person": name},
246 )
247 result = await self._compile_persona_page(subtask)
248 created += int(result.get("pages_created", 0))
249 updated += int(result.get("pages_updated", 0))
250 page_ids.extend(result.get("page_ids", []))
251 return {"pages_created": created, "pages_updated": updated, "page_ids": page_ids}
253 page_id = f"person:{_slug(scope)}:{_slug(person)}" if scope else f"person:{_slug(person)}"
254 existing = await self._ctx.wiki_store.get_page(page_id, task.bank_id)
255 page = await self._build_persona_page(task.bank_id, person, relevant, existing, page_id=page_id, scope=scope)
256 await self._ctx.wiki_store.upsert_page(page, task.bank_id)
257 result = {
258 "pages_created": 1 if existing is None else 0,
259 "pages_updated": 0 if existing is None else 1,
260 "page_ids": [page_id],
261 "source_count": len(relevant),
262 }
263 if task.payload.get("index_vector"):
264 index_result = await self._index_wiki_page_vector(
265 MemoryTask(
266 task_type=INDEX_WIKI_PAGE_VECTOR,
267 bank_id=task.bank_id,
268 payload={"page_id": page_id},
269 )
270 )
271 result.update(index_result)
272 return result
274 async def _index_wiki_page_vector(self, task: MemoryTask) -> dict[str, Any]:
275 if self._ctx.wiki_store is None:
276 raise ValueError("index_wiki_page_vector requires WikiStore")
277 page_id = _required_str(task.payload, "page_id")
278 page = await self._ctx.wiki_store.get_page(page_id, task.bank_id)
279 if page is None:
280 raise ValueError(f"Wiki page not found: {page_id}")
281 vectors = await self._ctx.llm_provider.embed([f"{page.title}\n\n{page.content[:1000]}"])
282 await self._ctx.vector_store.store_vectors(
283 [
284 VectorItem(
285 id=page.page_id,
286 bank_id=task.bank_id,
287 vector=vectors[0],
288 text=f"[WIKI:{page.kind}] {page.title}\n\n{page.content[:500]}",
289 metadata={"_wiki_source_ids": json.dumps(page.source_ids)},
290 tags=page.tags,
291 fact_type="wiki",
292 memory_layer="compiled",
293 retained_at=datetime.now(UTC),
294 )
295 ]
296 )
297 return {"indexed_page_id": page_id, "source_count": len(page.source_ids)}
299 async def _project_entity_edges(self, task: MemoryTask) -> dict[str, Any]:
300 if self._ctx.graph_store is None:
301 raise ValueError("project_entity_edges requires GraphStore")
302 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id)
303 entity_by_name: dict[str, Entity] = {}
304 memory_names: dict[str, set[str]] = {}
305 for item in items:
306 names = _item_person_names(item)
307 if not names:
308 continue
309 memory_names[item.id] = names
310 for name in names:
311 entity_by_name.setdefault(
312 name,
313 Entity(
314 id=f"person:{_slug(name)}",
315 name=name,
316 entity_type="PERSON",
317 aliases=[name],
318 metadata={"source": "task:project_entity_edges"},
319 ),
320 )
322 entities = list(entity_by_name.values())
323 entity_ids = await self._ctx.graph_store.store_entities(entities, task.bank_id)
324 id_by_name = dict(zip(entity_by_name.keys(), entity_ids, strict=False))
326 associations = [
327 MemoryEntityAssociation(memory_id=memory_id, entity_id=id_by_name[name])
328 for memory_id, names in memory_names.items()
329 for name in names
330 if name in id_by_name
331 ]
332 if associations:
333 await self._ctx.graph_store.link_memories_to_entities(associations, task.bank_id)
335 links: list[EntityLink] = []
336 for item in items:
337 names = sorted(memory_names.get(item.id, set()))
338 for i in range(len(names)):
339 for j in range(i + 1, len(names)):
340 links.append(
341 EntityLink(
342 entity_a=id_by_name[names[i]],
343 entity_b=id_by_name[names[j]],
344 link_type="co_occurs",
345 evidence=item.text[:500],
346 confidence=1.0,
347 created_at=datetime.now(UTC),
348 metadata={
349 "memory_id": item.id,
350 "session_id": str((item.metadata or {}).get("session_id") or ""),
351 "turn_ids": str((item.metadata or {}).get("locomo_turn_ids") or ""),
352 },
353 )
354 )
355 if links:
356 await self._ctx.graph_store.store_links(links, task.bank_id)
357 return {
358 "entities_projected": len(entities),
359 "associations_projected": len(associations),
360 "links_projected": len(links),
361 }
363 async def _normalize_temporal_facts(self, task: MemoryTask) -> dict[str, Any]:
364 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id)
365 updated: list[VectorItem] = []
366 for item in items:
367 metadata = dict(item.metadata or {})
368 normalized = temporal_metadata(item.text, item.occurred_at)
369 if not normalized:
370 continue
371 metadata.update(normalized)
372 updated.append(
373 VectorItem(
374 id=item.id,
375 bank_id=item.bank_id,
376 vector=item.vector,
377 text=item.text,
378 metadata=metadata,
379 tags=item.tags,
380 fact_type=item.fact_type,
381 occurred_at=item.occurred_at,
382 memory_layer=item.memory_layer,
383 retained_at=item.retained_at,
384 )
385 )
386 if updated:
387 await self._ctx.vector_store.store_vectors(updated)
388 return {"memories_scanned": len(items), "memories_updated": len(updated)}
390 async def _lint_wiki_page(self, task: MemoryTask) -> dict[str, Any]:
391 if self._ctx.lint_engine is None:
392 raise ValueError("lint_wiki_page requires LintEngine")
393 result = await self._ctx.lint_engine.run(task.bank_id)
394 page_id = _optional_str(task.payload.get("page_id"))
395 issues = result.issues
396 if page_id is not None:
397 issues = [issue for issue in issues if issue.page_id == page_id]
398 return {
399 "bank_id": result.bank_id,
400 "pages_checked": result.pages_checked,
401 "stale_count": sum(1 for issue in issues if issue.kind == "stale"),
402 "orphan_count": sum(1 for issue in issues if issue.kind == "orphan"),
403 "contradiction_count": sum(1 for issue in issues if issue.kind == "contradiction"),
404 "issues": [
405 {
406 "kind": issue.kind,
407 "page_id": issue.page_id,
408 "action": issue.action,
409 "detail": issue.detail,
410 "peer_page_id": issue.peer_page_id,
411 }
412 for issue in issues
413 ],
414 "elapsed_ms": result.elapsed_ms,
415 "error": result.error,
416 }
418 async def _build_persona_page(
419 self,
420 bank_id: str,
421 person: str,
422 items: list[VectorItem],
423 existing: WikiPage | None,
424 *,
425 page_id: str | None = None,
426 scope: str | None = None,
427 ) -> WikiPage:
428 source_ids = [item.id for item in items]
429 evidence = "\n".join(f"- {item.text[:500]}" for item in items[:50])
430 prompt = (
431 f"Compile a concise persona/preference wiki page for {person}.\n"
432 "Preserve stable facts: goals, likes, dislikes, relationships, plans, repeated activities, constraints.\n"
433 "Use only the evidence.\n\n"
434 f"Evidence:\n{evidence}"
435 )
436 completion = await self._ctx.llm_provider.complete(
437 [
438 Message(role="system", content="You maintain source-grounded person memory pages."),
439 Message(role="user", content=prompt),
440 ],
441 max_tokens=1200,
442 temperature=0.0,
443 )
444 content = completion.text.strip() or f"## {person}\n\nNo stable persona facts compiled."
445 if page_id is None:
446 page_id = f"person:{_slug(person)}"
447 # Tag the page with its scope (e.g. ``"convo:convo-3"``) when set,
448 # so scoped recall queries — which filter by the same tag — still
449 # surface the persona page. Without this tag a question scoped to
450 # ``convo:convo-3`` cannot retrieve its own conversation's persona.
451 tags = ["persona", f"person:{_slug(person)}"]
452 if scope:
453 tags.append(scope)
454 return WikiPage(
455 page_id=page_id,
456 bank_id=bank_id,
457 kind="entity",
458 title=person,
459 content=content,
460 scope=page_id,
461 source_ids=source_ids,
462 cross_links=[],
463 revision=(existing.revision + 1) if existing else 1,
464 revised_at=datetime.now(UTC),
465 tags=tags,
466 metadata={"task_type": COMPILE_PERSONA_PAGE},
467 )
470class MemoryTaskWorker:
471 """Small worker loop for backends that implement the task contract."""
473 def __init__(
474 self,
475 backend: TaskBackend,
476 dispatcher: MemoryTaskDispatcher,
477 *,
478 worker_id: str,
479 retry_delay_seconds: int = 60,
480 ) -> None:
481 self._backend = backend
482 self._dispatcher = dispatcher
483 self._worker_id = worker_id
484 self._retry_delay = retry_delay_seconds
486 async def run_once(self, *, limit: int = 10) -> int:
487 tasks = await self._backend.claim(self._worker_id, limit=limit)
488 for task in tasks:
489 try:
490 result = await self._dispatcher.run(task)
491 except Exception as exc:
492 retry_at = datetime.now(UTC) + timedelta(seconds=self._retry_delay)
493 await self._backend.fail(task.id, str(exc), retry_at)
494 else:
495 await self._backend.complete(task.id, result)
496 return len(tasks)
499def split_texts_by_token_budget(texts: list[str], max_tokens: int) -> list[list[str]]:
500 """Split retain payloads into sub-batches under an approximate token budget."""
501 if max_tokens <= 0:
502 raise ValueError("max_tokens must be positive")
503 batches: list[list[str]] = []
504 current: list[str] = []
505 current_tokens = 0
506 for text in texts:
507 token_estimate = max(1, len(text) // 4)
508 if current and current_tokens + token_estimate > max_tokens:
509 batches.append(current)
510 current = []
511 current_tokens = 0
512 current.append(text)
513 current_tokens += token_estimate
514 if current:
515 batches.append(current)
516 return batches
519async def _list_bank_vectors(vector_store: VectorStore, bank_id: str) -> list[VectorItem]:
520 items: list[VectorItem] = []
521 offset = 0
522 while True:
523 batch = await vector_store.list_vectors(bank_id, offset=offset, limit=500)
524 if not batch:
525 return items
526 items.extend(batch)
527 offset += len(batch)
530def _item_person_names(item: VectorItem) -> set[str]:
531 names: set[str] = set()
532 metadata = item.metadata or {}
533 for key in ("person", "locomo_persons", "locomo_speakers", "speaker"):
534 raw = metadata.get(key)
535 if raw:
536 names.update(part.strip() for part in str(raw).replace("|", ",").split(",") if part.strip())
537 names.update(match.group(0) for match in re.finditer(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?\b", item.text))
538 return {name for name in names if name.lower() not in {"the", "this", "that"}}
541def _item_matches_person(item: VectorItem, person: str | None) -> bool:
542 if person is None:
543 return True
544 return person.lower() in {name.lower() for name in _item_person_names(item)}
547def _item_in_scope(item: VectorItem, scope: str) -> bool:
548 """Return True when *item* belongs to *scope*.
550 Scope is matched against the item's tags (e.g. ``"convo:convo-3"``)
551 and against ``metadata.conversation_id`` (the LoCoMo benchmark stamps
552 both). Falls back to ``True`` when the scope marker isn't ``convo:*`` —
553 callers can extend this for other scope kinds without breaking the
554 LoCoMo path.
555 """
556 if not scope:
557 return True
558 tags = {str(tag).lower() for tag in (item.tags or [])}
559 if scope.lower() in tags:
560 return True
561 if scope.lower().startswith("convo:"):
562 convo_id = scope.split(":", 1)[1]
563 metadata = item.metadata or {}
564 if str(metadata.get("conversation_id", "")).lower() == convo_id.lower():
565 return True
566 return False
569def _slug(value: str) -> str:
570 return re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") or "unknown"
573def _optional_str(value: Any) -> str | None:
574 if value is None:
575 return None
576 text = str(value).strip()
577 return text or None
580def _required_str(payload: dict[str, Any], key: str) -> str:
581 value = _optional_str(payload.get(key))
582 if value is None:
583 raise ValueError(f"Task payload requires {key!r}")
584 return value