Coverage for astrocyte/portability.py: 84%

164 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Memory portability — AMA (Astrocyte Memory Archive) export and import. 

2 

3AMA is a newline-delimited JSON (JSONL) format. Line 1 is the header, 

4subsequent lines are individual memories. Streamable, self-describing, 

5and FFI-safe (plain JSON, no Python-specific types). 

6 

7See docs/_design/memory-portability.md for the full specification. 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import logging 

14import os 

15from dataclasses import dataclass 

16from datetime import datetime, timezone 

17from pathlib import Path 

18from typing import Literal 

19 

20from astrocyte.types import MemoryHit, Metadata, RecallRequest, RecallResult, RetainRequest 

21 

22logger = logging.getLogger("astrocyte.portability") 

23 

24# --------------------------------------------------------------------------- 

25# Path containment (CWE-022) 

26# --------------------------------------------------------------------------- 

27# 

28# Path(path).resolve() canonicalises but does NOT contain — a caller can 

29# pass /etc/passwd and resolve() returns it unchanged. ``_safe_resolve`` 

30# validates that the resolved path stays within an explicit allow-list. 

31# 

32# Allow-list resolution order: 

33# 1. ``allowed_roots`` kwarg passed to the public function 

34# 2. ``ASTROCYTE_PORTABILITY_ROOTS`` env var (os.pathsep-joined) 

35# 

36# When neither (1) nor (2) is configured, ``_safe_resolve`` REFUSES to 

37# return a path unless the caller has explicitly opted into uncontained 

38# mode via ``allow_uncontained=True``. This eliminates the silent 

39# "no containment" gap that CodeQL CWE-022 (py/path-injection) flags 

40# and forces every caller to make a conscious security decision. 

41# 

42# Recommended usage: 

43# * Server / gateway code: set ``ASTROCYTE_PORTABILITY_ROOTS`` and 

44# leave ``allow_uncontained=False``. Untrusted HTTP input cannot 

45# escape the configured roots. 

46# * Library / CLI / unit tests with caller-controlled paths: pass 

47# ``allowed_roots=[<known dir>]`` explicitly. 

48# * Trusted internal call sites that genuinely need any path: pass 

49# ``allow_uncontained=True`` to make the decision audit-able. 

50 

51_PORTABILITY_ROOTS_ENV = "ASTROCYTE_PORTABILITY_ROOTS" 

52 

53# Null byte and ASCII control characters never have a legitimate place in 

54# a filesystem path. Reject them up front; resolve() does NOT strip them. 

55_ILLEGAL_PATH_CHAR_ORDS = frozenset(range(0x00, 0x20)) | {0x7F} 

56 

57 

58def _portability_roots() -> list[Path]: 

59 """Read containment roots from the environment.""" 

60 raw = os.environ.get(_PORTABILITY_ROOTS_ENV, "") 

61 return [Path(p).expanduser().resolve() for p in raw.split(os.pathsep) if p] 

62 

63 

64def _safe_resolve( 

65 path: str | Path, 

66 *, 

67 allowed_roots: list[str | Path] | None = None, 

68 allow_uncontained: bool = False, 

69) -> Path: 

70 """Resolve ``path`` and verify it stays within an allowed root. 

71 

72 See module docstring for the allow-list resolution order and the 

73 ``allow_uncontained`` opt-in semantics. 

74 

75 Raises: 

76 ValueError: If the path contains illegal control characters, 

77 escapes every allowed root, or no containment is configured 

78 and the caller did not pass ``allow_uncontained=True``. 

79 """ 

80 path_str = os.fspath(path) 

81 if any(ord(c) in _ILLEGAL_PATH_CHAR_ORDS for c in path_str): 

82 raise ValueError(f"Portability path contains illegal control character: {path_str!r}") 

83 # ``path_str`` is user-controlled, but the taint is neutralized by 

84 # the allow-list containment check below — ``resolved`` is matched 

85 # against ``allowed_roots`` (or the env-configured 

86 # ``_portability_roots()``) before the function returns. Callers 

87 # cannot opt out without passing ``allow_uncontained=True`` 

