Coverage for astrocyte/integrations/haystack.py: 93%

42 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Haystack (deepset) integration — Astrocyte as a retriever component. 

2 

3Usage: 

4 from astrocyte import Astrocyte 

5 from astrocyte.integrations.haystack import AstrocyteRetriever 

6 

7 brain = Astrocyte.from_config("astrocyte.yaml") 

8 retriever = AstrocyteRetriever(brain, bank_id="knowledge-base") 

9 

10 # Use in a Haystack pipeline 

11 pipe = Pipeline() 

12 pipe.add_component("retriever", retriever) 

13 pipe.add_component("reader", reader) 

14 pipe.connect("retriever.documents", "reader.documents") 

15 

16 result = pipe.run({"retriever": {"query": "What is dark mode?"}}) 

17 

18Haystack uses a component pattern where each component has run() or arun() 

19methods with typed inputs/outputs. Retrievers return documents. 

20""" 

21 

22from __future__ import annotations 

23 

24from dataclasses import dataclass 

25from typing import TYPE_CHECKING, Any 

26 

27if TYPE_CHECKING: 

28 from astrocyte._astrocyte import Astrocyte 

29 

30from astrocyte.types import AstrocyteContext 

31 

32 

33@dataclass 

34class AstrocyteDocument: 

35 """Haystack-compatible document representation. 

36 

37 Mirrors haystack.Document with content, meta, score, and id fields. 

38 """ 

39 

40 content: str 

41 meta: dict[str, Any] 

42 score: float 

43 id: str 

44 

45 

46class AstrocyteRetriever: 

47 """Astrocyte-backed retriever for Haystack pipelines. 

48 

49 Implements Haystack's Retriever component pattern: 

50 - run(query, top_k) → {"documents": list[Document]} 

51 - Async via arun() 

52 

53 Documents returned use AstrocyteDocument (compatible with haystack.Document). 

54 """ 

55 

56 def __init__( 

57 self, 

58 brain: Astrocyte, 

59 bank_id: str, 

60 *, 

61 context: AstrocyteContext | None = None, 

62 top_k: int = 10, 

63 ) -> None: 

64 self.brain = brain 

65 self.bank_id = bank_id 

66 self._context = context 

67 self.top_k = top_k 

68 

69 async def arun( 

70 self, 

71 query: str, 

72 *, 

73 top_k: int | None = None, 

74 tags: list[str] | None = None, 

75 ) -> dict[str, list[AstrocyteDocument]]: 

76 """Async retrieval — returns {"documents": [...]}. 

77 

78 Haystack pipeline connects this output to downstream components. 

79 """ 

80 result = await self.brain.recall( 

81 query, 

82 bank_id=self.bank_id, 

83 max_results=top_k or self.top_k, 

84 tags=tags, 

85 context=self._context, 

86 ) 

87 documents = [ 

88 AstrocyteDocument( 

89 content=h.text, 

90 meta=dict(h.metadata) if h.metadata else {"source": "astrocyte"}, 

91 score=h.score, 

92 id=h.memory_id or "", 

93 ) 

94 for h in result.hits 

95 ] 

96 return {"documents": documents} 

97 

98 def run(self, query: str, **kwargs: Any) -> dict[str, list[AstrocyteDocument]]: 

99 """Synchronous retrieval for Haystack pipeline compatibility.""" 

100 from astrocyte.integrations._sync_utils import _run_async_from_sync 

101 

102 return _run_async_from_sync(self.arun(query, **kwargs)) 

103 

104 

105class AstrocyteWriter: 

106 """Astrocyte-backed document writer for Haystack pipelines. 

107 

108 Implements a Writer component: run(documents) → {"written": count}. 

109 """ 

110 

111 def __init__( 

112 self, 

113 brain: Astrocyte, 

114 bank_id: str, 

115 *, 

116 context: AstrocyteContext | None = None, 

117 ) -> None: 

118 self.brain = brain 

119 self.bank_id = bank_id 

120 self._context = context 

121 

122 async def arun(self, documents: list[AstrocyteDocument | dict[str, Any]]) -> dict[str, int]: 

123 """Write documents to Astrocyte memory.""" 

124 written = 0 

125 for doc in documents: 

126 if isinstance(doc, AstrocyteDocument): 

127 content = doc.content 

128 meta = doc.meta 

129 elif isinstance(doc, dict): 

130 content = doc.get("content", "") 

131 meta = doc.get("meta", {}) 

132 else: 

133 continue 

134 

135 result = await self.brain.retain( 

136 content, 

137 bank_id=self.bank_id, 

138 metadata=meta, 

139 tags=["haystack"], 

140 context=self._context, 

141 ) 

142 if result.stored: 

143 written += 1 

144 return {"written": written}