Coverage for astrocyte/integrations/langgraph.py: 92%

50 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""LangGraph / LangChain integration — Astrocyte as a memory store. 

2 

3Usage: 

4 from astrocyte import Astrocyte 

5 from astrocyte.integrations.langgraph import AstrocyteMemory 

6 

7 brain = Astrocyte.from_config("astrocyte.yaml") 

8 memory = AstrocyteMemory(brain, bank_id="user-123") 

9 # Optional: memory = AstrocyteMemory(..., context=AstrocyteContext(principal="user:me")) 

10 

11 # Use as LangGraph memory store 

12 graph = StateGraph(AgentState) 

13 app = graph.compile(checkpointer=memory) 

14 

15Maps: 

16 - put / save_context → brain.retain() 

17 - search → brain.recall() 

18 - Thread ID → bank ID (configurable via thread_to_bank) 

19""" 

20 

21from __future__ import annotations 

22 

23import logging 

24from typing import TYPE_CHECKING, Any 

25 

26if TYPE_CHECKING: 

27 from astrocyte._astrocyte import Astrocyte 

28 

29from astrocyte.integrations._sync_utils import _run_async_from_sync 

30from astrocyte.types import AstrocyteContext 

31 

32logger = logging.getLogger("astrocyte.integrations.langgraph") 

33 

34 

35class AstrocyteMemory: 

36 """Astrocyte-backed memory for LangGraph / LangChain agents. 

37 

38 Implements the interface pattern expected by LangGraph's memory store: 

39 save_context(), load_memory_variables(), search(). 

40 

41 Pass optional ``context`` (:class:`~astrocyte.types.AstrocyteContext`) for 

42 access control and OBO when ``access_control`` is enabled. 

43 

44 This is a thin wrapper — all policy enforcement happens inside Astrocyte. 

45 """ 

46 

47 def __init__( 

48 self, 

49 brain: Astrocyte, 

50 bank_id: str, 

51 *, 

52 context: AstrocyteContext | None = None, 

53 auto_retain: bool = False, 

54 auto_retain_filter: str | None = None, 

55 thread_to_bank: dict[str, str] | None = None, 

56 ) -> None: 

57 self.brain = brain 

58 self.bank_id = bank_id 

59 self._context = context 

60 self.auto_retain = auto_retain 

61 self.auto_retain_filter = auto_retain_filter 

62 self._thread_to_bank = thread_to_bank or {} 

63 

64 def _resolve_bank(self, thread_id: str | None = None) -> str: 

65 """Map thread ID to bank ID, falling back to default.""" 

66 if thread_id and thread_id in self._thread_to_bank: 

67 return self._thread_to_bank[thread_id] 

68 return self.bank_id 

69 

70 async def save_context( 

71 self, 

72 inputs: dict[str, Any], 

73 outputs: dict[str, Any], 

74 *, 

75 thread_id: str | None = None, 

76 tags: list[str] | None = None, 

77 ) -> None: 

78 """Save interaction context to memory (retain). 

79 

80 Combines inputs and outputs into a single memory entry. 

81 """ 

82 bank = self._resolve_bank(thread_id) 

83 

84 # Build a concise representation of the interaction 

85 parts: list[str] = [] 

86 if inputs: 

87 for key, value in inputs.items(): 

88 parts.append(f"{key}: {value}") 

89 if outputs: 

90 for key, value in outputs.items(): 

91 parts.append(f"{key}: {value}") 

92 

93 if not parts: 

94 return 

95 

96 content = "\n".join(parts) 

97 result = await self.brain.retain( 

98 content, 

99 bank_id=bank, 

100 tags=tags or ["langgraph"], 

101 metadata={"source": "langgraph", "thread_id": thread_id or ""}, 

102 context=self._context, 

103 ) 

104 if not result.stored: 

105 logger.warning("LangGraph save_context failed for bank %s: %s", bank, result.error) 

106 

107 async def search( 

108 self, 

109 query: str, 

110 *, 

111 thread_id: str | None = None, 

112 max_results: int = 5, 

113 tags: list[str] | None = None, 

114 ) -> list[dict[str, Any]]: 

115 """Search memory (recall). 

116 

117 Returns a list of dicts with 'text', 'score', and 'metadata'. 

118 """ 

119 bank = self._resolve_bank(thread_id) 

120 result = await self.brain.recall( 

121 query, 

122 bank_id=bank, 

123 max_results=max_results, 

124 tags=tags, 

125 context=self._context, 

126 ) 

127 return [ 

128 { 

129 "text": hit.text, 

130 "score": hit.score, 

131 "metadata": hit.metadata, 

132 "memory_id": hit.memory_id, 

133 } 

134 for hit in result.hits 

135 ] 

136 

137 async def load_memory_variables( 

138 self, 

139 inputs: dict[str, Any], 

140 *, 

141 thread_id: str | None = None, 

142 ) -> dict[str, str]: 

143 """Load relevant memories for the current input (LangChain pattern). 

144 

145 Returns {"memory": "formatted memories"} for injection into prompts. 

146 """ 

147 query = " ".join(str(v) for v in inputs.values()) 

148 if not query.strip(): 

149 return {"memory": ""} 

150 

151 hits = await self.search(query, thread_id=thread_id) 

152 if not hits: 

153 return {"memory": ""} 

154 

155 formatted = "\n".join(f"- {h['text']}" for h in hits) 

156 return {"memory": formatted} 

157 

158 # Sync wrappers for frameworks that don't support async 

159 def save_context_sync(self, inputs: dict, outputs: dict, **kwargs: Any) -> None: 

160 """Synchronous wrapper for save_context.""" 

161 _run_async_from_sync(self.save_context(inputs, outputs, **kwargs)) 

162 

163 def search_sync(self, query: str, **kwargs: Any) -> list[dict]: 

164 """Synchronous wrapper for search.""" 

165 return _run_async_from_sync(self.search(query, **kwargs))