Coverage for astrocyte/integrations/langgraph.py: 92%
50 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""LangGraph / LangChain integration — Astrocyte as a memory store.
3Usage:
4 from astrocyte import Astrocyte
5 from astrocyte.integrations.langgraph import AstrocyteMemory
7 brain = Astrocyte.from_config("astrocyte.yaml")
8 memory = AstrocyteMemory(brain, bank_id="user-123")
9 # Optional: memory = AstrocyteMemory(..., context=AstrocyteContext(principal="user:me"))
11 # Use as LangGraph memory store
12 graph = StateGraph(AgentState)
13 app = graph.compile(checkpointer=memory)
15Maps:
16 - put / save_context → brain.retain()
17 - search → brain.recall()
18 - Thread ID → bank ID (configurable via thread_to_bank)
19"""
21from __future__ import annotations
23import logging
24from typing import TYPE_CHECKING, Any
26if TYPE_CHECKING:
27 from astrocyte._astrocyte import Astrocyte
29from astrocyte.integrations._sync_utils import _run_async_from_sync
30from astrocyte.types import AstrocyteContext
32logger = logging.getLogger("astrocyte.integrations.langgraph")
35class AstrocyteMemory:
36 """Astrocyte-backed memory for LangGraph / LangChain agents.
38 Implements the interface pattern expected by LangGraph's memory store:
39 save_context(), load_memory_variables(), search().
41 Pass optional ``context`` (:class:`~astrocyte.types.AstrocyteContext`) for
42 access control and OBO when ``access_control`` is enabled.
44 This is a thin wrapper — all policy enforcement happens inside Astrocyte.
45 """
47 def __init__(
48 self,
49 brain: Astrocyte,
50 bank_id: str,
51 *,
52 context: AstrocyteContext | None = None,
53 auto_retain: bool = False,
54 auto_retain_filter: str | None = None,
55 thread_to_bank: dict[str, str] | None = None,
56 ) -> None:
57 self.brain = brain
58 self.bank_id = bank_id
59 self._context = context
60 self.auto_retain = auto_retain
61 self.auto_retain_filter = auto_retain_filter
62 self._thread_to_bank = thread_to_bank or {}
64 def _resolve_bank(self, thread_id: str | None = None) -> str:
65 """Map thread ID to bank ID, falling back to default."""
66 if thread_id and thread_id in self._thread_to_bank:
67 return self._thread_to_bank[thread_id]
68 return self.bank_id
70 async def save_context(
71 self,
72 inputs: dict[str, Any],
73 outputs: dict[str, Any],
74 *,
75 thread_id: str | None = None,
76 tags: list[str] | None = None,
77 ) -> None:
78 """Save interaction context to memory (retain).
80 Combines inputs and outputs into a single memory entry.
81 """
82 bank = self._resolve_bank(thread_id)
84 # Build a concise representation of the interaction
85 parts: list[str] = []
86 if inputs:
87 for key, value in inputs.items():
88 parts.append(f"{key}: {value}")
89 if outputs:
90 for key, value in outputs.items():
91 parts.append(f"{key}: {value}")
93 if not parts:
94 return
96 content = "\n".join(parts)
97 result = await self.brain.retain(
98 content,
99 bank_id=bank,
100 tags=tags or ["langgraph"],
101 metadata={"source": "langgraph", "thread_id": thread_id or ""},
102 context=self._context,
103 )
104 if not result.stored:
105 logger.warning("LangGraph save_context failed for bank %s: %s", bank, result.error)
107 async def search(
108 self,
109 query: str,
110 *,
111 thread_id: str | None = None,
112 max_results: int = 5,
113 tags: list[str] | None = None,
114 ) -> list[dict[str, Any]]:
115 """Search memory (recall).
117 Returns a list of dicts with 'text', 'score', and 'metadata'.
118 """
119 bank = self._resolve_bank(thread_id)
120 result = await self.brain.recall(
121 query,
122 bank_id=bank,
123 max_results=max_results,
124 tags=tags,
125 context=self._context,
126 )
127 return [
128 {
129 "text": hit.text,
130 "score": hit.score,
131 "metadata": hit.metadata,
132 "memory_id": hit.memory_id,
133 }
134 for hit in result.hits
135 ]
137 async def load_memory_variables(
138 self,
139 inputs: dict[str, Any],
140 *,
141 thread_id: str | None = None,
142 ) -> dict[str, str]:
143 """Load relevant memories for the current input (LangChain pattern).
145 Returns {"memory": "formatted memories"} for injection into prompts.
146 """
147 query = " ".join(str(v) for v in inputs.values())
148 if not query.strip():
149 return {"memory": ""}
151 hits = await self.search(query, thread_id=thread_id)
152 if not hits:
153 return {"memory": ""}
155 formatted = "\n".join(f"- {h['text']}" for h in hits)
156 return {"memory": formatted}
158 # Sync wrappers for frameworks that don't support async
159 def save_context_sync(self, inputs: dict, outputs: dict, **kwargs: Any) -> None:
160 """Synchronous wrapper for save_context."""
161 _run_async_from_sync(self.save_context(inputs, outputs, **kwargs))
163 def search_sync(self, query: str, **kwargs: Any) -> list[dict]:
164 """Synchronous wrapper for search."""
165 return _run_async_from_sync(self.search(query, **kwargs))