88 # explicitly. CodeQL's taint tracker doesn't see through allow-list 

89 # logic, so this whole file is in CodeQL's ``paths-ignore`` (see 

90 # ``.github/codeql/codeql-config.yml``). Threat model is locked by 

91 # ``tests/test_portability.py::TestPathContainment``. 

92 resolved = Path(path_str).expanduser().resolve() 

93 roots: list[Path] 

94 if allowed_roots: 

95 roots = [Path(r).expanduser().resolve() for r in allowed_roots] 

96 else: 

97 roots = _portability_roots() 

98 if not roots: 

99 if not allow_uncontained: 

100 raise ValueError( 

101 "Portability path containment is required. Provide one of:\n" 

102 " - allowed_roots=[<dir>, ...] kwarg, OR\n" 

103 f" - {_PORTABILITY_ROOTS_ENV} environment variable " 

104 "(os.pathsep-joined directories), OR\n" 

105 " - allow_uncontained=True for trusted internal callers." 

106 ) 

107 return resolved 

108 for root in roots: 

109 if resolved == root or resolved.is_relative_to(root): 

110 return resolved 

111 raise ValueError(f"Portability path escapes allowed roots: {resolved!s} not in {[str(r) for r in roots]}") 

112 

113 

114# --------------------------------------------------------------------------- 

115# AMA header 

116# --------------------------------------------------------------------------- 

117 

118AMA_VERSION = 1 

119 

120 

121@dataclass 

122class AmaHeader: 

123 """First line of an AMA file.""" 

124 

125 bank_id: str 

126 exported_at: str # ISO 8601 

127 provider: str 

128 memory_count: int 

129 _ama_version: int = AMA_VERSION 

130 

131 

132@dataclass 

133class AmaMemory: 

134 """One memory line in an AMA file.""" 

135 

136 id: str 

137 text: str 

138 fact_type: str | None = None 

139 tags: list[str] | None = None 

140 metadata: Metadata | None = None 

141 occurred_at: str | None = None # ISO 8601 

142 created_at: str | None = None # ISO 8601 

143 source: str | None = None 

144 bank_id: str | None = None 

145 entities: list[dict[str, str | list[str]]] | None = None 

146 embedding: list[float] | None = None 

147 

148 

149# --------------------------------------------------------------------------- 

150# Writer — export a bank to AMA JSONL 

151# --------------------------------------------------------------------------- 

152 

153 

154async def export_bank( 

155 recall_fn, 

156 bank_id: str, 

157 path: str | Path, 

158 provider_name: str = "unknown", 

159 include_embeddings: bool = False, 

160 include_entities: bool = True, 

161 batch_size: int = 100, 

162 *, 

163 allowed_roots: list[str | Path] | None = None, 

164 allow_uncontained: bool = False, 

165) -> int: 

166 """Export a memory bank to AMA JSONL format. 

167 

168 Args: 

169 recall_fn: Async callable that takes a RecallRequest and returns RecallResult. 

170 Typically ``brain._do_recall``. 

171 bank_id: Bank to export. 

172 path: Output file path. 

173 provider_name: Provider identifier for the header. 

174 include_embeddings: Include vector embeddings (not portable across models). 

175 include_entities: Include extracted entities. 

176 batch_size: Number of memories per recall batch. 

177 allowed_roots: Optional list of directory roots; the resolved 

178 ``path`` must fall under one of them. When ``None``, falls 

179 back to ``ASTROCYTE_PORTABILITY_ROOTS`` env var. 

180 allow_uncontained: When True, skip path containment if neither 

181 ``allowed_roots`` nor the env var is set. Use only for 

182 trusted internal callers — the default ``False`` raises if 

183 no containment is configured. See ``_safe_resolve``. 

184 

185 Returns: 

186 Number of memories exported. 

187 """ 

188 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained) 

189 path.parent.mkdir(parents=True, exist_ok=True) 

190 

191 # Collect all memories via recall with large limit 

192 all_hits: list[MemoryHit] = [] 

193 offset = 0 

194 while True: 

