Coverage for astrocyte/integrations/dspy.py: 92%
25 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""DSPy integration — Astrocyte as a DSPy retrieval module.
3Usage:
4 from astrocyte import Astrocyte
5 from astrocyte.integrations.dspy import AstrocyteRM
7 brain = Astrocyte.from_config("astrocyte.yaml")
8 retriever = AstrocyteRM(brain, bank_id="knowledge-base")
10 # Use as a DSPy retrieval model
11 import dspy
12 dspy.configure(rm=retriever)
14 # Or use directly
15 results = retriever("What is dark mode?", k=5)
17DSPy uses retrieval models (RM) that implement __call__(query, k) → list[str].
18The AstrocyteRM wraps brain.recall() to match this pattern.
19"""
21from __future__ import annotations
23from typing import TYPE_CHECKING, Any
25if TYPE_CHECKING:
26 from astrocyte._astrocyte import Astrocyte
28from astrocyte.integrations._sync_utils import _run_async_from_sync
29from astrocyte.types import AstrocyteContext
32class AstrocyteRM:
33 """Astrocyte-backed retrieval model for DSPy.
35 Implements DSPy's RM protocol: __call__(query, k) → list of passage strings.
36 Also provides async methods for direct use.
37 """
39 def __init__(
40 self,
41 brain: Astrocyte,
42 bank_id: str,
43 *,
44 context: AstrocyteContext | None = None,
45 default_k: int = 5,
46 ) -> None:
47 self.brain = brain
48 self.bank_id = bank_id
49 self._context = context
50 self.default_k = default_k
52 def __call__(self, query: str, k: int | None = None) -> list[str]:
53 """Synchronous retrieval (DSPy RM protocol).
55 Returns list of passage strings.
56 """
57 k = k or self.default_k
58 return _run_async_from_sync(self._retrieve(query, k))
60 async def _retrieve(self, query: str, k: int) -> list[str]:
61 result = await self.brain.recall(query, bank_id=self.bank_id, max_results=k, context=self._context)
62 return [h.text for h in result.hits]
64 async def aretrieve(self, query: str, k: int | None = None) -> list[str]:
65 """Async retrieval for use in async DSPy pipelines."""
66 return await self._retrieve(query, k or self.default_k)
68 async def aretain(self, content: str, **kwargs: Any) -> str | None:
69 """Store content for later retrieval. Returns memory_id."""
70 ctx = kwargs.pop("context", self._context)
71 result = await self.brain.retain(content, bank_id=self.bank_id, context=ctx, **kwargs)
72 return result.memory_id if result.stored else None
74 async def areflect(self, query: str) -> str:
75 """Synthesize an answer from memory."""
76 result = await self.brain.reflect(query, bank_id=self.bank_id, context=self._context)
77 return result.answer