Coverage for astrocyte/integrations/claude_agent_sdk.py: 51%

67 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Claude Agent SDK integration — Astrocyte memory as native SDK tools. 

2 

3Usage: 

4 from astrocyte import Astrocyte 

5 from astrocyte.integrations.claude_agent_sdk import astrocyte_claude_agent_server 

6 

7 brain = Astrocyte.from_config("astrocyte.yaml") 

8 memory_server = astrocyte_claude_agent_server(brain, bank_id="user-123") 

9 

10 # Use with Claude Agent SDK 

11 from claude_agent_sdk import query, ClaudeAgentOptions 

12 

13 options = ClaudeAgentOptions( 

14 mcp_servers={"memory": memory_server}, 

15 allowed_tools=["mcp__memory__*"], 

16 ) 

17 

18 async for message in query(prompt="What do you remember?", options=options): 

19 ... 

20 

21The Claude Agent SDK uses @tool + create_sdk_mcp_server. Tools receive 

22args as dict[str, Any] and return {"content": [{"type": "text", "text": ...}]}. 

23Tools are registered as in-process MCP servers — no separate process needed. 

24 

25This integration provides TWO options: 

261. astrocyte_claude_agent_server() — returns an SDK MCP server (requires claude_agent_sdk installed) 

272. astrocyte_claude_agent_tools() — returns tool definition dicts (no SDK dependency, for testing) 

28""" 

29 

30from __future__ import annotations 

31 

32import json 

33from typing import TYPE_CHECKING, Any 

34 

35if TYPE_CHECKING: 

36 from astrocyte._astrocyte import Astrocyte 

37 from astrocyte.types import AstrocyteContext 

38 

39 

40# --------------------------------------------------------------------------- 

41# Tool definitions as plain dicts (no SDK dependency) 

42# --------------------------------------------------------------------------- 

43 

44 

45def astrocyte_claude_agent_tools( 

46 brain: Astrocyte, 

47 bank_id: str, 

48 *, 

49 include_reflect: bool = True, 

50 include_forget: bool = False, 

51 context: AstrocyteContext | None = None, 

52) -> list[dict[str, Any]]: 

53 """Create Claude Agent SDK tool definitions as plain dicts. 

54 

55 Each dict has: name, description, input_schema, handler. 

56 The handler follows the SDK pattern: receives dict[str, Any], returns 

57 {"content": [{"type": "text", "text": "..."}]}. 

58 

59 This works WITHOUT claude_agent_sdk installed (for testing). 

60 To get an actual SDK MCP server, use astrocyte_claude_agent_server(). 

61 """ 

62 tools: list[dict[str, Any]] = [] 

63 

64 # ── memory_retain ── 

65 

66 async def retain_handler(args: dict[str, Any]) -> dict[str, Any]: 

67 content = args["content"] 

68 tags = args.get("tags") 

69 tag_list = [t.strip() for t in tags.split(",")] if isinstance(tags, str) and tags else None 

70 result = await brain.retain(content, bank_id=bank_id, tags=tag_list, context=context) 

71 return { 

72 "content": [ 

73 { 

74 "type": "text", 

75 "text": json.dumps({"stored": result.stored, "memory_id": result.memory_id}), 

76 } 

77 ] 

78 } 

79 

80 tools.append( 

81 { 

82 "name": "memory_retain", 

83 "description": "Store content into long-term memory for future recall.", 

84 "input_schema": {"content": str}, 

85 "handler": retain_handler, 

86 } 

87 ) 

88 

89 # ── memory_recall ── 

90 

91 async def recall_handler(args: dict[str, Any]) -> dict[str, Any]: 

92 query = args["query"] 

93 max_results = args.get("max_results", 5) 

94 result = await brain.recall(query, bank_id=bank_id, max_results=max_results, context=context) 

95 hits = [{"text": h.text, "score": round(h.score, 4)} for h in result.hits] 

96 return { 

97 "content": [ 

98 { 

99 "type": "text", 

100 "text": json.dumps({"hits": hits, "total": result.total_available}), 

101 } 

102 ] 

103 } 

104 

105 tools.append( 

106 { 

107 "name": "memory_recall", 

108 "description": "Search long-term memory for information relevant to a query.", 

109 "input_schema": {"query": str, "max_results": int}, 

110 "handler": recall_handler, 

111 } 

112 ) 

113 

114 # ── memory_reflect ── 

115 

116 if include_reflect: 

117 

118 async def reflect_handler(args: dict[str, Any]) -> dict[str, Any]: 

119 query = args["query"] 

120 result = await brain.reflect(query, bank_id=bank_id, context=context) 

121 return { 

122 "content": [ 

123 { 

124 "type": "text", 

125 "text": result.answer, 

126 } 

127 ] 

128 } 

129 

130 tools.append( 

131 { 

132 "name": "memory_reflect", 

133 "description": "Synthesize a comprehensive answer from long-term memory.", 

134 "input_schema": {"query": str}, 

135 "handler": reflect_handler, 

136 } 

137 ) 

138 

139 # ── memory_forget ── 

140 

141 if include_forget: 

142 

143 async def forget_handler(args: dict[str, Any]) -> dict[str, Any]: 

144 memory_ids = args["memory_ids"] 

145 if isinstance(memory_ids, str): 

146 memory_ids = [mid.strip() for mid in memory_ids.split(",")] 

147 result = await brain.forget(bank_id, memory_ids=memory_ids, context=context) 

148 return { 

149 "content": [ 

150 { 

151 "type": "text", 

152 "text": json.dumps({"deleted_count": result.deleted_count}), 

153 } 

154 ] 

155 } 

156 

157 tools.append( 

158 { 

159 "name": "memory_forget", 

160 "description": "Remove specific memories by their IDs.", 

161 "input_schema": {"memory_ids": str}, 

162 "handler": forget_handler, 

163 } 

164 ) 

165 

166 return tools 

167 

168 

169# --------------------------------------------------------------------------- 

170# SDK MCP server (requires claude_agent_sdk installed) 

171# --------------------------------------------------------------------------- 

172 

173 

174def astrocyte_claude_agent_server( 

175 brain: Astrocyte, 

176 bank_id: str, 

177 *, 

178 server_name: str = "astrocyte_memory", 

179 include_reflect: bool = True, 

180 include_forget: bool = False, 

181 context: AstrocyteContext | None = None, 

182) -> Any: 

183 """Create a Claude Agent SDK in-process MCP server backed by Astrocyte. 

