Coverage for astrocyte/integrations/dspy.py: 92%

25 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""DSPy integration — Astrocyte as a DSPy retrieval module. 

2 

3Usage: 

4 from astrocyte import Astrocyte 

5 from astrocyte.integrations.dspy import AstrocyteRM 

6 

7 brain = Astrocyte.from_config("astrocyte.yaml") 

8 retriever = AstrocyteRM(brain, bank_id="knowledge-base") 

9 

10 # Use as a DSPy retrieval model 

11 import dspy 

12 dspy.configure(rm=retriever) 

13 

14 # Or use directly 

15 results = retriever("What is dark mode?", k=5) 

16 

17DSPy uses retrieval models (RM) that implement __call__(query, k) → list[str]. 

18The AstrocyteRM wraps brain.recall() to match this pattern. 

19""" 

20 

21from __future__ import annotations 

22 

23from typing import TYPE_CHECKING, Any 

24 

25if TYPE_CHECKING: 

26 from astrocyte._astrocyte import Astrocyte 

27 

28from astrocyte.integrations._sync_utils import _run_async_from_sync 

29from astrocyte.types import AstrocyteContext 

30 

31 

32class AstrocyteRM: 

33 """Astrocyte-backed retrieval model for DSPy. 

34 

35 Implements DSPy's RM protocol: __call__(query, k) → list of passage strings. 

36 Also provides async methods for direct use. 

37 """ 

38 

39 def __init__( 

40 self, 

41 brain: Astrocyte, 

42 bank_id: str, 

43 *, 

44 context: AstrocyteContext | None = None, 

45 default_k: int = 5, 

46 ) -> None: 

47 self.brain = brain 

48 self.bank_id = bank_id 

49 self._context = context 

50 self.default_k = default_k 

51 

52 def __call__(self, query: str, k: int | None = None) -> list[str]: 

53 """Synchronous retrieval (DSPy RM protocol). 

54 

55 Returns list of passage strings. 

56 """ 

57 k = k or self.default_k 

58 return _run_async_from_sync(self._retrieve(query, k)) 

59 

60 async def _retrieve(self, query: str, k: int) -> list[str]: 

61 result = await self.brain.recall(query, bank_id=self.bank_id, max_results=k, context=self._context) 

62 return [h.text for h in result.hits] 

63 

64 async def aretrieve(self, query: str, k: int | None = None) -> list[str]: 

65 """Async retrieval for use in async DSPy pipelines.""" 

66 return await self._retrieve(query, k or self.default_k) 

67 

68 async def aretain(self, content: str, **kwargs: Any) -> str | None: 

69 """Store content for later retrieval. Returns memory_id.""" 

70 ctx = kwargs.pop("context", self._context) 

71 result = await self.brain.retain(content, bank_id=self.bank_id, context=ctx, **kwargs) 

72 return result.memory_id if result.stored else None 

73 

74 async def areflect(self, query: str) -> str: 

75 """Synthesize an answer from memory.""" 

76 result = await self.brain.reflect(query, bank_id=self.bank_id, context=self._context) 

77 return result.answer