195 result: RecallResult = await recall_fn( 

196 RecallRequest( 

197 query="*", # Wildcard — retrieve everything 

198 bank_id=bank_id, 

199 max_results=batch_size, 

200 ) 

201 ) 

202 if not result.hits: 

203 break 

204 all_hits.extend(result.hits) 

205 # If we got fewer than batch_size, we've exhausted the bank 

206 if len(result.hits) < batch_size: 

207 break 

208 offset += batch_size 

209 # Safety: prevent infinite loops for providers that always return results 

210 if offset > 100000: 

211 break 

212 

213 # Deduplicate by memory_id 

214 seen: set[str] = set() 

215 unique_hits: list[MemoryHit] = [] 

216 for hit in all_hits: 

217 key = hit.memory_id or hit.text 

218 if key not in seen: 

219 seen.add(key) 

220 unique_hits.append(hit) 

221 

222 # Write AMA file 

223 now = datetime.now(timezone.utc).isoformat() 

224 header = { 

225 "_ama_version": AMA_VERSION, 

226 "bank_id": bank_id, 

227 "exported_at": now, 

228 "provider": provider_name, 

229 "memory_count": len(unique_hits), 

230 } 

231 

232 with open(path, "w", encoding="utf-8") as f: 

233 f.write(json.dumps(header, default=str) + "\n") 

234 for hit in unique_hits: 

235 record: dict = { 

236 "id": hit.memory_id or "", 

237 "text": hit.text, 

238 } 

239 if hit.fact_type: 

240 record["fact_type"] = hit.fact_type 

241 if hit.tags: 

242 record["tags"] = hit.tags 

243 if hit.metadata: 

244 record["metadata"] = hit.metadata 

245 if hit.occurred_at: 

246 record["occurred_at"] = hit.occurred_at.isoformat() 

247 if hit.source: 

248 record["source"] = hit.source 

249 if hit.bank_id: 

250 record["bank_id"] = hit.bank_id 

251 # Embeddings and entities would come from provider-specific data 

252 # For Phase 1, we export what's available in MemoryHit 

253 f.write(json.dumps(record, default=str) + "\n") 

254 

255 return len(unique_hits) 

256 

257 

258# --------------------------------------------------------------------------- 

259# Reader — iterate AMA JSONL lines 

260# --------------------------------------------------------------------------- 

261 

262 

263def read_ama_header( 

264 path: str | Path, 

265 *, 

266 allowed_roots: list[str | Path] | None = None, 

267 allow_uncontained: bool = False, 

268) -> AmaHeader: 

269 """Read and validate the AMA header (first line). 

270 

271 See ``export_bank`` for ``allowed_roots`` and ``allow_uncontained`` semantics. 

272 """ 

273 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained) 

274 with open(path, encoding="utf-8") as f: 

275 first_line = f.readline().strip() 

276 if not first_line: 

277 raise ValueError(f"AMA file is empty: {path}") 

278 data = json.loads(first_line) 

279 if "_ama_version" not in data: 

280 raise ValueError(f"Not a valid AMA file (missing _ama_version): {path}") 

281 if data["_ama_version"] != AMA_VERSION: 

282 raise ValueError(f"Unsupported AMA version {data['_ama_version']} (expected {AMA_VERSION})") 

283 # Validate required field types 

284 for field in ("bank_id", "exported_at", "provider"): 

285 if not isinstance(data.get(field), str): 

286 raise ValueError(f"AMA header field '{field}' must be a string: {path}") 

287 if not isinstance(data.get("memory_count"), int): 

288 raise ValueError(f"AMA header field 'memory_count' must be an integer: {path}") 

289 return AmaHeader( 

290 bank_id=data["bank_id"], 

291 exported_at=data["exported_at"], 

292 provider=data["provider"], 

293 memory_count=data["memory_count"], 

294 _ama_version=data["_ama_version"], 

295 ) 

296 

297 

298def iter_ama_memories( 

299 path: str | Path, 

300 *, 

301 allowed_roots: list[str | Path] | None = None, 

302 allow_uncontained: bool = False, 

303) -> list[AmaMemory]: 

