Coverage for astrocyte/portability.py: 84%
164 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Memory portability — AMA (Astrocyte Memory Archive) export and import.
3AMA is a newline-delimited JSON (JSONL) format. Line 1 is the header,
4subsequent lines are individual memories. Streamable, self-describing,
5and FFI-safe (plain JSON, no Python-specific types).
7See docs/_design/memory-portability.md for the full specification.
8"""
10from __future__ import annotations
12import json
13import logging
14import os
15from dataclasses import dataclass
16from datetime import datetime, timezone
17from pathlib import Path
18from typing import Literal
20from astrocyte.types import MemoryHit, Metadata, RecallRequest, RecallResult, RetainRequest
22logger = logging.getLogger("astrocyte.portability")
24# ---------------------------------------------------------------------------
25# Path containment (CWE-022)
26# ---------------------------------------------------------------------------
27#
28# Path(path).resolve() canonicalises but does NOT contain — a caller can
29# pass /etc/passwd and resolve() returns it unchanged. ``_safe_resolve``
30# validates that the resolved path stays within an explicit allow-list.
31#
32# Allow-list resolution order:
33# 1. ``allowed_roots`` kwarg passed to the public function
34# 2. ``ASTROCYTE_PORTABILITY_ROOTS`` env var (os.pathsep-joined)
35#
36# When neither (1) nor (2) is configured, ``_safe_resolve`` REFUSES to
37# return a path unless the caller has explicitly opted into uncontained
38# mode via ``allow_uncontained=True``. This eliminates the silent
39# "no containment" gap that CodeQL CWE-022 (py/path-injection) flags
40# and forces every caller to make a conscious security decision.
41#
42# Recommended usage:
43# * Server / gateway code: set ``ASTROCYTE_PORTABILITY_ROOTS`` and
44# leave ``allow_uncontained=False``. Untrusted HTTP input cannot
45# escape the configured roots.
46# * Library / CLI / unit tests with caller-controlled paths: pass
47# ``allowed_roots=[<known dir>]`` explicitly.
48# * Trusted internal call sites that genuinely need any path: pass
49# ``allow_uncontained=True`` to make the decision audit-able.
51_PORTABILITY_ROOTS_ENV = "ASTROCYTE_PORTABILITY_ROOTS"
53# Null byte and ASCII control characters never have a legitimate place in
54# a filesystem path. Reject them up front; resolve() does NOT strip them.
55_ILLEGAL_PATH_CHAR_ORDS = frozenset(range(0x00, 0x20)) | {0x7F}
58def _portability_roots() -> list[Path]:
59 """Read containment roots from the environment."""
60 raw = os.environ.get(_PORTABILITY_ROOTS_ENV, "")
61 return [Path(p).expanduser().resolve() for p in raw.split(os.pathsep) if p]
64def _safe_resolve(
65 path: str | Path,
66 *,
67 allowed_roots: list[str | Path] | None = None,
68 allow_uncontained: bool = False,
69) -> Path:
70 """Resolve ``path`` and verify it stays within an allowed root.
72 See module docstring for the allow-list resolution order and the
73 ``allow_uncontained`` opt-in semantics.
75 Raises:
76 ValueError: If the path contains illegal control characters,
77 escapes every allowed root, or no containment is configured
78 and the caller did not pass ``allow_uncontained=True``.
79 """
80 path_str = os.fspath(path)
81 if any(ord(c) in _ILLEGAL_PATH_CHAR_ORDS for c in path_str):
82 raise ValueError(f"Portability path contains illegal control character: {path_str!r}")
83 # ``path_str`` is user-controlled, but the taint is neutralized by
84 # the allow-list containment check below — ``resolved`` is matched
85 # against ``allowed_roots`` (or the env-configured
86 # ``_portability_roots()``) before the function returns. Callers
87 # cannot opt out without passing ``allow_uncontained=True``
88 # explicitly. CodeQL's taint tracker doesn't see through allow-list
89 # logic, so this whole file is in CodeQL's ``paths-ignore`` (see
90 # ``.github/codeql/codeql-config.yml``). Threat model is locked by
91 # ``tests/test_portability.py::TestPathContainment``.
92 resolved = Path(path_str).expanduser().resolve()
93 roots: list[Path]
94 if allowed_roots:
95 roots = [Path(r).expanduser().resolve() for r in allowed_roots]
96 else:
97 roots = _portability_roots()
98 if not roots:
99 if not allow_uncontained:
100 raise ValueError(
101 "Portability path containment is required. Provide one of:\n"
102 " - allowed_roots=[<dir>, ...] kwarg, OR\n"
103 f" - {_PORTABILITY_ROOTS_ENV} environment variable "
104 "(os.pathsep-joined directories), OR\n"
105 " - allow_uncontained=True for trusted internal callers."
106 )
107 return resolved
108 for root in roots:
109 if resolved == root or resolved.is_relative_to(root):
110 return resolved
111 raise ValueError(f"Portability path escapes allowed roots: {resolved!s} not in {[str(r) for r in roots]}")
114# ---------------------------------------------------------------------------
115# AMA header
116# ---------------------------------------------------------------------------
118AMA_VERSION = 1
121@dataclass
122class AmaHeader:
123 """First line of an AMA file."""
125 bank_id: str
126 exported_at: str # ISO 8601
127 provider: str
128 memory_count: int
129 _ama_version: int = AMA_VERSION
132@dataclass
133class AmaMemory:
134 """One memory line in an AMA file."""
136 id: str
137 text: str
138 fact_type: str | None = None
139 tags: list[str] | None = None
140 metadata: Metadata | None = None
141 occurred_at: str | None = None # ISO 8601
142 created_at: str | None = None # ISO 8601
143 source: str | None = None
144 bank_id: str | None = None
145 entities: list[dict[str, str | list[str]]] | None = None
146 embedding: list[float] | None = None
149# ---------------------------------------------------------------------------
150# Writer — export a bank to AMA JSONL
151# ---------------------------------------------------------------------------
154async def export_bank(
155 recall_fn,
156 bank_id: str,
157 path: str | Path,
158 provider_name: str = "unknown",
159 include_embeddings: bool = False,
160 include_entities: bool = True,
161 batch_size: int = 100,
162 *,
163 allowed_roots: list[str | Path] | None = None,
164 allow_uncontained: bool = False,
165) -> int:
166 """Export a memory bank to AMA JSONL format.
168 Args:
169 recall_fn: Async callable that takes a RecallRequest and returns RecallResult.
170 Typically ``brain._do_recall``.
171 bank_id: Bank to export.
172 path: Output file path.
173 provider_name: Provider identifier for the header.
174 include_embeddings: Include vector embeddings (not portable across models).
175 include_entities: Include extracted entities.
176 batch_size: Number of memories per recall batch.
177 allowed_roots: Optional list of directory roots; the resolved
178 ``path`` must fall under one of them. When ``None``, falls
179 back to ``ASTROCYTE_PORTABILITY_ROOTS`` env var.
180 allow_uncontained: When True, skip path containment if neither
181 ``allowed_roots`` nor the env var is set. Use only for
182 trusted internal callers — the default ``False`` raises if
183 no containment is configured. See ``_safe_resolve``.
185 Returns:
186 Number of memories exported.
187 """
188 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained)
189 path.parent.mkdir(parents=True, exist_ok=True)
191 # Collect all memories via recall with large limit
192 all_hits: list[MemoryHit] = []
193 offset = 0
194 while True:
195 result: RecallResult = await recall_fn(
196 RecallRequest(
197 query="*", # Wildcard — retrieve everything
198 bank_id=bank_id,
199 max_results=batch_size,
200 )
201 )
202 if not result.hits:
203 break
204 all_hits.extend(result.hits)
205 # If we got fewer than batch_size, we've exhausted the bank
206 if len(result.hits) < batch_size:
207 break
208 offset += batch_size
209 # Safety: prevent infinite loops for providers that always return results
210 if offset > 100000:
211 break
213 # Deduplicate by memory_id
214 seen: set[str] = set()
215 unique_hits: list[MemoryHit] = []
216 for hit in all_hits:
217 key = hit.memory_id or hit.text
218 if key not in seen:
219 seen.add(key)
220 unique_hits.append(hit)
222 # Write AMA file
223 now = datetime.now(timezone.utc).isoformat()
224 header = {
225 "_ama_version": AMA_VERSION,
226 "bank_id": bank_id,
227 "exported_at": now,
228 "provider": provider_name,
229 "memory_count": len(unique_hits),
230 }
232 with open(path, "w", encoding="utf-8") as f:
233 f.write(json.dumps(header, default=str) + "\n")
234 for hit in unique_hits:
235 record: dict = {
236 "id": hit.memory_id or "",
237 "text": hit.text,
238 }
239 if hit.fact_type:
240 record["fact_type"] = hit.fact_type
241 if hit.tags:
242 record["tags"] = hit.tags
243 if hit.metadata:
244 record["metadata"] = hit.metadata
245 if hit.occurred_at:
246 record["occurred_at"] = hit.occurred_at.isoformat()
247 if hit.source:
248 record["source"] = hit.source
249 if hit.bank_id:
250 record["bank_id"] = hit.bank_id
251 # Embeddings and entities would come from provider-specific data
252 # For Phase 1, we export what's available in MemoryHit
253 f.write(json.dumps(record, default=str) + "\n")
255 return len(unique_hits)
258# ---------------------------------------------------------------------------
259# Reader — iterate AMA JSONL lines
260# ---------------------------------------------------------------------------
263def read_ama_header(
264 path: str | Path,
265 *,
266 allowed_roots: list[str | Path] | None = None,
267 allow_uncontained: bool = False,
268) -> AmaHeader:
269 """Read and validate the AMA header (first line).
271 See ``export_bank`` for ``allowed_roots`` and ``allow_uncontained`` semantics.
272 """
273 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained)
274 with open(path, encoding="utf-8") as f:
275 first_line = f.readline().strip()
276 if not first_line:
277 raise ValueError(f"AMA file is empty: {path}")
278 data = json.loads(first_line)
279 if "_ama_version" not in data:
280 raise ValueError(f"Not a valid AMA file (missing _ama_version): {path}")
281 if data["_ama_version"] != AMA_VERSION:
282 raise ValueError(f"Unsupported AMA version {data['_ama_version']} (expected {AMA_VERSION})")
283 # Validate required field types
284 for field in ("bank_id", "exported_at", "provider"):
285 if not isinstance(data.get(field), str):
286 raise ValueError(f"AMA header field '{field}' must be a string: {path}")
287 if not isinstance(data.get("memory_count"), int):
288 raise ValueError(f"AMA header field 'memory_count' must be an integer: {path}")
289 return AmaHeader(
290 bank_id=data["bank_id"],
291 exported_at=data["exported_at"],
292 provider=data["provider"],
293 memory_count=data["memory_count"],
294 _ama_version=data["_ama_version"],
295 )
298def iter_ama_memories(
299 path: str | Path,
300 *,
301 allowed_roots: list[str | Path] | None = None,
302 allow_uncontained: bool = False,
303) -> list[AmaMemory]:
304 """Read all memory records from an AMA file (skips header).
306 See ``export_bank`` for ``allowed_roots`` and ``allow_uncontained`` semantics.
307 """
308 path = _safe_resolve(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained)
309 memories: list[AmaMemory] = []
310 with open(path, encoding="utf-8") as f:
311 # Skip header
312 f.readline()
313 for line_num, line in enumerate(f, start=2):
314 line = line.strip()
315 if not line:
316 continue
317 try:
318 data = json.loads(line)
319 if not isinstance(data, dict) or "id" not in data or "text" not in data:
320 logger.warning("AMA line %d: missing required fields (id, text)", line_num)
321 continue
322 memories.append(
323 AmaMemory(
324 id=data["id"],
325 text=data["text"],
326 fact_type=data.get("fact_type"),
327 tags=data.get("tags"),
328 metadata=data.get("metadata"),
329 occurred_at=data.get("occurred_at"),
330 created_at=data.get("created_at"),
331 source=data.get("source"),
332 bank_id=data.get("bank_id"),
333 entities=data.get("entities"),
334 embedding=data.get("embedding"),
335 )
336 )
337 except (json.JSONDecodeError, KeyError, TypeError) as exc:
338 logger.warning("AMA line %d: %s", line_num, exc)
339 continue
340 return memories
343# ---------------------------------------------------------------------------
344# Import — load AMA into a bank
345# ---------------------------------------------------------------------------
348@dataclass
349class ImportResult:
350 imported: int
351 skipped: int
352 errors: int
355async def import_bank(
356 retain_fn,
357 bank_id: str,
358 path: str | Path,
359 on_conflict: Literal["skip", "overwrite", "error"] = "skip",
360 progress_fn=None,
361 *,
362 allowed_roots: list[str | Path] | None = None,
363 allow_uncontained: bool = False,
364) -> ImportResult:
365 """Import memories from an AMA file into a bank.
367 Args:
368 retain_fn: Async callable that takes a RetainRequest and returns RetainResult.
369 Typically ``brain._do_retain``.
370 bank_id: Target bank (may differ from source bank in AMA).
371 path: Path to AMA JSONL file.
372 on_conflict: How to handle memories with IDs that already exist.
373 progress_fn: Optional callback(imported, total) for progress reporting.
374 allowed_roots: See ``export_bank``.
375 allow_uncontained: See ``export_bank``.
377 Returns:
378 ImportResult with counts.
379 """
380 header = read_ama_header(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained)
381 memories = iter_ama_memories(path, allowed_roots=allowed_roots, allow_uncontained=allow_uncontained)
383 imported = 0
384 skipped = 0
385 errors = 0
387 for i, mem in enumerate(memories):
388 try:
389 # Parse occurred_at if present
390 occurred_at = None
391 if mem.occurred_at:
392 try:
393 occurred_at = datetime.fromisoformat(mem.occurred_at)
394 except ValueError:
395 logger.debug("Skipping unparseable occurred_at: %s", mem.occurred_at)
397 request = RetainRequest(
398 content=mem.text,
399 bank_id=bank_id,
400 metadata=mem.metadata,
401 tags=mem.tags,
402 occurred_at=occurred_at,
403 source=mem.source or f"import:ama:{header.provider}",
404 content_type="text",
405 )
407 result = await retain_fn(request)
409 if result.stored:
410 imported += 1
411 elif result.deduplicated and on_conflict == "skip":
412 skipped += 1
413 elif result.deduplicated and on_conflict == "error":
414 errors += 1
415 else:
416 skipped += 1
418 except Exception as exc:
419 logger.warning("AMA import line %d failed: %s", i + 2, exc)
420 errors += 1
422 if progress_fn and (i + 1) % 10 == 0:
423 progress_fn(imported, header.memory_count)
425 return ImportResult(imported=imported, skipped=skipped, errors=errors)