Coverage for astrocyte/integrations/smolagents.py: 74%

43 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Smolagents (HuggingFace) integration — Astrocyte as @tool functions. 

2 

3Usage: 

4 from astrocyte import Astrocyte 

5 from astrocyte.integrations.smolagents import astrocyte_smolagent_tools 

6 

7 brain = Astrocyte.from_config("astrocyte.yaml") 

8 tools = astrocyte_smolagent_tools(brain, bank_id="user-123") 

9 

10 # Use with smolagents 

11 from smolagents import CodeAgent, HfApiModel 

12 agent = CodeAgent(tools=tools, model=HfApiModel()) 

13 

14Smolagents uses a code-centric approach where tools are plain Python functions 

15with type annotations and docstrings. The agent writes Python code that calls 

16these functions. 

17""" 

18 

19from __future__ import annotations 

20 

21from typing import TYPE_CHECKING, Any 

22 

23if TYPE_CHECKING: 

24 from astrocyte._astrocyte import Astrocyte 

25 

26from astrocyte.types import AstrocyteContext 

27 

28 

29class AstrocyteSmolTool: 

30 """A single Astrocyte tool compatible with smolagents' Tool protocol. 

31 

32 Smolagents expects tools with: name, description, inputs (schema), output_type, 

33 and a __call__ or forward method. 

34 """ 

35 

36 def __init__( 

37 self, 

38 name: str, 

39 description: str, 

40 inputs: dict[str, dict[str, str]], 

41 output_type: str, 

42 fn: Any, 

43 ) -> None: 

44 self.name = name 

45 self.description = description 

46 self.inputs = inputs 

47 self.output_type = output_type 

48 self._fn = fn 

49 

50 async def forward(self, **kwargs: Any) -> Any: 

51 """Execute the tool (async).""" 

52 return await self._fn(**kwargs) 

53 

54 def __call__(self, **kwargs: Any) -> Any: 

55 """Sync fallback — smolagents may call tools synchronously.""" 

56 from astrocyte.integrations._sync_utils import _run_async_from_sync 

57 

58 return _run_async_from_sync(self._fn(**kwargs)) 

59 

60 

61def astrocyte_smolagent_tools( 

62 brain: Astrocyte, 

63 bank_id: str, 

64 *, 

65 context: AstrocyteContext | None = None, 

66 include_reflect: bool = True, 

67 include_forget: bool = False, 

68) -> list[AstrocyteSmolTool]: 

69 """Create smolagents-compatible tools backed by Astrocyte. 

70 

71 Returns a list of AstrocyteSmolTool instances that implement the 

72 smolagents Tool protocol (name, description, inputs, output_type, forward). 

73 """ 

74 tools: list[AstrocyteSmolTool] = [] 

75 

76 async def _retain(content: str, tags: str = "") -> str: 

77 tag_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else None 

78 result = await brain.retain(content, bank_id=bank_id, tags=tag_list, context=context) 

79 if result.stored: 

80 return f"Stored memory with id: {result.memory_id}" 

81 return f"Failed to store: {result.error}" 

82 

83 tools.append( 

84 AstrocyteSmolTool( 

85 name="memory_retain", 

86 description="Store content into long-term memory for future recall.", 

87 inputs={ 

88 "content": {"type": "string", "description": "The text to memorize."}, 

89 "tags": {"type": "string", "description": "Comma-separated tags (optional)."}, 

90 }, 

91 output_type="string", 

92 fn=_retain, 

93 ) 

94 ) 

95 

96 async def _recall(query: str, max_results: int = 5) -> str: 

97 result = await brain.recall(query, bank_id=bank_id, max_results=max_results, context=context) 

98 if not result.hits: 

99 return "No relevant memories found." 

100 lines = [f"- [{h.score:.2f}] {h.text}" for h in result.hits] 

101 return f"Found {len(result.hits)} memories:\n" + "\n".join(lines) 

102 

103 tools.append( 

104 AstrocyteSmolTool( 

105 name="memory_recall", 

106 description="Search long-term memory for information relevant to a query.", 

107 inputs={ 

108 "query": {"type": "string", "description": "Natural language search query."}, 

109 "max_results": {"type": "integer", "description": "Maximum number of results."}, 

110 }, 

111 output_type="string", 

112 fn=_recall, 

113 ) 

114 ) 

115 

116 if include_reflect: 

117 

118 async def _reflect(query: str) -> str: 

119 result = await brain.reflect(query, bank_id=bank_id, context=context) 

120 return result.answer 

121 

122 tools.append( 

123 AstrocyteSmolTool( 

124 name="memory_reflect", 

125 description="Synthesize a comprehensive answer from long-term memory.", 

126 inputs={ 

127 "query": {"type": "string", "description": "The question to answer."}, 

128 }, 

129 output_type="string", 

130 fn=_reflect, 

131 ) 

132 ) 

133 

134 if include_forget: 

135 

136 async def _forget(memory_ids: str) -> str: 

137 ids = [mid.strip() for mid in memory_ids.split(",")] 

138 result = await brain.forget(bank_id, memory_ids=ids, context=context) 

139 return f"Deleted {result.deleted_count} memories." 

140 

141 tools.append( 

142 AstrocyteSmolTool( 

143 name="memory_forget", 

144 description="Remove specific memories by their IDs (comma-separated).", 

145 inputs={ 

146 "memory_ids": {"type": "string", "description": "Comma-separated memory IDs."}, 

147 }, 

148 output_type="string", 

149 fn=_forget, 

150 ) 

151 ) 

152 

153 return tools