Coverage for astrocyte/integrations/claude_agent_sdk.py: 51%
67 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Claude Agent SDK integration — Astrocyte memory as native SDK tools.
3Usage:
4 from astrocyte import Astrocyte
5 from astrocyte.integrations.claude_agent_sdk import astrocyte_claude_agent_server
7 brain = Astrocyte.from_config("astrocyte.yaml")
8 memory_server = astrocyte_claude_agent_server(brain, bank_id="user-123")
10 # Use with Claude Agent SDK
11 from claude_agent_sdk import query, ClaudeAgentOptions
13 options = ClaudeAgentOptions(
14 mcp_servers={"memory": memory_server},
15 allowed_tools=["mcp__memory__*"],
16 )
18 async for message in query(prompt="What do you remember?", options=options):
19 ...
21The Claude Agent SDK uses @tool + create_sdk_mcp_server. Tools receive
22args as dict[str, Any] and return {"content": [{"type": "text", "text": ...}]}.
23Tools are registered as in-process MCP servers — no separate process needed.
25This integration provides TWO options:
261. astrocyte_claude_agent_server() — returns an SDK MCP server (requires claude_agent_sdk installed)
272. astrocyte_claude_agent_tools() — returns tool definition dicts (no SDK dependency, for testing)
28"""
30from __future__ import annotations
32import json
33from typing import TYPE_CHECKING, Any
35if TYPE_CHECKING:
36 from astrocyte._astrocyte import Astrocyte
37 from astrocyte.types import AstrocyteContext
40# ---------------------------------------------------------------------------
41# Tool definitions as plain dicts (no SDK dependency)
42# ---------------------------------------------------------------------------
45def astrocyte_claude_agent_tools(
46 brain: Astrocyte,
47 bank_id: str,
48 *,
49 include_reflect: bool = True,
50 include_forget: bool = False,
51 context: AstrocyteContext | None = None,
52) -> list[dict[str, Any]]:
53 """Create Claude Agent SDK tool definitions as plain dicts.
55 Each dict has: name, description, input_schema, handler.
56 The handler follows the SDK pattern: receives dict[str, Any], returns
57 {"content": [{"type": "text", "text": "..."}]}.
59 This works WITHOUT claude_agent_sdk installed (for testing).
60 To get an actual SDK MCP server, use astrocyte_claude_agent_server().
61 """
62 tools: list[dict[str, Any]] = []
64 # ── memory_retain ──
66 async def retain_handler(args: dict[str, Any]) -> dict[str, Any]:
67 content = args["content"]
68 tags = args.get("tags")
69 tag_list = [t.strip() for t in tags.split(",")] if isinstance(tags, str) and tags else None
70 result = await brain.retain(content, bank_id=bank_id, tags=tag_list, context=context)
71 return {
72 "content": [
73 {
74 "type": "text",
75 "text": json.dumps({"stored": result.stored, "memory_id": result.memory_id}),
76 }
77 ]
78 }
80 tools.append(
81 {
82 "name": "memory_retain",
83 "description": "Store content into long-term memory for future recall.",
84 "input_schema": {"content": str},
85 "handler": retain_handler,
86 }
87 )
89 # ── memory_recall ──
91 async def recall_handler(args: dict[str, Any]) -> dict[str, Any]:
92 query = args["query"]
93 max_results = args.get("max_results", 5)
94 result = await brain.recall(query, bank_id=bank_id, max_results=max_results, context=context)
95 hits = [{"text": h.text, "score": round(h.score, 4)} for h in result.hits]
96 return {
97 "content": [
98 {
99 "type": "text",
100 "text": json.dumps({"hits": hits, "total": result.total_available}),
101 }
102 ]
103 }
105 tools.append(
106 {
107 "name": "memory_recall",
108 "description": "Search long-term memory for information relevant to a query.",
109 "input_schema": {"query": str, "max_results": int},
110 "handler": recall_handler,
111 }
112 )
114 # ── memory_reflect ──
116 if include_reflect:
118 async def reflect_handler(args: dict[str, Any]) -> dict[str, Any]:
119 query = args["query"]
120 result = await brain.reflect(query, bank_id=bank_id, context=context)
121 return {
122 "content": [
123 {
124 "type": "text",
125 "text": result.answer,
126 }
127 ]
128 }
130 tools.append(
131 {
132 "name": "memory_reflect",
133 "description": "Synthesize a comprehensive answer from long-term memory.",
134 "input_schema": {"query": str},
135 "handler": reflect_handler,
136 }
137 )
139 # ── memory_forget ──
141 if include_forget:
143 async def forget_handler(args: dict[str, Any]) -> dict[str, Any]:
144 memory_ids = args["memory_ids"]
145 if isinstance(memory_ids, str):
146 memory_ids = [mid.strip() for mid in memory_ids.split(",")]
147 result = await brain.forget(bank_id, memory_ids=memory_ids, context=context)
148 return {
149 "content": [
150 {
151 "type": "text",
152 "text": json.dumps({"deleted_count": result.deleted_count}),
153 }
154 ]
155 }
157 tools.append(
158 {
159 "name": "memory_forget",
160 "description": "Remove specific memories by their IDs.",
161 "input_schema": {"memory_ids": str},
162 "handler": forget_handler,
163 }
164 )
166 return tools
169# ---------------------------------------------------------------------------
170# SDK MCP server (requires claude_agent_sdk installed)
171# ---------------------------------------------------------------------------
174def astrocyte_claude_agent_server(
175 brain: Astrocyte,
176 bank_id: str,
177 *,
178 server_name: str = "astrocyte_memory",
179 include_reflect: bool = True,
180 include_forget: bool = False,
181 context: AstrocyteContext | None = None,
182) -> Any:
183 """Create a Claude Agent SDK in-process MCP server backed by Astrocyte.
185 Requires claude_agent_sdk to be installed.
186 Returns an MCP server that can be passed to ClaudeAgentOptions.mcp_servers.
188 Usage:
189 memory_server = astrocyte_claude_agent_server(brain, bank_id="user-123")
190 options = ClaudeAgentOptions(
191 mcp_servers={"memory": memory_server},
192 allowed_tools=["mcp__memory__*"],
193 )
195 For session-scoped memory with Managed Agents, see
196 :mod:`astrocyte.integrations.managed_agents`.
197 """
198 from claude_agent_sdk import ToolAnnotations, create_sdk_mcp_server, tool
200 sdk_tools = []
202 @tool(
203 "memory_retain",
204 "Store content into long-term memory for future recall. "
205 "Optionally pass comma-separated tags for categorization.",
206 {"content": str, "tags": str},
207 )
208 async def memory_retain(args: dict[str, Any]) -> dict[str, Any]:
209 content = args["content"]
210 tags = args.get("tags")
211 tag_list = [t.strip() for t in tags.split(",")] if isinstance(tags, str) and tags else None
212 result = await brain.retain(content, bank_id=bank_id, tags=tag_list, context=context)
213 return {
214 "content": [{"type": "text", "text": json.dumps({"stored": result.stored, "memory_id": result.memory_id})}]
215 }
217 sdk_tools.append(memory_retain)
219 @tool(
220 "memory_recall",
221 "Search long-term memory for information relevant to a query.",
222 {"query": str, "max_results": int},
223 annotations=ToolAnnotations(readOnlyHint=True),
224 )
225 async def memory_recall(args: dict[str, Any]) -> dict[str, Any]:
226 query_text = args["query"]
227 max_results = args.get("max_results", 5)
228 result = await brain.recall(query_text, bank_id=bank_id, max_results=max_results, context=context)
229 hits = [{"text": h.text, "score": round(h.score, 4)} for h in result.hits]
230 return {"content": [{"type": "text", "text": json.dumps({"hits": hits, "total": result.total_available})}]}
232 sdk_tools.append(memory_recall)
234 if include_reflect:
236 @tool(
237 "memory_reflect",
238 "Synthesize a comprehensive answer from long-term memory. "
239 "Use this instead of recall when you need a narrative answer rather than raw hits.",
240 {"query": str},
241 annotations=ToolAnnotations(readOnlyHint=True),
242 )
243 async def memory_reflect(args: dict[str, Any]) -> dict[str, Any]:
244 result = await brain.reflect(args["query"], bank_id=bank_id, context=context)
245 return {"content": [{"type": "text", "text": result.answer}]}
247 sdk_tools.append(memory_reflect)
249 if include_forget:
251 @tool(
252 "memory_forget",
253 "Remove specific memories by their IDs (comma-separated).",
254 {"memory_ids": str},
255 annotations=ToolAnnotations(destructiveHint=True),
256 )
257 async def memory_forget(args: dict[str, Any]) -> dict[str, Any]:
258 ids = [mid.strip() for mid in args["memory_ids"].split(",")]
259 result = await brain.forget(bank_id, memory_ids=ids, context=context)
260 return {"content": [{"type": "text", "text": json.dumps({"deleted_count": result.deleted_count})}]}
262 sdk_tools.append(memory_forget)
264 return create_sdk_mcp_server(
265 name=server_name,
266 version="1.0.0",
267 tools=sdk_tools,
268 )