184 

185 Requires claude_agent_sdk to be installed. 

186 Returns an MCP server that can be passed to ClaudeAgentOptions.mcp_servers. 

187 

188 Usage: 

189 memory_server = astrocyte_claude_agent_server(brain, bank_id="user-123") 

190 options = ClaudeAgentOptions( 

191 mcp_servers={"memory": memory_server}, 

192 allowed_tools=["mcp__memory__*"], 

193 ) 

194 

195 For session-scoped memory with Managed Agents, see 

196 :mod:`astrocyte.integrations.managed_agents`. 

197 """ 

198 from claude_agent_sdk import ToolAnnotations, create_sdk_mcp_server, tool 

199 

200 sdk_tools = [] 

201 

202 @tool( 

203 "memory_retain", 

204 "Store content into long-term memory for future recall. " 

205 "Optionally pass comma-separated tags for categorization.", 

206 {"content": str, "tags": str}, 

207 ) 

208 async def memory_retain(args: dict[str, Any]) -> dict[str, Any]: 

209 content = args["content"] 

210 tags = args.get("tags") 

211 tag_list = [t.strip() for t in tags.split(",")] if isinstance(tags, str) and tags else None 

212 result = await brain.retain(content, bank_id=bank_id, tags=tag_list, context=context) 

213 return { 

214 "content": [{"type": "text", "text": json.dumps({"stored": result.stored, "memory_id": result.memory_id})}] 

215 } 

216 

217 sdk_tools.append(memory_retain) 

218 

219 @tool( 

220 "memory_recall", 

221 "Search long-term memory for information relevant to a query.", 

222 {"query": str, "max_results": int}, 

223 annotations=ToolAnnotations(readOnlyHint=True), 

224 ) 

225 async def memory_recall(args: dict[str, Any]) -> dict[str, Any]: 

226 query_text = args["query"] 

227 max_results = args.get("max_results", 5) 

228 result = await brain.recall(query_text, bank_id=bank_id, max_results=max_results, context=context) 

229 hits = [{"text": h.text, "score": round(h.score, 4)} for h in result.hits] 

230 return {"content": [{"type": "text", "text": json.dumps({"hits": hits, "total": result.total_available})}]} 

231 

232 sdk_tools.append(memory_recall) 

233 

234 if include_reflect: 

235 

236 @tool( 

237 "memory_reflect", 

238 "Synthesize a comprehensive answer from long-term memory. " 

239 "Use this instead of recall when you need a narrative answer rather than raw hits.", 

240 {"query": str}, 

241 annotations=ToolAnnotations(readOnlyHint=True), 

242 ) 

243 async def memory_reflect(args: dict[str, Any]) -> dict[str, Any]: 

244 result = await brain.reflect(args["query"], bank_id=bank_id, context=context) 

245 return {"content": [{"type": "text", "text": result.answer}]} 

246 

247 sdk_tools.append(memory_reflect) 

248 

249 if include_forget: 

250 

251 @tool( 

252 "memory_forget", 

253 "Remove specific memories by their IDs (comma-separated).", 

254 {"memory_ids": str}, 

255 annotations=ToolAnnotations(destructiveHint=True), 

256 ) 

257 async def memory_forget(args: dict[str, Any]) -> dict[str, Any]: 

258 ids = [mid.strip() for mid in args["memory_ids"].split(",")] 

259 result = await brain.forget(bank_id, memory_ids=ids, context=context) 

260 return {"content": [{"type": "text", "text": json.dumps({"deleted_count": result.deleted_count})}]} 

261 

262 sdk_tools.append(memory_forget) 

263 

264 return create_sdk_mcp_server( 

265 name=server_name, 

266 version="1.0.0", 

267 tools=sdk_tools, 

268 )