Coverage for astrocyte/pipeline/tasks.py: 83%

315 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Async memory task contract and handlers. 

2 

3This module is the framework-facing half of the Postgres task design. The 

4production backend can persist these ``MemoryTask`` rows in ``astrocyte_tasks``; 

5tests and single-process deployments can use ``InMemoryTaskBackend``. 

6""" 

7 

8from __future__ import annotations 

9 

10import json 

11import re 

12import uuid 

13from dataclasses import dataclass, field 

14from datetime import UTC, datetime, timedelta 

15from typing import TYPE_CHECKING, Any, Literal, Protocol 

16 

17from astrocyte.pipeline.temporal import temporal_metadata 

18from astrocyte.types import ( 

19 Entity, 

20 EntityLink, 

21 MemoryEntityAssociation, 

22 Message, 

23 VectorItem, 

24 WikiPage, 

25) 

26 

27if TYPE_CHECKING: 

28 from astrocyte.pipeline.compile import CompileEngine 

29 from astrocyte.pipeline.lint import LintEngine 

30 from astrocyte.provider import GraphStore, LLMProvider, VectorStore, WikiStore 

31 

32TaskStatus = Literal["queued", "running", "succeeded", "failed", "dead"] 

33 

34COMPILE_BANK = "compile_bank" 

35COMPILE_PERSONA_PAGE = "compile_persona_page" 

36INDEX_WIKI_PAGE_VECTOR = "index_wiki_page_vector" 

37PROJECT_ENTITY_EDGES = "project_entity_edges" 

38NORMALIZE_TEMPORAL_FACTS = "normalize_temporal_facts" 

39LINT_WIKI_PAGE = "lint_wiki_page" 

40 

41 

42@dataclass 

43class MemoryTask: 

44 """A durable background task for memory quality work.""" 

45 

46 task_type: str 

47 bank_id: str 

48 payload: dict[str, Any] = field(default_factory=dict) 

49 id: str = field(default_factory=lambda: uuid.uuid4().hex) 

50 status: TaskStatus = "queued" 

51 idempotency_key: str | None = None 

52 attempts: int = 0 

53 max_attempts: int = 5 

54 run_after: datetime = field(default_factory=lambda: datetime.now(UTC)) 

55 result: dict[str, Any] | None = None 

56 error: str | None = None 

57 claimed_by: str | None = None 

58 claimed_at: datetime | None = None 

59 created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) 

60 updated_at: datetime = field(default_factory=lambda: datetime.now(UTC)) 

61 

62 

63class TaskBackend(Protocol): 

64 """Storage contract for task queues, implemented by Postgres in production.""" 

65 

66 async def enqueue(self, task: MemoryTask) -> str: 

67 pass 

68 

69 async def claim(self, worker_id: str, limit: int = 10) -> list[MemoryTask]: 

70 pass 

71 

72 async def complete(self, task_id: str, result: dict[str, Any]) -> None: 

73 pass 

74 

75 async def fail(self, task_id: str, error: str, retry_at: datetime | None) -> None: 

76 pass 

77 

78 

79class InMemoryTaskBackend: 

80 """Deterministic task backend for unit tests and local development.""" 

81 

82 def __init__(self) -> None: 

83 self._tasks: dict[str, MemoryTask] = {} 

84 

85 async def enqueue(self, task: MemoryTask) -> str: 

86 if task.idempotency_key is not None: 

87 for existing in self._tasks.values(): 

88 if ( 

89 existing.task_type == task.task_type 

90 and existing.idempotency_key == task.idempotency_key 

91 and existing.status in {"queued", "running"} 

92 ): 

93 return existing.id 

94 self._tasks[task.id] = task 

95 return task.id 

96 

97 async def claim(self, worker_id: str, limit: int = 10) -> list[MemoryTask]: 

98 now = datetime.now(UTC) 

99 queued = sorted( 

100 (task for task in self._tasks.values() if task.status == "queued" and task.run_after <= now), 

101 key=lambda task: (task.run_after, task.created_at), 

102 ) 

103 claimed: list[MemoryTask] = [] 

104 for task in queued[:limit]: 

105 task.status = "running" 

106 task.claimed_by = worker_id 

107 task.claimed_at = now 

108 task.updated_at = now 

109 task.attempts += 1 

110 claimed.append(task) 

111 return claimed 

112 

113 async def complete(self, task_id: str, result: dict[str, Any]) -> None: 

114 task = self._tasks[task_id] 

115 task.status = "succeeded" 

116 task.result = result 

117 task.error = None 

118 task.updated_at = datetime.now(UTC) 

119 

120 async def fail(self, task_id: str, error: str, retry_at: datetime | None) -> None: 

121 task = self._tasks[task_id] 

122 task.error = error 

123 task.updated_at = datetime.now(UTC) 

124 if retry_at is None or task.attempts >= task.max_attempts: 

125 task.status = "dead" 

126 return 

127 task.status = "queued" 

128 task.run_after = retry_at 

129 

130 def get(self, task_id: str) -> MemoryTask | None: 

131 return self._tasks.get(task_id) 

132 

133 def list_by_status(self, status: TaskStatus | None = None) -> list[MemoryTask]: 

134 tasks = list(self._tasks.values()) 

135 if status is not None: 

136 tasks = [task for task in tasks if task.status == status] 

137 return sorted(tasks, key=lambda task: (task.created_at, task.id)) 

138 

139 async def recover_stale( 

140 self, 

141 *, 

142 stale_after: timedelta, 

143 now: datetime | None = None, 

144 ) -> int: 

145 """Requeue running tasks whose worker claim is stale.""" 

146 current = now or datetime.now(UTC) 

147 recovered = 0 

148 for task in self._tasks.values(): 

149 if task.status != "running" or task.claimed_at is None: 

150 continue 

151 if current - task.claimed_at < stale_after: 

152 continue 

153 task.status = "queued" if task.attempts < task.max_attempts else "dead" 

154 task.claimed_by = None 

155 task.claimed_at = None 

156 task.run_after = current 

157 task.updated_at = current 

158 recovered += 1 

159 return recovered 

160 

161 

162@dataclass 

163class TaskHandlerContext: 

164 """Dependencies available to task handlers.""" 

165 

166 vector_store: VectorStore 

167 llm_provider: LLMProvider 

168 wiki_store: WikiStore | None = None 

169 graph_store: GraphStore | None = None 

170 compile_engine: CompileEngine | None = None 

171 lint_engine: LintEngine | None = None 

172 

173 

174class MemoryTaskDispatcher: 

175 """Dispatches task rows to memory-quality handlers.""" 

176 

177 def __init__(self, context: TaskHandlerContext) -> None: 

178 self._ctx = context 

179 

180 async def run(self, task: MemoryTask) -> dict[str, Any]: 

181 if task.task_type == COMPILE_BANK: 

182 return await self._compile_bank(task) 

183 if task.task_type == COMPILE_PERSONA_PAGE: 

184 return await self._compile_persona_page(task) 

185 if task.task_type == INDEX_WIKI_PAGE_VECTOR: 

186 return await self._index_wiki_page_vector(task) 

187 if task.task_type == PROJECT_ENTITY_EDGES: 

188 return await self._project_entity_edges(task) 

189 if task.task_type == NORMALIZE_TEMPORAL_FACTS: 

190 return await self._normalize_temporal_facts(task) 

191 if task.task_type == LINT_WIKI_PAGE: 

192 return await self._lint_wiki_page(task) 

193 raise ValueError(f"Unknown memory task type: {task.task_type}") 

194 

195 async def _compile_bank(self, task: MemoryTask) -> dict[str, Any]: 

196 if self._ctx.compile_engine is None: 

197 raise ValueError("compile_bank requires CompileEngine") 

198 result = await self._ctx.compile_engine.run( 

199 task.bank_id, 

200 scope=_optional_str(task.payload.get("scope")), 

201 ) 

202 return { 

203 "bank_id": result.bank_id, 

204 "scopes_compiled": result.scopes_compiled, 

205 "pages_created": result.pages_created, 

206 "pages_updated": result.pages_updated, 

207 "noise_memories": result.noise_memories, 

208 "tokens_used": result.tokens_used, 

209 "elapsed_ms": result.elapsed_ms, 

210 "error": result.error, 

211 } 

212 

213 async def _compile_persona_page(self, task: MemoryTask) -> dict[str, Any]: 

214 if self._ctx.wiki_store is None: 

215 raise ValueError("compile_persona_page requires WikiStore") 

216 person = _optional_str(task.payload.get("person")) 

217 # Optional scope qualifier (e.g. ``"convo:convo-3"``) — when set, 

218 # the persona page is keyed per-scope so distinct contexts produce 

219 # distinct pages even when the person's name is the same. LoCoMo 

220 # uses this to keep "Caroline in conversation_0" separate from 

221 # "Caroline in conversation_3". 

222 scope = _optional_str(task.payload.get("scope")) 

223 source_ids = [str(value) for value in task.payload.get("source_ids", []) if value] 

224 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id) 

225 relevant = [ 

226 item for item in items if _item_matches_person(item, person) and (not source_ids or item.id in source_ids) 

227 ] 

228 # When scope is set, also filter the relevant items to those that 

229 # belong to the same scope — otherwise the page would merge memories 

230 # from every context that mentioned this person, defeating the 

231 # whole point of scoping. 

232 if scope is not None: 

233 relevant = [item for item in relevant if _item_in_scope(item, scope)] 

234 if person is None: 

235 names = sorted({name for item in items for name in _item_person_names(item)}) 

236 if not names: 

237 return {"pages_created": 0, "pages_updated": 0, "reason": "no_person_metadata"} 

238 created = 0 

239 updated = 0 

240 page_ids: list[str] = [] 

241 for name in names: 

242 subtask = MemoryTask( 

243 task_type=COMPILE_PERSONA_PAGE, 

244 bank_id=task.bank_id, 

245 payload={"person": name, "scope": scope} if scope else {"person": name}, 

246 ) 

247 result = await self._compile_persona_page(subtask) 

248 created += int(result.get("pages_created", 0)) 

249 updated += int(result.get("pages_updated", 0)) 

250 page_ids.extend(result.get("page_ids", [])) 

251 return {"pages_created": created, "pages_updated": updated, "page_ids": page_ids} 

252 

253 page_id = f"person:{_slug(scope)}:{_slug(person)}" if scope else f"person:{_slug(person)}" 

254 existing = await self._ctx.wiki_store.get_page(page_id, task.bank_id) 

255 page = await self._build_persona_page(task.bank_id, person, relevant, existing, page_id=page_id, scope=scope) 

256 await self._ctx.wiki_store.upsert_page(page, task.bank_id) 

257 result = { 

258 "pages_created": 1 if existing is None else 0, 

259 "pages_updated": 0 if existing is None else 1, 

260 "page_ids": [page_id], 

261 "source_count": len(relevant), 

262 } 

263 if task.payload.get("index_vector"): 

264 index_result = await self._index_wiki_page_vector( 

265 MemoryTask( 

266 task_type=INDEX_WIKI_PAGE_VECTOR, 

267 bank_id=task.bank_id, 

268 payload={"page_id": page_id}, 

269 ) 

270 ) 

271 result.update(index_result) 

272 return result 

273 

274 async def _index_wiki_page_vector(self, task: MemoryTask) -> dict[str, Any]: 

275 if self._ctx.wiki_store is None: 

276 raise ValueError("index_wiki_page_vector requires WikiStore") 

277 page_id = _required_str(task.payload, "page_id") 

278 page = await self._ctx.wiki_store.get_page(page_id, task.bank_id) 

279 if page is None: 

280 raise ValueError(f"Wiki page not found: {page_id}") 

281 vectors = await self._ctx.llm_provider.embed([f"{page.title}\n\n{page.content[:1000]}"]) 

282 await self._ctx.vector_store.store_vectors( 

283 [ 

284 VectorItem( 

285 id=page.page_id, 

286 bank_id=task.bank_id, 

287 vector=vectors[0], 

288 text=f"[WIKI:{page.kind}] {page.title}\n\n{page.content[:500]}", 

289 metadata={"_wiki_source_ids": json.dumps(page.source_ids)}, 

290 tags=page.tags, 

291 fact_type="wiki", 

292 memory_layer="compiled", 

293 retained_at=datetime.now(UTC), 

294 ) 

295 ] 

296 ) 

297 return {"indexed_page_id": page_id, "source_count": len(page.source_ids)} 

298 

299 async def _project_entity_edges(self, task: MemoryTask) -> dict[str, Any]: 

300 if self._ctx.graph_store is None: 

301 raise ValueError("project_entity_edges requires GraphStore") 

302 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id) 

303 entity_by_name: dict[str, Entity] = {} 

304 memory_names: dict[str, set[str]] = {} 

305 for item in items: 

306 names = _item_person_names(item) 

307 if not names: 

308 continue 

309 memory_names[item.id] = names 

310 for name in names: 

311 entity_by_name.setdefault( 

312 name, 

313 Entity( 

314 id=f"person:{_slug(name)}", 

315 name=name, 

316 entity_type="PERSON", 

317 aliases=[name], 

318 metadata={"source": "task:project_entity_edges"}, 

319 ), 

320 ) 

321 

322 entities = list(entity_by_name.values()) 

323 entity_ids = await self._ctx.graph_store.store_entities(entities, task.bank_id) 

324 id_by_name = dict(zip(entity_by_name.keys(), entity_ids, strict=False)) 

325 

326 associations = [ 

327 MemoryEntityAssociation(memory_id=memory_id, entity_id=id_by_name[name]) 

328 for memory_id, names in memory_names.items() 

329 for name in names 

330 if name in id_by_name 

331 ] 

332 if associations: 

333 await self._ctx.graph_store.link_memories_to_entities(associations, task.bank_id) 

334 

335 links: list[EntityLink] = [] 

336 for item in items: 

337 names = sorted(memory_names.get(item.id, set())) 

338 for i in range(len(names)): 

339 for j in range(i + 1, len(names)): 

340 links.append( 

341 EntityLink( 

342 entity_a=id_by_name[names[i]], 

343 entity_b=id_by_name[names[j]], 

344 link_type="co_occurs", 

345 evidence=item.text[:500], 

346 confidence=1.0, 

347 created_at=datetime.now(UTC), 

348 metadata={ 

349 "memory_id": item.id, 

350 "session_id": str((item.metadata or {}).get("session_id") or ""), 

351 "turn_ids": str((item.metadata or {}).get("locomo_turn_ids") or ""), 

352 }, 

353 ) 

354 ) 

355 if links: 

356 await self._ctx.graph_store.store_links(links, task.bank_id) 

357 return { 

358 "entities_projected": len(entities), 

359 "associations_projected": len(associations), 

360 "links_projected": len(links), 

361 } 

362 

363 async def _normalize_temporal_facts(self, task: MemoryTask) -> dict[str, Any]: 

364 items = await _list_bank_vectors(self._ctx.vector_store, task.bank_id) 

365 updated: list[VectorItem] = [] 

366 for item in items: 

367 metadata = dict(item.metadata or {}) 

368 normalized = temporal_metadata(item.text, item.occurred_at) 

369 if not normalized: 

370 continue 

371 metadata.update(normalized) 

372 updated.append( 

373 VectorItem( 

374 id=item.id, 

375 bank_id=item.bank_id, 

376 vector=item.vector, 

377 text=item.text, 

378 metadata=metadata, 

379 tags=item.tags, 

380 fact_type=item.fact_type, 

381 occurred_at=item.occurred_at, 

382 memory_layer=item.memory_layer, 

383 retained_at=item.retained_at, 

384 ) 

385 ) 

386 if updated: 

387 await self._ctx.vector_store.store_vectors(updated) 

388 return {"memories_scanned": len(items), "memories_updated": len(updated)} 

389 

390 async def _lint_wiki_page(self, task: MemoryTask) -> dict[str, Any]: 

391 if self._ctx.lint_engine is None: 

392 raise ValueError("lint_wiki_page requires LintEngine") 

393 result = await self._ctx.lint_engine.run(task.bank_id) 

394 page_id = _optional_str(task.payload.get("page_id")) 

395 issues = result.issues 

396 if page_id is not None: 

397 issues = [issue for issue in issues if issue.page_id == page_id] 

398 return { 

399 "bank_id": result.bank_id, 

400 "pages_checked": result.pages_checked, 

401 "stale_count": sum(1 for issue in issues if issue.kind == "stale"), 

402 "orphan_count": sum(1 for issue in issues if issue.kind == "orphan"), 

403 "contradiction_count": sum(1 for issue in issues if issue.kind == "contradiction"), 

404 "issues": [ 

405 { 

406 "kind": issue.kind, 

407 "page_id": issue.page_id, 

408 "action": issue.action, 

409 "detail": issue.detail, 

410 "peer_page_id": issue.peer_page_id, 

411 } 

412 for issue in issues 

413 ], 

414 "elapsed_ms": result.elapsed_ms, 

415 "error": result.error, 

416 } 

417 

418 async def _build_persona_page( 

419 self, 

420 bank_id: str, 

421 person: str, 

422 items: list[VectorItem], 

423 existing: WikiPage | None, 

424 *, 

425 page_id: str | None = None, 

426 scope: str | None = None, 

427 ) -> WikiPage: 

428 source_ids = [item.id for item in items] 

429 evidence = "\n".join(f"- {item.text[:500]}" for item in items[:50]) 

430 prompt = ( 

431 f"Compile a concise persona/preference wiki page for {person}.\n" 

432 "Preserve stable facts: goals, likes, dislikes, relationships, plans, repeated activities, constraints.\n" 

433 "Use only the evidence.\n\n" 

434 f"Evidence:\n{evidence}" 

435 ) 

436 completion = await self._ctx.llm_provider.complete( 

437 [ 

438 Message(role="system", content="You maintain source-grounded person memory pages."), 

439 Message(role="user", content=prompt), 

440 ], 

441 max_tokens=1200, 

442 temperature=0.0, 

443 ) 

444 content = completion.text.strip() or f"## {person}\n\nNo stable persona facts compiled." 

445 if page_id is None: 

446 page_id = f"person:{_slug(person)}" 

447 # Tag the page with its scope (e.g. ``"convo:convo-3"``) when set, 

448 # so scoped recall queries — which filter by the same tag — still 

449 # surface the persona page. Without this tag a question scoped to 

450 # ``convo:convo-3`` cannot retrieve its own conversation's persona. 

451 tags = ["persona", f"person:{_slug(person)}"] 

452 if scope: 

453 tags.append(scope) 

454 return WikiPage( 

455 page_id=page_id, 

456 bank_id=bank_id, 

457 kind="entity", 

458 title=person, 

459 content=content, 

460 scope=page_id, 

461 source_ids=source_ids, 

462 cross_links=[], 

463 revision=(existing.revision + 1) if existing else 1, 

464 revised_at=datetime.now(UTC), 

465 tags=tags, 

466 metadata={"task_type": COMPILE_PERSONA_PAGE}, 

467 ) 

468 

469 

470class MemoryTaskWorker: 

471 """Small worker loop for backends that implement the task contract.""" 

472 

473 def __init__( 

474 self, 

475 backend: TaskBackend, 

476 dispatcher: MemoryTaskDispatcher, 

477 *, 

478 worker_id: str, 

479 retry_delay_seconds: int = 60, 

480 ) -> None: 

481 self._backend = backend 

482 self._dispatcher = dispatcher 

483 self._worker_id = worker_id 

484 self._retry_delay = retry_delay_seconds 

485 

486 async def run_once(self, *, limit: int = 10) -> int: 

487 tasks = await self._backend.claim(self._worker_id, limit=limit) 

488 for task in tasks: 

489 try: 

490 result = await self._dispatcher.run(task) 

491 except Exception as exc: 

492 retry_at = datetime.now(UTC) + timedelta(seconds=self._retry_delay) 

493 await self._backend.fail(task.id, str(exc), retry_at) 

494 else: 

495 await self._backend.complete(task.id, result) 

496 return len(tasks) 

497 

498 

499def split_texts_by_token_budget(texts: list[str], max_tokens: int) -> list[list[str]]: 

500 """Split retain payloads into sub-batches under an approximate token budget.""" 

501 if max_tokens <= 0: 

502 raise ValueError("max_tokens must be positive") 

503 batches: list[list[str]] = [] 

504 current: list[str] = [] 

505 current_tokens = 0 

506 for text in texts: 

507 token_estimate = max(1, len(text) // 4) 

508 if current and current_tokens + token_estimate > max_tokens: 

509 batches.append(current) 

510 current = [] 

511 current_tokens = 0 

512 current.append(text) 

513 current_tokens += token_estimate 

514 if current: 

515 batches.append(current) 

516 return batches 

517 

518 

519async def _list_bank_vectors(vector_store: VectorStore, bank_id: str) -> list[VectorItem]: 

520 items: list[VectorItem] = [] 

521 offset = 0 

522 while True: 

523 batch = await vector_store.list_vectors(bank_id, offset=offset, limit=500) 

524 if not batch: 

525 return items 

526 items.extend(batch) 

527 offset += len(batch) 

528 

529 

530def _item_person_names(item: VectorItem) -> set[str]: 

531 names: set[str] = set() 

532 metadata = item.metadata or {} 

533 for key in ("person", "locomo_persons", "locomo_speakers", "speaker"): 

534 raw = metadata.get(key) 

535 if raw: 

536 names.update(part.strip() for part in str(raw).replace("|", ",").split(",") if part.strip()) 

537 names.update(match.group(0) for match in re.finditer(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?\b", item.text)) 

538 return {name for name in names if name.lower() not in {"the", "this", "that"}} 

539 

540 

541def _item_matches_person(item: VectorItem, person: str | None) -> bool: 

542 if person is None: 

543 return True 

544 return person.lower() in {name.lower() for name in _item_person_names(item)} 

545 

546 

547def _item_in_scope(item: VectorItem, scope: str) -> bool: 

548 """Return True when *item* belongs to *scope*. 

549 

550 Scope is matched against the item's tags (e.g. ``"convo:convo-3"``) 

551 and against ``metadata.conversation_id`` (the LoCoMo benchmark stamps 

552 both). Falls back to ``True`` when the scope marker isn't ``convo:*`` — 

553 callers can extend this for other scope kinds without breaking the 

554 LoCoMo path. 

555 """ 

556 if not scope: 

557 return True 

558 tags = {str(tag).lower() for tag in (item.tags or [])} 

559 if scope.lower() in tags: 

560 return True 

561 if scope.lower().startswith("convo:"): 

562 convo_id = scope.split(":", 1)[1] 

563 metadata = item.metadata or {} 

564 if str(metadata.get("conversation_id", "")).lower() == convo_id.lower(): 

565 return True 

566 return False 

567 

568 

569def _slug(value: str) -> str: 

570 return re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") or "unknown" 

571 

572 

573def _optional_str(value: Any) -> str | None: 

574 if value is None: 

575 return None 

576 text = str(value).strip() 

577 return text or None 

578 

579 

580def _required_str(payload: dict[str, Any], key: str) -> str: 

581 value = _optional_str(payload.get(key)) 

582 if value is None: 

583 raise ValueError(f"Task payload requires {key!r}") 

584 return value