304 """Read all memory records from an AMA file (skips header). 

305 

306 See ``export_bank`` for ``allowed_roots`` and ``allow_uncontained`` semantics. 

307 """ 

308 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained) 

309 memories: list[AmaMemory] = [] 

310 with open(path, encoding="utf-8") as f: 

311 # Skip header 

312 f.readline() 

313 for line_num, line in enumerate(f, start=2): 

314 line = line.strip() 

315 if not line: 

316 continue 

317 try: 

318 data = json.loads(line) 

319 if not isinstance(data, dict) or "id" not in data or "text" not in data: 

320 logger.warning("AMA line %d: missing required fields (id, text)", line_num) 

321 continue 

322 memories.append( 

323 AmaMemory( 

324 id=data["id"], 

325 text=data["text"], 

326 fact_type=data.get("fact_type"), 

327 tags=data.get("tags"), 

328 metadata=data.get("metadata"), 

329 occurred_at=data.get("occurred_at"), 

330 created_at=data.get("created_at"), 

331 source=data.get("source"), 

332 bank_id=data.get("bank_id"), 

333 entities=data.get("entities"), 

334 embedding=data.get("embedding"), 

335 ) 

336 ) 

337 except (json.JSONDecodeError, KeyError, TypeError) as exc: 

338 logger.warning("AMA line %d: %s", line_num, exc) 

339 continue 

340 return memories 

341 

342 

343# --------------------------------------------------------------------------- 

344# Import — load AMA into a bank 

345# --------------------------------------------------------------------------- 

346 

347 

348@dataclass 

349class ImportResult: 

350 imported: int 

351 skipped: int 

352 errors: int 

353 

354 

355async def import_bank( 

356 retain_fn, 

357 bank_id: str, 

358 path: str | Path, 

359 on_conflict: Literal["skip", "overwrite", "error"] = "skip", 

360 progress_fn=None, 

361 *, 

362 allowed_roots: list[str | Path] | None = None, 

363 allow_uncontained: bool = False, 

364) -> ImportResult: 

365 """Import memories from an AMA file into a bank. 

366 

367 Args: 

368 retain_fn: Async callable that takes a RetainRequest and returns RetainResult. 

369 Typically ``brain._do_retain``. 

370 bank_id: Target bank (may differ from source bank in AMA). 

371 path: Path to AMA JSONL file. 

372 on_conflict: How to handle memories with IDs that already exist. 

373 progress_fn: Optional callback(imported, total) for progress reporting. 

374 allowed_roots: See ``export_bank``. 

375 allow_uncontained: See ``export_bank``. 

376 

377 Returns: 

378 ImportResult with counts. 

379 """ 

380 header = read_ama_header(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained) 

381 memories = iter_ama_memories(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained) 

382 

383 imported = 0 

384 skipped = 0 

385 errors = 0 

386 

387 for i, mem in enumerate(memories): 

388 try: 

389 # Parse occurred_at if present 

390 occurred_at = None 

391 if mem.occurred_at: 

392 try: 

393 occurred_at = datetime.fromisoformat(mem.occurred_at) 

394 except ValueError: 

395 logger.debug("Skipping unparseable occurred_at: %s", mem.occurred_at) 

396 

397 request = RetainRequest( 

398 content=mem.text, 

399 bank_id=bank_id, 

400 metadata=mem.metadata, 

401 tags=mem.tags, 

402 occurred_at=occurred_at, 

403 source=mem.source or f"import:ama:{header.provider}", 

404 content_type="text", 

405 ) 

406 

407 result = await retain_fn(request) 

408 

409 if result.stored: 

410 imported += 1 

411 elif result.deduplicated and on_conflict == "skip": 

412 skipped += 1 

413 elif result.deduplicated and on_conflict == "error": 

414 errors += 1 

415 else: 

416 skipped += 1 

417 

418 except Exception as exc: 

419 logger.warning("AMA import line %d failed: %s", i + 2, exc) 

420 errors += 1 

421 

422 if progress_fn and (i + 1) % 10 == 0: 

423 progress_fn(imported, header.memory_count) 

424 

425 return ImportResult(imported=imported, skipped=skipped, errors=errors)