Coverage for astrocyte/pipeline/orchestrator.py: 74%
1117 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Pipeline orchestrator — coordinates Tier 1 retain/recall/reflect flows.
3Async (coordinates I/O stages). See docs/_design/built-in-pipeline.md.
4"""
6from __future__ import annotations
8import asyncio
9import hashlib
10import inspect
11import json
12import logging
13import os
14import re
15import time
16import uuid
17from collections import defaultdict
18from contextlib import asynccontextmanager
19from datetime import datetime, timezone
20from typing import TYPE_CHECKING, Any
22from astrocyte.config import ExtractionProfileConfig, RecallAuthorityConfig
23from astrocyte.pipeline.agentic_reflect import AgenticReflectParams
24from astrocyte.pipeline.chunking import DEFAULT_CHUNK_SIZE, chunk_text
25from astrocyte.pipeline.cross_encoder_rerank import (
26 CrossEncoderProtocol,
27 cross_encoder_rerank,
28)
29from astrocyte.pipeline.embedding import generate_embeddings
30from astrocyte.pipeline.entity_extraction import extract_entities
31from astrocyte.pipeline.extraction import (
32 merged_user_and_builtin_profiles,
33 prepare_retain_input,
34 resolve_retain_chunking,
35)
36from astrocyte.pipeline.fusion import (
37 DEFAULT_RRF_K,
38 ScoredItem,
39 memory_hits_as_scored,
40 rrf_fusion,
41 weighted_rrf_fusion,
42)
43from astrocyte.pipeline.hyde import generate_hyde_vector
44from astrocyte.pipeline.link_expansion import LinkExpansionParams, link_expansion
45from astrocyte.pipeline.query_intent import (
46 QueryIntent,
47 classify_query_intent,
48 weights_for_intent,
49)
50from astrocyte.pipeline.query_plan import build_query_plan
51from astrocyte.pipeline.reflect import synthesize
52from astrocyte.pipeline.reranking import (
53 apply_context_diversity,
54 basic_rerank,
55 cross_encoder_like_rerank,
56 llm_pairwise_rerank,
57)
58from astrocyte.pipeline.retrieval import parallel_retrieve
59from astrocyte.policy.homeostasis import enforce_token_budget
60from astrocyte.policy.signal_quality import DedupDetector
61from astrocyte.recall.authority import apply_recall_authority, merge_retain_metadata_authority_tier
62from astrocyte.types import (
63 Completion,
64 Entity,
65 EntityLink,
66 MemoryEntityAssociation,
67 MemoryHit,
68 MemoryLink,
69 Message,
70 RecallRequest,
71 RecallResult,
72 RecallTrace,
73 ReflectRequest,
74 ReflectResult,
75 RetainRequest,
76 RetainResult,
77 VectorFilters,
78 VectorItem,
79)
81if TYPE_CHECKING:
82 from astrocyte.mip.router import MipRouter
83 from astrocyte.mip.schema import PipelineSpec
84 from astrocyte.pipeline.entity_resolution import EntityResolver
85 from astrocyte.provider import DocumentStore, GraphStore, LLMProvider, VectorStore, WikiStore
88_logger = logging.getLogger("astrocyte.mip")
91def _warn_on_version_drift(
92 bank_pipeline: PipelineSpec | None,
93 hits: list[MemoryHit],
94 bank_id: str,
95) -> None:
96 """Emit a single warning when retrieved hits were retained under a different MIP version.
98 Soft signal — does not affect recall results. Hits without a persisted
99 ``_mip.pipeline_version`` are ignored (they were retained before MIP, by a
100 rule with no version, or by a different rule).
101 """
102 if bank_pipeline is None or bank_pipeline.version is None:
103 return
104 current_version = int(bank_pipeline.version)
105 seen: set[int] = set()
106 for hit in hits:
107 if not hit.metadata:
108 continue
109 v = hit.metadata.get("_mip.pipeline_version")
110 if v is None:
111 continue
112 try:
113 v_int = int(v)
114 except (TypeError, ValueError):
115 continue
116 if v_int != current_version:
117 seen.add(v_int)
118 if seen:
119 _logger.warning(
120 "MIP pipeline version drift in bank %r: current=%d, hits retained under versions %s. "
121 "Consider re-indexing or accepting the drift.",
122 bank_id,
123 current_version,
124 sorted(seen),
125 )
128def _abstention_floor_for_skepticism(skepticism: int, base_floor: float) -> float:
129 """Scale the configured abstention floor by a bank/call's
130 disposition skepticism (1–5).
132 Same primitive Hindsight uses for per-bank disposition behavior:
133 a single deployment can serve adversarial-resistant agents AND
134 trust-the-model assistants without forking the YAML config. The
135 bench harnesses use this to express "LME wants no abstention,
136 LoCoMo wants aggressive abstention" via per-call ``dispositions``,
137 not split configs.
139 Mapping (linear around the legacy default of skepticism=3):
141 - ``skepticism=1`` → ``0.0`` — never abstain. An "answer-everything"
142 assistant. Replicates the legacy ``abstention_enabled=False`` knob.
143 - ``skepticism=2`` → ``base_floor * 0.5`` — trusting, abstain only
144 on extremely weak retrieval.
145 - ``skepticism=3`` → ``base_floor`` — legacy default. Replicates
146 the legacy ``abstention_enabled=True`` behaviour.
147 - ``skepticism=4`` → ``base_floor * 1.5`` — moderately skeptical.
148 - ``skepticism=5`` → ``base_floor * 2.0`` (capped at 1.0) —
149 aggressive abstention. Use for adversarial-bucket evaluation
150 where false answers cost more than abstentions.
151 """
152 if skepticism <= 1:
153 return 0.0
154 return min(1.0, base_floor * 0.5 * (skepticism - 1))
157def _build_cooccurrence_pairs(
158 entity_ids: list[str],
159 max_entities: int,
160) -> list[tuple[str, str]]:
161 """Return the (a, b) pairs to create ``co_occurs`` links for.
163 Caps the input at ``max_entities`` (taking the head of the list,
164 which preserves extraction order — typically tracks prominence in
165 the source text), then emits the all-pairs Cartesian product over
166 that subset. ``C(K,2)`` pairs at most. Returns ``[]`` for trivial
167 inputs (≤1 entity) so the caller can skip the storage call entirely.
169 Extracted as a pure helper so the cap is testable without standing
170 up the full retain pipeline. The 2026-05-06 LME profile measured
171 the unbounded path at 34% of retain wall — capping here is the
172 targeted fix.
173 """
174 if len(entity_ids) <= 1:
175 return []
176 k = max(2, max_entities)
177 head = entity_ids[:k]
178 return [(head[i], head[j]) for i in range(len(head)) for j in range(i + 1, len(head))]
181def _resolve_skepticism_for_abstention(
182 request_dispositions,
183 fallback_enabled: bool,
184) -> int:
185 """Resolve effective skepticism for the abstention decision.
187 Precedence:
188 1. ``request_dispositions.skepticism`` (per-call override)
189 2. Backward compat: legacy
190 ``adversarial_defense.abstention_enabled`` bool maps to
191 skepticism=3 when True (legacy default behaviour) or
192 skepticism=1 when False (never abstain). New code should pass
193 ``dispositions`` per-call instead of toggling the bool.
194 """
195 if request_dispositions is not None:
196 return int(getattr(request_dispositions, "skepticism", 3))
197 return 3 if fallback_enabled else 1
200class _RetainProfiler:
201 """Aggregate per-stage timings during retain for evidence-driven
202 bottleneck identification.
204 Disabled unless ``ASTROCYTE_RETAIN_PROFILE=1`` (env var) — the cost
205 of the time.monotonic() calls is small but non-zero and we don't
206 want it on by default in production. When enabled, the orchestrator
207 captures wall time for each suspected hot path (SFE LLM, embedding
208 generation, vector insert, entity merge, entity resolution) and
209 emits an aggregated p50/p95/max breakdown via :meth:`report`.
211 The samples accumulate across the lifetime of the orchestrator so a
212 single bench run produces one breakdown that covers all retain
213 calls. Call :meth:`reset` between batches if you want per-batch
214 isolation.
216 Why not Prometheus / OTel: those exist (``observability.*`` config),
217 but require running collectors and a separate analysis stack just
218 to find a bottleneck. This is a stop-gap for dev/bench
219 investigation — write data once, print at end of run, done.
220 """
222 def __init__(self) -> None:
223 self.enabled = os.environ.get("ASTROCYTE_RETAIN_PROFILE") == "1"
224 self.samples: dict[str, list[float]] = defaultdict(list)
226 @asynccontextmanager
227 async def time(self, stage: str):
228 """Async context manager that records elapsed wall time (ms)
229 under ``stage`` if profiling is enabled. No-op otherwise so the
230 production hot path stays clean."""
231 if not self.enabled:
232 yield
233 return
234 t0 = time.monotonic()
235 try:
236 yield
237 finally:
238 self.samples[stage].append((time.monotonic() - t0) * 1000.0)
240 def reset(self) -> None:
241 self.samples.clear()
243 def report(self, prefix: str = "[retain.profile]") -> None:
244 """Print p50/p95/max for every recorded stage, ordered by total
245 time (descending). The dominant stage is first — that's what
246 you want to optimize.
248 Uses ``print`` rather than the module logger so the breakdown
249 always reaches stdout regardless of how the caller configured
250 logging — this is dev/bench tooling output, not production
251 telemetry, and we want it to be impossible to lose."""
252 if not self.enabled:
253 print(f"{prefix} (profiler disabled — set ASTROCYTE_RETAIN_PROFILE=1)")
254 return
255 if not self.samples:
256 print(f"{prefix} (profiler enabled but captured no samples — instrumentation unwired?)")
257 return
258 import statistics
260 rows: list[tuple[str, int, float, float, float, float]] = []
261 for stage, samples in self.samples.items():
262 if not samples:
263 continue
264 rows.append(
265 (
266 stage,
267 len(samples),
268 sum(samples),
269 statistics.median(samples),
270 statistics.quantiles(samples, n=20)[18] if len(samples) >= 20 else max(samples),
271 max(samples),
272 ),
273 )
274 # Sort by total descending — biggest cost first.
275 rows.sort(key=lambda r: r[2], reverse=True)
276 print(f"{prefix} aggregate breakdown (sorted by total wall time):")
277 print(
278 f"{prefix} {'stage':<22} {'n':<7} {'total_ms':<12} {'p50_ms':<10} {'p95_ms':<10} {'max_ms':<10}",
279 )
280 for stage, n, total, p50, p95, mx in rows:
281 print(
282 f"{prefix} {stage:<22} {n:<7d} {total:<12.0f} {p50:<10.1f} {p95:<10.1f} {mx:<10.1f}",
283 )
286class _TrackingLLMProvider:
287 """Transparent wrapper that accumulates token usage from an LLMProvider.
289 All calls are forwarded to the underlying provider. After each ``complete()``
290 call, ``Completion.usage`` (if present) is added to ``tokens_used``.
291 Embedding calls do not consume completion tokens and are forwarded as-is.
293 This is used internally by :class:`PipelineOrchestrator` so that the
294 evaluation framework can report ``total_tokens_used`` per run without
295 modifying individual pipeline modules. See ``evaluation.md`` §2.2.
296 """
298 SPI_VERSION: int = 1
300 def __init__(self, inner: LLMProvider) -> None:
301 self._inner = inner
302 self.tokens_used: int = 0
304 async def complete(
305 self,
306 messages: list[Message],
307 model: str | None = None,
308 max_tokens: int = 1024,
309 temperature: float = 0.0,
310 tools: list | None = None, # list[ToolDefinition] — kept loose to avoid import cycle
311 tool_choice: str | None = None,
312 response_format: dict | None = None,
313 ) -> Completion:
314 # Forward native function-calling kwargs through to the underlying
315 # provider so the agentic reflect loop (Hindsight parity) can use
316 # ``tools=``/``tool_choice=`` end-to-end. Without this, the loop
317 # silently fell back to forced single-shot synthesis on every
318 # call — observed in 1986/1986 questions on the 2026-05-01 bench.
319 #
320 # Backward compat: only thread the new kwargs when the caller
321 # actually supplied them. Legacy providers / test fakes whose
322 # ``complete()`` signatures predate the tools/response_format
323 # extensions keep working when invoked via the old text-only
324 # path.
325 extra_kwargs: dict = {}
326 if tools is not None:
327 extra_kwargs["tools"] = tools
328 if tool_choice is not None:
329 extra_kwargs["tool_choice"] = tool_choice
330 if response_format is not None:
331 extra_kwargs["response_format"] = response_format
332 result = await self._inner.complete(
333 messages,
334 model=model,
335 max_tokens=max_tokens,
336 temperature=temperature,
337 **extra_kwargs,
338 )
339 if result.usage:
340 self.tokens_used += result.usage.input_tokens + result.usage.output_tokens
341 # Tier-3 observability: record cost + per-phase tokens when a
342 # benchmark collector is attached. Silent no-op otherwise.
343 collector = getattr(self, "_metrics_collector", None)
344 if collector is not None:
345 try:
346 collector.record_completion_call(
347 model=getattr(result, "model", None) or "",
348 input_tokens=result.usage.input_tokens,
349 output_tokens=result.usage.output_tokens,
350 )
351 except Exception:
352 pass # never let metrics break a real call
353 return result
355 async def embed(
356 self,
357 texts: list[str],
358 model: str | None = None,
359 ) -> list[list[float]]:
360 result = await self._inner.embed(texts, model=model)
361 # Embeddings don't return usage objects; estimate input tokens from
362 # the rough heuristic of ~1 token per 4 chars per text (conservative).
363 collector = getattr(self, "_metrics_collector", None)
364 if collector is not None:
365 try:
366 est_tokens = sum(max(1, len(t) // 4) for t in texts)
367 collector.record_embedding_call(
368 model=model or "text-embedding-3-small",
369 tokens=est_tokens,
370 )
371 except Exception:
372 pass # metrics are best-effort; never block the embedding call
373 return result
375 def capabilities(self) -> Any:
376 return self._inner.capabilities()
378 def reset_tokens(self) -> int:
379 """Return accumulated tokens and reset counter to zero."""
380 total = self.tokens_used
381 self.tokens_used = 0
382 return total
385class PipelineOrchestrator:
386 """Orchestrates the Tier 1 built-in intelligence pipeline.
388 Coordinates async stages: chunk → embed → store → retrieve → fuse → rerank.
389 """
391 def __init__(
392 self,
393 vector_store: VectorStore,
394 llm_provider: LLMProvider,
395 graph_store: GraphStore | None = None,
396 document_store: DocumentStore | None = None,
397 chunk_strategy: str = "sentence",
398 max_chunk_size: int = DEFAULT_CHUNK_SIZE,
399 rrf_k: int = DEFAULT_RRF_K,
400 semantic_overfetch: int = 5,
401 extraction_profiles: dict[str, ExtractionProfileConfig] | None = None,
402 *,
403 enable_temporal_retrieval: bool = True,
404 temporal_scan_cap: int = 500,
405 temporal_half_life_days: float = 7.0,
406 enable_intent_aware_recall: bool = True,
407 enable_multi_query_expansion: bool = True,
408 enable_hyde: bool = False,
409 wiki_store: WikiStore | None = None,
410 wiki_confidence_threshold: float = 0.7,
411 entity_resolver: EntityResolver | None = None,
412 enable_observation_consolidation: bool = True,
413 observation_weight: float = 0.0,
414 observation_injection_weight: float = 1.5,
415 multi_query_confidence_threshold: float = 0.72,
416 final_rerank_mode: str = "heuristic",
417 final_rerank_top_n: int = 30,
418 final_rerank_keep_n: int | None = None,
419 ) -> None:
420 self.vector_store = vector_store
421 self._tracker = _TrackingLLMProvider(llm_provider)
422 self.llm_provider: LLMProvider = self._tracker # type: ignore[assignment]
423 self.graph_store = graph_store
424 self.document_store = document_store
425 self.chunk_strategy = chunk_strategy
426 self.max_chunk_size = max_chunk_size
427 self.extraction_profiles = extraction_profiles
428 self.rrf_k = rrf_k
429 self.semantic_overfetch = semantic_overfetch
430 # Temporal retrieval knobs — see astrocyte.pipeline.retrieval for full
431 # semantics. Enabled by default: the strategy no-ops gracefully on
432 # banks whose vectors have no timestamps, so enabling it is safe.
433 self.enable_temporal_retrieval = enable_temporal_retrieval
434 self.temporal_scan_cap = temporal_scan_cap
435 self.temporal_half_life_days = temporal_half_life_days
436 # M9 BM25-IDF keyword strategy. When True (and the document store
437 # advertises ``search_fulltext_bm25``), the keyword leg routes
438 # through the materialized-view path with corpus IDF + length
439 # normalisation instead of the classic ``ts_rank_cd``. Wired by
440 # ``Astrocyte.set_pipeline`` from ``bm25_idf.enabled`` config.
441 self.bm25_idf_enabled: bool = False
442 #: M10 source-aware retain + recall. ``source_store`` is the
443 #: SourceStore instance (or ``None``); the three flags below
444 #: control behaviour. All wired by ``Astrocyte.set_pipeline`` from
445 #: the ``source_aware_retrieval`` config block.
446 self.source_store: object | None = None
447 self.source_retain_provenance: bool = False
448 self.source_chunk_expansion: bool = False
449 self.source_expansion_score_multiplier: float = 0.5
450 self.source_expansion_max_per_hit: int = 4
451 # Intent-aware recall: heuristic query classifier biases RRF
452 # weights per strategy. Conservative (always fuses all strategies
453 # even under bias), so enabling is safe — a misclassification
454 # degrades to a soft lean rather than a strategy drop.
455 self.enable_intent_aware_recall = enable_intent_aware_recall
456 # Multi-query expansion: decompose complex questions into sub-questions,
457 # recall for each independently, and merge via RRF. Improves multi-hop
458 # coverage at the cost of N-1 extra embedding + retrieval passes per query.
459 # Disabled by default; enable for multi-hop-heavy workloads.
460 # When enabled, the confidence gate (multi_query_confidence_threshold)
461 # suppresses expansion when the top raw semantic score already exceeds the
462 # threshold — avoiding costly sub-query decomposition when a direct answer
463 # is already retrieved with high confidence. Cosine similarity is used as
464 # the gate signal (not RRF scores, which are rank-based and not comparable
465 # across query runs). Default threshold 0.72 passes ~20-30% of queries.
466 self.enable_multi_query_expansion = enable_multi_query_expansion
467 self.multi_query_confidence_threshold: float = multi_query_confidence_threshold
468 # HyDE (R1): generate a hypothetical answer, embed it, and run a second
469 # semantic search pass with that vector. Disabled by default — adds one
470 # LLM call per recall. Enable for multi-hop / paraphrase-heavy workloads.
471 self.enable_hyde = enable_hyde
472 self._dedup = DedupDetector(similarity_threshold=0.95)
473 # Forget-cache invalidation: ``Astrocyte.forget`` calls this hook so
474 # the in-memory dedup cache doesn't keep matching against memories
475 # that are gone from the vector store. Without this, re-retain after
476 # forget silently returns ``stored=False, error="All chunks are
477 # near-duplicates"`` for similar content (see invalidate_dedup_cache
478 # below).
479 #: Per-stage retain timing aggregator. No-op unless
480 #: ``ASTROCYTE_RETAIN_PROFILE=1`` is set in the environment;
481 #: when enabled, retain_many wraps suspect call sites and the
482 #: caller can inspect the breakdown via ``profiler.report()``
483 #: or read raw samples from ``profiler.samples``.
484 self._profiler = _RetainProfiler()
485 #: Set by :meth:`astrocyte._astrocyte.Astrocyte.set_pipeline` when ``recall_authority`` is configured.
486 self.recall_authority: RecallAuthorityConfig | None = None
487 #: Set by :meth:`astrocyte._astrocyte.Astrocyte.set_pipeline` when
488 #: ``cross_encoder_rerank.enabled`` is true. ``None`` falls back to
489 #: the heuristic ``cross_encoder_like_rerank`` in ``_rank_reflect_context``.
490 self.cross_encoder: CrossEncoderProtocol | None = None
491 #: When ``cross_encoder`` is set, only the first ``cross_encoder_top_k``
492 #: candidates are scored to bound inference cost. Default mirrors the
493 #: config default (30).
494 self.cross_encoder_top_k: int = 30
495 #: Set by :meth:`astrocyte._astrocyte.Astrocyte.set_pipeline` when
496 #: ``link_expansion.enabled`` is true. ``None`` skips the
497 #: post-fusion 3-signal expansion. Replaced the old BFS-hop
498 #: ``spreading_activation_params`` per Hindsight C3 rewrite.
499 self.link_expansion_params: LinkExpansionParams | None = None
500 #: Adversarial-defense score-floor abstention. When the top
501 #: recall hit's score is below this floor, reflect short-circuits
502 #: to "insufficient evidence" without invoking the LLM. Targets
503 #: adversarial questions where the LLM left to its own devices
504 #: would hallucinate an answer from weak hits. Wired by
505 #: ``Astrocyte.set_pipeline`` from ``adversarial_defense`` config.
506 self.adversarial_abstention_enabled: bool = False
507 self.adversarial_abstention_floor: float = 0.2
508 #: Pre-loop premise verification — decompose question into atomic
509 #: claims, verify each. Wired below in the reflect path.
510 self.adversarial_premise_verification_enabled: bool = False
511 self.adversarial_premise_min_confidence: float = 0.6
512 #: Tighten the agentic-reflect system prompt with explicit
513 #: adversarial-defense rules ("insufficient evidence is always
514 #: a valid answer", premise-check, etc.).
515 self.adversarial_prompt_enabled: bool = False
516 #: Hindsight-parity causal-link extraction at retain time. When
517 #: enabled, one extra LLM call per record produces directional
518 #: ``causes`` edges. Wired by ``Astrocyte.set_pipeline``.
519 self.causal_links_enabled: bool = False
520 self.causal_max_pairs_per_memory: int = 4
521 self.causal_min_confidence: float = 0.7
522 #: Hindsight-parity semantic-kNN graph (C3a). When enabled, each
523 #: new memory at retain time gets ``MemoryLink(link_type="semantic")``
524 #: edges to its top-K most-similar existing memories above the
525 #: similarity threshold. Wired by ``Astrocyte.set_pipeline``.
526 self.semantic_link_graph_enabled: bool = False
527 self.semantic_link_graph_top_k: int = 5
528 self.semantic_link_graph_threshold: float = 0.7
529 #: Structured fact extraction at retain time (5-dim
530 #: what/when/where/who/why with embedded entities + caused_by
531 #: relations). When enabled, replaces chunk_text +
532 #: extract_entities + fact_causal_extraction with a single
533 #: LLM call producing structured facts. Each fact becomes one
534 #: memory.
535 self.structured_fact_extraction_enabled: bool = False
536 self.structured_fact_extraction_max_facts: int = 30
537 #: Mode: "verbatim" stores raw chunk text + metadata (preserves
538 #: vocabulary for embedding-match), "concise" stores LLM-paraphrased
539 #: structured facts. Default verbatim because of the recall_hit_rate
540 #: regression that concise paraphrasing causes.
541 self.structured_fact_extraction_mode: str = "verbatim"
542 #: Chunking strategy used by verbatim SFE pre-chunking. Defaults
543 #: to "paragraph" (which gives the LLM full session context for
544 #: metadata extraction); LoCoMo benchmark losses 2.5 pts overall
545 #: when set to "dialogue".
546 self.structured_fact_extraction_chunk_strategy: str = "paragraph"
547 #: Per-chunk character budget for verbatim SFE pre-chunking.
548 #: When ``None`` the SFE path falls through to the same
549 #: chunk-size resolver legacy retain uses (orchestrator default
550 #: 512 chars). Bumping this is the dominant retain throughput
551 #: lever for SFE — fewer chunks = fewer ``facts[]`` entries the
552 #: LLM has to emit per session, and gpt-4o-mini latency is
553 #: roughly linear in output tokens. 2048 measured ~4× on LME-
554 #: shaped traffic without accuracy regression on LoCoMo.
555 self.structured_fact_extraction_chunk_max_size: int | None = None
556 #: Per-chunk parallel verbatim extraction (Phase 3 of cost-
557 #: control port). When True, the SFE path sends one LLM call
558 #: per chunk in parallel rather than one batched call. Drops
559 #: cross-chunk causal_relations — pair with
560 #: ``causal_links.enabled=false``.
561 self.structured_fact_extraction_parallel_chunks: bool = False
562 #: Max in-flight LLM calls per session when
563 #: ``parallel_chunks`` is True.
564 self.structured_fact_extraction_parallel_chunks_max_concurrency: int = 6
565 #: Entity co-occurrence link cap (2026-05-06 retain-profile fix).
566 #: When ``enabled``, retain creates ``co_occurs`` links between
567 #: at most ``max_entities`` entities per memory — bounding the
568 #: Cartesian product to ``C(K,2)`` per retain regardless of N.
569 #: Profiling on LME measured the unbounded path at 34% of
570 #: retain wall with O(N²) drift; capping at K=5 brings it to
571 #: <1% steady state.
572 self.entity_cooccurrence_enabled: bool = True
573 self.entity_cooccurrence_max_entities: int = 5
574 #: Query-level temporal constraint extraction. When enabled,
575 #: recall parses temporal expressions in the query into a
576 #: time_range filter applied to retrieval. Regex pre-pass is
577 #: free; ``allow_llm_fallback`` opt-in adds 1 LLM call per
578 #: temporal-marker query.
579 self.query_analyzer_enabled: bool = False
580 self.query_analyzer_allow_llm_fallback: bool = True
581 #: M18a-1 — extended temporal-expansion pattern set in the regex
582 #: pre-pass (word-numbers, "a few X ago", "the other day", "this/
583 #: earlier this <unit>", "recently/lately"). Default False.
584 #: Flipped via per-bank config or env override
585 #: ``ASTROCYTE_M18_ENABLE_TEMPORAL_EXPANSION=1`` for bench runs.
586 self.query_analyzer_enable_temporal_expansion: bool = False
587 #: Hindsight-parity agentic reflect loop. ``None`` = single-shot
588 #: synthesis (legacy path). Set by ``Astrocyte.set_pipeline``
589 #: when ``agentic_reflect.enabled`` is true.
590 self.agentic_reflect_params: AgenticReflectParams | None = None
591 #: Set by :meth:`astrocyte._astrocyte.Astrocyte.set_pipeline` when MIP is configured.
592 #: Used by :meth:`recall` to resolve per-bank rerank/reflect overrides (P3).
593 self.mip_router: MipRouter | None = None
594 # M8 W5 — wiki tier precedence. When a WikiStore is wired up and a
595 # compiled wiki page scores above ``wiki_confidence_threshold``, recall
596 # returns the wiki hit + raw-memory citations instead of running the full
597 # parallel-retrieve / RRF pipeline.
598 self.wiki_store: WikiStore | None = wiki_store
599 self.wiki_confidence_threshold: float = wiki_confidence_threshold
600 # M11: entity resolution — alias-of links between entities.
601 # None means the stage is skipped (opt-in, no cost when disabled).
602 self.entity_resolver: EntityResolver | None = entity_resolver
603 # Observation consolidation — post-retain background LLM pass that
604 # maintains a deduplicated observations layer. Disabled by default
605 # (opt-in). When enabled, recall also runs a separate "observation"
606 # strategy that searches the observations layer and fuses results into
607 # the main RRF pipeline with a configurable weight boost.
608 self.enable_observation_consolidation: bool = enable_observation_consolidation
609 self.observation_weight: float = observation_weight
610 # Weight applied to the ::obs bank when intent-gated injection fires
611 # (EXPLORATORY / RELATIONAL queries). Kept separate from observation_weight
612 # so callers can disable global injection (observation_weight=0.0) while
613 # still enabling the intent-gated path.
614 self.observation_injection_weight: float = observation_injection_weight
615 self.final_rerank_mode = final_rerank_mode
616 self.final_rerank_top_n = final_rerank_top_n
617 self.final_rerank_keep_n = final_rerank_keep_n
618 if enable_observation_consolidation:
619 from astrocyte.pipeline.observation import ObservationConsolidator
621 self._observation_consolidator: ObservationConsolidator | None = ObservationConsolidator()
622 else:
623 self._observation_consolidator = None
624 self._background_tasks: set[asyncio.Task[None]] = set()
626 # Mental-model service — wires the agentic reflect loop to the
627 # configured ``MentalModelStore`` (typically ``PostgresMentalModelStore``).
628 # ``None`` when no store is configured; ``set_mental_model_service``
629 # is called from ``Astrocyte.set_pipeline`` when a store is present.
630 # When set, the agent gets ``search_mental_models`` as a tool —
631 # the highest-quality tier in the hierarchical priority order
632 # (mental_models → observations → recall → expand → done).
633 self.mental_model_service: object | None = None
635 @property
636 def tokens_used(self) -> int:
637 """Total LLM tokens consumed through this orchestrator since last reset."""
638 return self._tracker.tokens_used
640 def reset_token_counter(self) -> int:
641 """Return accumulated token count and reset to zero."""
642 return self._tracker.reset_tokens()
644 def invalidate_dedup_cache(
645 self,
646 bank_id: str,
647 memory_ids: list[str] | None = None,
648 ) -> None:
649 """Drop forgotten memories from the in-memory dedup cache.
651 Called by ``Astrocyte.forget`` after the underlying store has
652 accepted the forget. Without this, a re-retain of similar content
653 after forget hits the in-memory ``DedupDetector`` cache (still
654 holding the forgotten memory's embedding) and the pipeline
655 short-circuits with ``RetainResult(stored=False, deduplicated=True,
656 error="All chunks are near-duplicates")`` — a silent no-op from
657 the caller's perspective.
659 Args:
660 bank_id: The bank the forget targeted.
661 memory_ids: Specific memories to drop from the cache. ``None``
662 (the whole-bank or tag-filtered forget path) clears the
663 entire bank's cache — coarser but always-correct.
664 """
665 if memory_ids is None:
666 self._dedup.clear_bank(bank_id)
667 return
669 for mid in memory_ids:
670 self._dedup.remove(bank_id, mid)
672 async def _attach_entity_name_embeddings(self, entities: list[Entity]) -> None:
673 """Embed each entity's name and attach the vector to ``entity.embedding``.
675 Powers the Hindsight-inspired entity-resolution cascade — at retain
676 time we generate one batched embedding call across all entity names,
677 then attach the resulting vector to each entity. Both
678 ``store_entities`` (for adapters that persist embeddings) and
679 ``EntityResolver.resolve()`` (for the cosine-similarity tier of the
680 cascade) consume the attached vectors.
682 Skips entities whose ``embedding`` is already set (caller-supplied)
683 and skips entities with empty/whitespace names. On failure logs a
684 warning and leaves embeddings unset — the resolver degrades to
685 trigram-only and the cost is correctness for that one batch, not
686 a retain failure.
687 """
688 targets = [e for e in entities if e.embedding is None and e.name and e.name.strip()]
689 if not targets:
690 return
691 try:
692 vectors = await generate_embeddings(
693 [e.name.strip() for e in targets],
694 self.llm_provider,
695 )
696 except Exception as exc:
697 _logger.warning(
698 "entity name embedding failed (resolver will fall back to trigram-only tier): %s",
699 exc,
700 )
701 return
702 for entity, vec in zip(targets, vectors, strict=False):
703 entity.embedding = vec
705 async def _structured_fact_extraction_for_text(
706 self,
707 prepared,
708 request: RetainRequest,
709 *,
710 chunk_strategy: str = "paragraph",
711 chunk_max_size: int | None = None,
712 chunk_overlap: int | None = None,
713 ) -> tuple[
714 list[str] | None,
715 list[Entity] | None,
716 list[tuple[int, str]] | None, # (entity_idx_in_full_list, memory_idx) associations
717 list[tuple[int, int, float]] | None, # (source_memory_idx, target_memory_idx, confidence) caused_by
718 ]:
719 """Run the structured 5-dim fact extraction path.
721 Returns ``(fact_texts, entities, associations, caused_by)`` when
722 the path is enabled and produces facts; ``(None, None, None, None)``
723 otherwise. Caller falls through to the legacy chunk_text +
724 extract_entities path when None is returned.
726 Indices in ``associations`` and ``caused_by`` reference positions
727 in the returned ``fact_texts``/``entities`` lists. Callers
728 resolve them to memory IDs after store_vectors.
729 """
730 # SFE supersedes the legacy ``prepared.extract_entities`` gate.
731 # The legacy setting controlled which entity-extraction PATH
732 # ran (metadata vs LLM); SFE is a third path that produces
733 # both entities and structured facts in one call. When SFE is
734 # enabled, it fires regardless of the profile's entity
735 # extraction mode — the profile's metadata-entities are
736 # ignored in favor of the richer SFE output.
737 if not self.structured_fact_extraction_enabled or self.llm_provider is None:
738 return None, None, None, None
740 try:
741 # Pre-chunk using the SAME strategy the legacy retain path
742 # would use (resolved from the extraction profile), then ask
743 # the LLM to enrich each chunk with metadata WITHOUT
744 # paraphrasing. Stored memory text = original chunk
745 # vocabulary at the granularity the profile specifies
746 # (e.g. dialogue turns for conversations, paragraphs for
747 # documents).
748 #
749 # The legacy "concise" path (``extract_facts``) was removed
750 # in M9 — it caused severe recall_hit_rate degradation
751 # because the LLM-paraphrased ``what`` field lost the
752 # surface vocabulary that question embeddings share. Verbatim
753 # is now the only supported mode; ``validate_astrocyte_config``
754 # rejects ``extraction_mode: concise`` at config load.
755 #
756 # Open-domain regression observed on 2026-05-02 traced to
757 # hardcoded "paragraph" strategy producing one giant chunk
758 # per LoCoMo session (no paragraph breaks in dialogue text).
759 # Profile-driven strategy fixes that.
760 from astrocyte.pipeline.chunking import chunk_text
761 from astrocyte.pipeline.fact_extraction import (
762 extract_facts_verbatim,
763 extract_facts_verbatim_parallel,
764 materialize_facts,
765 )
767 pre_chunk_kwargs: dict[str, int] = {}
768 if chunk_max_size is not None:
769 pre_chunk_kwargs["max_chunk_size"] = chunk_max_size
770 if chunk_overlap is not None:
771 pre_chunk_kwargs["overlap"] = chunk_overlap
772 chunks_local = chunk_text(
773 prepared.text,
774 strategy=chunk_strategy,
775 **pre_chunk_kwargs,
776 )
777 # Phase 3: route to the per-chunk parallel path when opt-in.
778 # Drops cross-chunk causal_relations — pair with
779 # ``causal_links.enabled=false``.
780 if self.structured_fact_extraction_parallel_chunks:
781 facts = await extract_facts_verbatim_parallel(
782 chunks_local,
783 self.llm_provider,
784 event_date=request.occurred_at,
785 max_concurrency=(self.structured_fact_extraction_parallel_chunks_max_concurrency),
786 )
787 else:
788 facts = await extract_facts_verbatim(
789 chunks_local,
790 self.llm_provider,
791 event_date=request.occurred_at,
792 )
793 except Exception as exc:
794 _logger.warning(
795 "structured fact extraction failed (%s); falling back to legacy chunk + entity-extraction path.",
796 exc,
797 )
798 return None, None, None, None
799 if not facts:
800 return None, None, None, None
802 # Materialize without embeddings here — embeddings are batched
803 # later for cost. We only need the list of fact texts and the
804 # pre-extracted entities + association index map.
805 materialized = materialize_facts(
806 facts,
807 bank_id=request.bank_id,
808 occurred_at=request.occurred_at,
809 )
810 fact_texts = [item.text for item in materialized.vector_items]
811 entities = list(materialized.entities)
813 # Build (memory_idx, entity_idx) pairs from the materialized
814 # association tuples (which reference items by ID). We need
815 # indices because the caller hasn't assigned final memory IDs
816 # yet — those come after dedup + store_vectors.
817 item_id_to_idx = {item.id: i for i, item in enumerate(materialized.vector_items)}
818 ent_id_to_idx = {e.id: i for i, e in enumerate(entities)}
819 associations: list[tuple[int, str]] = []
820 for item_id, ent_id in materialized.memory_entity_associations:
821 mem_idx = item_id_to_idx.get(item_id)
822 ent_idx = ent_id_to_idx.get(ent_id)
823 if mem_idx is None or ent_idx is None:
824 continue
825 associations.append((ent_idx, mem_idx))
827 # caused_by edges: indices both into fact list (= memory list).
828 caused_by: list[tuple[int, int, float]] = []
829 for link in materialized.memory_links:
830 src_idx = item_id_to_idx.get(link.source_memory_id)
831 tgt_idx = item_id_to_idx.get(link.target_memory_id)
832 if src_idx is None or tgt_idx is None:
833 continue
834 caused_by.append((src_idx, tgt_idx, float(link.confidence)))
836 return fact_texts, entities, associations, caused_by
838 async def _persist_semantic_links(
839 self,
840 bank_id: str,
841 memory_ids: list[str],
842 embeddings: list[list[float]],
843 ) -> None:
844 """Compute & persist semantic-kNN edges for a freshly-stored batch.
846 Hindsight parity (C3a): each new memory gets ``MemoryLink(
847 link_type="semantic", weight=cosine)`` edges to its top-K
848 most-similar existing memories with similarity ≥ threshold.
849 Best-effort — failures log and continue rather than aborting
850 retain.
851 """
852 if not self.semantic_link_graph_enabled or self.graph_store is None or not memory_ids:
853 return
854 try:
855 from astrocyte.pipeline.semantic_link_graph import compute_semantic_links
857 links = await compute_semantic_links(
858 bank_id=bank_id,
859 new_memory_ids=memory_ids,
860 new_embeddings=embeddings,
861 vector_store=self.vector_store,
862 top_k=self.semantic_link_graph_top_k,
863 similarity_threshold=self.semantic_link_graph_threshold,
864 )
865 except Exception as exc:
866 _logger.warning("semantic_link_graph computation failed: %s", exc)
867 return
868 if not links:
869 return
870 try:
871 await self.graph_store.store_memory_links(links, bank_id)
872 except Exception as exc:
873 _logger.warning("storing semantic memory_links failed: %s", exc)
875 async def _process_record_entities_with_retry(
876 self,
877 record: dict[str, Any],
878 *,
879 max_retries: int = 10,
880 base_delay: float = 0.1,
881 ) -> None:
882 """Retry-wrapper around :meth:`_process_record_entities`.
884 AGE's entity ``MERGE`` path takes per-label advisory locks
885 (``pg_advisory_xact_lock(label_oid)``) so concurrent retains
886 racing on the same label name don't produce duplicate label
887 rows. Three or more retains contending on the same label can
888 produce a CYCLIC wait: P1 holds advisory, waits on tx; P2 waits
889 on advisory; P3 holds tx, waits on advisory — Postgres breaks
890 the cycle by aborting one of them with ``DeadlockDetected``.
892 This is timing-sensitive: with HNSW vector inserts the retain
893 latency was high enough that the cyclic pattern was rare in
894 practice. With DiskANN's faster inserts (2026-05-06 default
895 switch) the LoCoMo bench surfaced it within 20 retains.
896 Postgres-style deadlocks are inherently retriable — only the
897 loser's transaction rolls back, so we just try again with
898 exponential backoff.
900 Tuning history:
902 - 2026-05-06 (initial): ``max_retries=3``, ``base_delay=0.2``.
903 Absorbed most deadlocks but the LoCoMo bench surfaced 4-way
904 cyclic waits that the 3 attempts couldn't outlast.
905 - 2026-05-06 (revised): ``max_retries=10``, ``base_delay=0.1``.
906 Retain throughput evidence: with ``concurrency=10`` retains
907 and a per-bank entity overlap rate around 30% (LoCoMo
908 characters appearing in many sessions), deadlocks fire on
909 ~5% of retains. 10 retries with 0.1s × 2^n backoff
910 (capped at 3.6s by the worst case) gives 51.1s max stall —
911 ample headroom for the cyclic-wait probability tail. The
912 bench earlier observed 6 deadlocks absorbed before one
913 exhausted — bumping to 10 attempts more than doubles the
914 probability mass we cover.
915 """
916 last_exc: BaseException | None = None
917 for attempt in range(max_retries):
918 try:
919 await self._process_record_entities(record)
920 return
921 except Exception as exc: # noqa: BLE001 — match psycopg without import dep
922 # Match by class name + message so we don't need a hard
923 # dependency on psycopg in this module (the orchestrator
924 # is provider-agnostic; deadlocks are a Postgres-side
925 # concept that bubbles up via whatever DB driver the
926 # graph_store happens to use).
927 name = type(exc).__name__.lower()
928 msg = str(exc).lower()
929 is_deadlock = "deadlock" in name or "deadlock detected" in msg
930 if not is_deadlock or attempt == max_retries - 1:
931 raise
932 last_exc = exc
933 delay = base_delay * (2**attempt)
934 _logger.warning(
935 "_process_record_entities deadlock on attempt %d/%d (%s); retrying in %.2fs",
936 attempt + 1,
937 max_retries,
938 exc,
939 delay,
940 )
941 await asyncio.sleep(delay)
942 # Defensive — loop body either returns or re-raises; this is
943 # only reachable if max_retries=0 (which we guard against above
944 # via the attempt < max_retries-1 check).
945 if last_exc is not None:
946 raise last_exc
948 async def _process_record_entities(self, record: dict[str, Any]) -> None:
949 """Run all entity-related I/O for a single retain_many record.
951 Encapsulates the embed-names / store-entities / link-memories /
952 co-occurrence-links / entity-resolution cascade so :meth:`retain_many`
953 can dispatch one task per record via :func:`asyncio.gather`. Each
954 record's writes target distinct ``(bank_id, id)`` and
955 ``(bank_id, memory_id, entity_id)`` keys, so concurrent records
956 don't contend on row-level locks.
958 Failures in entity resolution are caught and logged — they must
959 not abort retain because the raw memories are already stored.
960 """
961 if self.graph_store is None:
962 return
963 request: RetainRequest = record["request"]
964 prepared = record["prepared"]
965 memory_ids: list[str] = record["memory_ids"]
966 entities: list[Entity] = record["entities"]
967 if not entities:
968 return
970 if self.entity_resolver is not None:
971 async with self._profiler.time("entity_emb"):
972 await self._attach_entity_name_embeddings(entities)
973 # Path B (Hindsight-style): rewrite tentative IDs to canonical
974 # IDs BEFORE storage so different surface forms that match an
975 # existing canonical never produce duplicate entity rows. The
976 # post-store ``resolve()`` alias pass becomes redundant in this
977 # mode and is skipped below.
978 if self.entity_resolver.canonical_resolution:
979 try:
980 async with self._profiler.time("entity_resolve"):
981 await self.entity_resolver.resolve_canonical_ids_in_place(
982 new_entities=entities,
983 bank_id=request.bank_id,
984 graph_store=self.graph_store,
985 event_date=request.occurred_at,
986 )
987 except Exception as exc:
988 _logger.warning(
989 "canonical resolution failed during retain (falling back to tentative IDs): %s",
990 exc,
991 )
993 async with self._profiler.time("entity_store"):
994 entity_ids = await self.graph_store.store_entities(entities, request.bank_id)
996 # Per-fact entity associations when SFE supplied them
997 # (each memory linked only to entities IT mentions); legacy path
998 # uses the Cartesian product.
999 sfe_associations = record.get("sfe_associations")
1000 if sfe_associations is not None:
1001 associations = [
1002 MemoryEntityAssociation(
1003 memory_id=memory_ids[mem_idx],
1004 entity_id=entity_ids[ent_idx],
1005 )
1006 for ent_idx, mem_idx in sfe_associations
1007 if 0 <= ent_idx < len(entity_ids) and 0 <= mem_idx < len(memory_ids)
1008 ]
1009 else:
1010 associations = [
1011 MemoryEntityAssociation(memory_id=mid, entity_id=eid) for mid in memory_ids for eid in entity_ids
1012 ]
1013 async with self._profiler.time("entity_link_mem"):
1014 await self.graph_store.link_memories_to_entities(associations, request.bank_id)
1015 # Same cap as the single-record retain path — see comment there.
1016 # Without it the all-pairs Cartesian product dominates retain
1017 # wall on entity-dense workloads (LME profile 2026-05-06).
1018 if self.entity_cooccurrence_enabled:
1019 pairs = _build_cooccurrence_pairs(
1020 entity_ids,
1021 self.entity_cooccurrence_max_entities,
1022 )
1023 if pairs:
1024 links = [EntityLink(entity_a=a, entity_b=b, link_type="co_occurs") for a, b in pairs]
1025 async with self._profiler.time("entity_co_occur"):
1026 await self.graph_store.store_links(links, request.bank_id)
1028 # MemoryLinks (caused_by). Two paths:
1029 # - SFE supplied them inline → resolve indices, store.
1030 # - Legacy fact_causal_extraction LLM call per record.
1031 sfe_caused_by = record.get("sfe_caused_by")
1032 chunks: list[str] = record.get("chunks") or []
1033 if sfe_caused_by:
1034 memory_links = [
1035 MemoryLink(
1036 source_memory_id=memory_ids[src_idx],
1037 target_memory_id=memory_ids[tgt_idx],
1038 link_type="caused_by",
1039 confidence=conf,
1040 weight=1.0,
1041 created_at=datetime.now(timezone.utc),
1042 metadata={"bank_id": request.bank_id, "source": "fact_extraction"},
1043 )
1044 for src_idx, tgt_idx, conf in sfe_caused_by
1045 if 0 <= src_idx < len(memory_ids) and 0 <= tgt_idx < len(memory_ids) and src_idx != tgt_idx
1046 ]
1047 if memory_links:
1048 try:
1049 await self.graph_store.store_memory_links(
1050 memory_links,
1051 request.bank_id,
1052 )
1053 except Exception as exc:
1054 _logger.warning("storing memory_links failed: %s", exc)
1055 elif (
1056 self.causal_links_enabled
1057 and len(memory_ids) > 1
1058 and len(chunks) == len(memory_ids)
1059 and self.llm_provider is not None
1060 ):
1061 try:
1062 from astrocyte.pipeline.fact_causal_extraction import (
1063 build_memory_links_from_relations,
1064 extract_fact_causal_relations,
1065 )
1067 relations = await extract_fact_causal_relations(
1068 chunks,
1069 self.llm_provider,
1070 max_pairs_per_fact=self.causal_max_pairs_per_memory,
1071 min_confidence=self.causal_min_confidence,
1072 )
1073 memory_links = build_memory_links_from_relations(
1074 relations,
1075 memory_ids,
1076 bank_id=request.bank_id,
1077 )
1078 except Exception as exc:
1079 _logger.warning("fact-level causal extraction failed: %s", exc)
1080 memory_links = []
1081 if memory_links:
1082 try:
1083 await self.graph_store.store_memory_links(
1084 memory_links,
1085 request.bank_id,
1086 )
1087 except Exception as exc:
1088 _logger.warning("storing memory_links failed: %s", exc)
1090 # Legacy two-stage flow: post-store alias resolution. Skipped when
1091 # ``canonical_resolution=True`` because IDs are already canonical.
1092 if self.entity_resolver is not None and not self.entity_resolver.canonical_resolution:
1093 try:
1094 await self.entity_resolver.resolve(
1095 new_entities=entities,
1096 source_text=prepared.text,
1097 bank_id=request.bank_id,
1098 graph_store=self.graph_store,
1099 llm_provider=self.llm_provider,
1100 event_date=request.occurred_at,
1101 )
1102 except Exception as exc:
1103 _logger.warning("entity resolution failed during retain: %s", exc)
1105 async def _provision_source_provenance(
1106 self,
1107 request: RetainRequest,
1108 prepared_text: str,
1109 chunks: list[str],
1110 ) -> list[str | None]:
1111 """Stamp the (document, chunks) pair into the SourceStore.
1113 Returns the per-chunk ``chunk_id`` list, parallel to ``chunks``.
1114 On failure (or when the SourceStore is not configured / the flag
1115 is off), returns ``[None] * len(chunks)`` so retain continues
1116 with anonymous vectors — provenance must never break ingest.
1118 The document id is deterministic (SHA-256 of the prepared text,
1119 prefixed with ``doc:``) so the same input ingested twice resolves
1120 to the same SourceDocument and the dedup probe is hit. Chunk ids
1121 are ``{document_id}:{i}`` for the same reason.
1122 """
1123 n = len(chunks)
1124 empty: list[str | None] = [None] * n
1125 if self.source_store is None or not self.source_retain_provenance:
1126 return empty
1128 # Late import: keeps the module import-cost low and avoids a hard
1129 # dependency on the M10 types from the orchestrator's surface area.
1130 from astrocyte.types import SourceChunk, SourceDocument
1132 try:
1133 doc_hash = hashlib.sha256(prepared_text.encode("utf-8")).hexdigest()
1134 doc_id_default = f"doc:{doc_hash[:16]}"
1135 document = SourceDocument(
1136 id=doc_id_default,
1137 bank_id=request.bank_id,
1138 title=None,
1139 source_uri=request.source,
1140 content_hash=doc_hash,
1141 content_type=request.content_type,
1142 metadata=None,
1143 )
1144 # store_document is idempotent on (bank_id, content_hash); the
1145 # returned id may differ from doc_id_default if a prior live
1146 # row with the same hash already exists.
1147 stored_doc_id = await self.source_store.store_document(document) # type: ignore[union-attr]
1149 source_chunks: list[SourceChunk] = []
1150 for i, chunk_text_str in enumerate(chunks):
1151 chunk_hash = hashlib.sha256(chunk_text_str.encode("utf-8")).hexdigest()
1152 source_chunks.append(
1153 SourceChunk(
1154 id=f"{stored_doc_id}:{i}",
1155 bank_id=request.bank_id,
1156 document_id=stored_doc_id,
1157 chunk_index=i,
1158 text=chunk_text_str,
1159 content_hash=chunk_hash,
1160 metadata=None,
1161 )
1162 )
1163 chunk_ids = await self.source_store.store_chunks(source_chunks) # type: ignore[union-attr]
1164 if len(chunk_ids) != n:
1165 # Should not happen for well-behaved adapters; defend
1166 # against shape mismatch by falling back rather than
1167 # silently misaligning chunk_ids onto VectorItems.
1168 _logger.warning(
1169 "source_store.store_chunks returned %d ids, expected %d — falling back",
1170 len(chunk_ids),
1171 n,
1172 )
1173 return empty
1174 return list(chunk_ids)
1175 except Exception as exc:
1176 # Defensive: provenance is best-effort. Log and degrade to
1177 # anonymous vectors so a SourceStore outage doesn't break ingest.
1178 _logger.warning("source-aware retain failed; continuing without provenance: %s", exc)
1179 return empty
1181 async def _expand_via_sibling_chunks(
1182 self,
1183 fused: list,
1184 bank_id: str,
1185 ) -> list:
1186 """M10: extend ``fused`` with vectors from sibling chunks of each top-K hit.
1188 Walks the top-K hits, asks the SourceStore for the sibling chunks
1189 of each hit's ``chunk_id`` (capped at ``expansion_max_per_hit``),
1190 then asks the vector store for the vectors backing those chunks.
1191 Each expanded hit's score is multiplied by
1192 ``expansion_score_multiplier`` so the seed hits stay ranked
1193 higher (the multiplier is < 1.0 by default).
1195 Returns the merged list, sorted by score descending. Existing
1196 memory ids are NOT duplicated — siblings whose vector is already
1197 in ``fused`` are skipped.
1198 """
1199 if not fused:
1200 return fused
1201 store = self.source_store
1202 if store is None:
1203 return fused
1205 # 1. Collect (chunk_id, doc_id?) for the top-K seeds. We process
1206 # only the top ``source_expansion_max_per_hit * 4`` because the
1207 # multiplier already shrinks tail-hit gains and we want to keep
1208 # the per-recall fan-out bounded.
1209 seed_cap = max(1, self.source_expansion_max_per_hit * 4)
1210 seed_chunk_ids = [getattr(h, "chunk_id", None) for h in fused[:seed_cap] if getattr(h, "chunk_id", None)]
1211 if not seed_chunk_ids:
1212 return fused
1214 # 2. Resolve each seed chunk → its document, then list siblings
1215 # of that document. Run lookups in parallel; tolerate failures.
1216 async def _siblings_for(chunk_id: str) -> list[str]:
1217 try:
1218 chunk = await store.get_chunk(chunk_id, bank_id) # type: ignore[union-attr]
1219 if chunk is None:
1220 return []
1221 siblings = await store.list_chunks(chunk.document_id, bank_id) # type: ignore[union-attr]
1222 # Drop the seed chunk itself; cap at expansion_max_per_hit.
1223 ids = [c.id for c in siblings if c.id != chunk_id][: self.source_expansion_max_per_hit]
1224 return ids
1225 except Exception as exc:
1226 _logger.debug("source_store sibling lookup failed for %s: %s", chunk_id, exc)
1227 return []
1229 sibling_lists = await asyncio.gather(*(_siblings_for(cid) for cid in seed_chunk_ids))
1230 sibling_chunk_ids: set[str] = set()
1231 for ids in sibling_lists:
1232 sibling_chunk_ids.update(ids)
1233 # Drop any chunk_ids already represented in ``fused`` so we don't
1234 # re-rank the same memory twice.
1235 already_present = {getattr(h, "chunk_id", None) for h in fused if getattr(h, "chunk_id", None)}
1236 sibling_chunk_ids -= already_present
1237 if not sibling_chunk_ids:
1238 return fused
1240 # 3. Pull the sibling vectors. The vector store returns score=1.0;
1241 # we apply the configured multiplier so seeds stay ranked higher.
1242 try:
1243 sibling_hits = await self.vector_store.get_by_chunk_ids( # type: ignore[attr-defined]
1244 list(sibling_chunk_ids),
1245 bank_id,
1246 )
1247 except Exception as exc:
1248 _logger.warning("vector_store.get_by_chunk_ids failed; skipping chunk expansion: %s", exc)
1249 return fused
1251 if not sibling_hits:
1252 return fused
1254 existing_ids = {getattr(h, "id", None) for h in fused}
1255 boost = self.source_expansion_score_multiplier
1256 for sh in sibling_hits:
1257 if sh.id in existing_ids:
1258 continue
1259 sh.score = max(0.0, min(1.0, sh.score * boost))
1260 fused.append(sh)
1262 # Re-sort: seeds keep their fused score, expansions slot in by
1263 # the boosted score. The downstream rerank stage gets the final say.
1264 fused.sort(key=lambda h: getattr(h, "score", 0.0), reverse=True)
1265 return fused
1267 async def retain(self, request: RetainRequest) -> RetainResult:
1268 """Retain pipeline: normalize → chunk → extract entities → embed → store."""
1269 # 0–1. Raw → normalizer → profile metadata/tags (M3 extraction chain)
1270 profile: ExtractionProfileConfig | None = None
1271 name = getattr(request, "extraction_profile", None)
1272 profile_table = merged_user_and_builtin_profiles(self.extraction_profiles)
1273 if name:
1274 profile = profile_table.get(name)
1276 # MIP pipeline overrides (from RoutingDecision) — optional, no behavior
1277 # change when absent. ``chunker`` and ``dedup`` are consumed here;
1278 # ``rerank`` and ``reflect`` apply to recall/reflect (Phase 2).
1279 mip_pipeline = request.mip_pipeline
1280 mip_chunker = mip_pipeline.chunker if mip_pipeline else None
1281 mip_dedup = mip_pipeline.dedup if mip_pipeline else None
1282 dedup_threshold_override = mip_dedup.threshold if mip_dedup else None
1283 dedup_action = (mip_dedup.action if mip_dedup else None) or "skip_chunk"
1285 prepared = prepare_retain_input(
1286 request,
1287 profile,
1288 graph_store_configured=self.graph_store is not None,
1289 )
1290 chunking = resolve_retain_chunking(
1291 prepared.effective_content_type,
1292 profile=profile,
1293 default_strategy=self.chunk_strategy,
1294 default_max_chunk_size=self.max_chunk_size,
1295 mip_chunker=mip_chunker,
1296 )
1297 chunk_kwargs: dict[str, int] = {"max_chunk_size": chunking.max_size}
1298 if chunking.overlap is not None:
1299 chunk_kwargs["overlap"] = chunking.overlap
1301 # Structured 5-dim fact extraction (opt-in). Replaces chunking +
1302 # entity extraction with a single LLM call producing facts.
1303 # Falls back to legacy chunk_text + extract_entities when
1304 # disabled or when extraction returns no facts.
1305 # SFE uses its OWN chunking strategy (not the profile's), because
1306 # SFE has different granularity needs: large chunks give the LLM
1307 # full context for metadata extraction. Default "paragraph" wins
1308 # +2.5pts on LoCoMo over the profile-driven "dialogue" choice.
1309 async with self._profiler.time("sfe"):
1310 (
1311 sfe_fact_texts,
1312 sfe_entities,
1313 sfe_associations,
1314 sfe_caused_by,
1315 ) = await self._structured_fact_extraction_for_text(
1316 prepared,
1317 request,
1318 chunk_strategy=self.structured_fact_extraction_chunk_strategy,
1319 chunk_max_size=(
1320 self.structured_fact_extraction_chunk_max_size
1321 if self.structured_fact_extraction_chunk_max_size is not None
1322 else chunking.max_size
1323 ),
1324 chunk_overlap=chunking.overlap,
1325 )
1326 if sfe_fact_texts is not None:
1327 chunks = list(sfe_fact_texts)
1328 else:
1329 chunks = chunk_text(prepared.text, strategy=chunking.strategy, **chunk_kwargs)
1330 if not chunks:
1331 return RetainResult(stored=False, error="No content after chunking")
1333 # 2. Generate embeddings for all chunks
1334 async with self._profiler.time("embed"):
1335 embeddings = await generate_embeddings(chunks, self.llm_provider)
1337 # 2b. Per-chunk dedup — behavior depends on dedup_action:
1338 # "skip_chunk" (default): drop duplicate chunks, keep the rest
1339 # "skip": if any chunk is a duplicate, reject the entire retain
1340 # "warn": keep all chunks regardless of duplicates
1341 # "update": not yet implemented; falls back to "skip_chunk"
1342 keep_indices: list[int] = []
1343 any_duplicate = False
1344 for i, emb in enumerate(embeddings):
1345 is_dup, _sim = self._dedup.is_duplicate(
1346 request.bank_id,
1347 emb,
1348 threshold_override=dedup_threshold_override,
1349 )
1350 if is_dup:
1351 any_duplicate = True
1352 if dedup_action == "warn" or not is_dup:
1353 keep_indices.append(i)
1355 if dedup_action == "skip" and any_duplicate:
1356 return RetainResult(stored=False, deduplicated=True, error="Duplicate chunk(s) found; rule action=skip")
1358 if not keep_indices:
1359 return RetainResult(stored=False, deduplicated=True, error="All chunks are near-duplicates")
1361 chunks = [chunks[i] for i in keep_indices]
1362 embeddings = [embeddings[i] for i in keep_indices]
1364 # Remap structured-fact-extraction indices through dedup. The
1365 # association and caused_by lists referenced positions in the
1366 # ORIGINAL chunks; after dedup, indices need remapping to the
1367 # surviving positions (or the records dropped entirely).
1368 if sfe_associations is not None or sfe_caused_by is not None:
1369 old_to_new = {old: new for new, old in enumerate(keep_indices)}
1370 if sfe_associations is not None:
1371 sfe_associations = [
1372 (ent_idx, old_to_new[mem_idx]) for ent_idx, mem_idx in sfe_associations if mem_idx in old_to_new
1373 ]
1374 if sfe_caused_by is not None:
1375 sfe_caused_by = [
1376 (old_to_new[src], old_to_new[tgt], conf)
1377 for src, tgt, conf in sfe_caused_by
1378 if src in old_to_new and tgt in old_to_new
1379 ]
1381 # 3. Extract entities (profile can disable; metadata profiles avoid LLM calls).
1382 # Structured fact extraction provides entities pre-extracted —
1383 # skip the second LLM call when it ran successfully.
1384 if sfe_entities is not None:
1385 entities = list(sfe_entities)
1386 elif profile is not None and profile.entity_extraction == "metadata":
1387 entities = _entities_from_metadata(prepared.metadata)
1388 elif prepared.extract_entities:
1389 entities = await extract_entities(prepared.text, self.llm_provider)
1390 else:
1391 entities = []
1393 profile_tier = profile.authority_tier if profile else None
1394 chunk_metadata = merge_retain_metadata_authority_tier(
1395 prepared.metadata,
1396 bank_id=request.bank_id,
1397 profile_authority_tier=profile_tier,
1398 recall_authority=self.recall_authority,
1399 )
1401 # Persist MIP provenance so recall can warn on rule-version drift (P2).
1402 if request.mip_rule_name or (mip_pipeline and mip_pipeline.version is not None):
1403 chunk_metadata = dict(chunk_metadata or {})
1404 if request.mip_rule_name:
1405 chunk_metadata["_mip.rule"] = request.mip_rule_name
1406 if mip_pipeline and mip_pipeline.version is not None:
1407 chunk_metadata["_mip.pipeline_version"] = int(mip_pipeline.version)
1409 # Stamp ``_created_at`` so (1) temporal retrieval can rank by
1410 # recency, (2) MIP forget.min_age_days can enforce age guards, and
1411 # (3) lifecycle TTL has a consistent retain timestamp. Callers that
1412 # already set ``_created_at`` (imports / replay) win via setdefault.
1413 # Mirrors InMemoryEngineProvider.retain so pipeline- and engine-
1414 # backed deployments share the same observability contract.
1415 chunk_metadata = dict(chunk_metadata or {})
1416 chunk_metadata.setdefault(
1417 "_created_at",
1418 datetime.now(timezone.utc).isoformat(),
1419 )
1421 # 3b. M10: persist source-document + chunk provenance, get back per-chunk
1422 # ``chunk_id``s that we'll stamp onto each VectorItem so recall can
1423 # later resolve "which document/chunk did this memory come from".
1424 # Returns ``[None] * len(chunks)`` when source_store is not configured
1425 # or the flag is off — fully backward-compatible.
1426 chunk_ids = await self._provision_source_provenance(
1427 request,
1428 prepared.text,
1429 chunks,
1430 )
1432 # 4. Store vectors
1433 memory_ids: list[str] = []
1434 items: list[VectorItem] = []
1435 for chunk, embedding, chunk_id in zip(chunks, embeddings, chunk_ids):
1436 mem_id = uuid.uuid4().hex[:16]
1437 memory_ids.append(mem_id)
1438 items.append(
1439 VectorItem(
1440 id=mem_id,
1441 bank_id=request.bank_id,
1442 vector=embedding,
1443 text=chunk,
1444 metadata=chunk_metadata,
1445 tags=prepared.tags,
1446 fact_type=prepared.fact_type,
1447 occurred_at=request.occurred_at,
1448 retained_at=datetime.now(timezone.utc), # M9: wall-clock store time
1449 chunk_id=chunk_id, # M10: source-chunk backreference
1450 )
1451 )
1453 async with self._profiler.time("store_vec"):
1454 await self.vector_store.store_vectors(items)
1456 # Hindsight-parity semantic-kNN graph (C3a). Best-effort: links
1457 # the new memories to their nearest existing neighbors so the
1458 # link-expansion retrieval CTE has a precomputed semantic
1459 # signal at recall time. Skipped when disabled.
1460 await self._persist_semantic_links(
1461 request.bank_id,
1462 [item.id for item in items],
1463 [item.vector for item in items],
1464 )
1466 # 4b. Mirror chunks into the document store so keyword (BM25) search
1467 # has content to retrieve from. Previously recall.parallel_retrieve
1468 # ran a keyword strategy when ``document_store`` was configured —
1469 # but nothing was ever stored there. Fixed here: every chunk that
1470 # lands in the vector store also lands in the document store, with
1471 # the same memory_id so RRF fusion can dedupe across strategies.
1472 # Failures don't abort the retain — degrading to semantic-only
1473 # retrieval is preferable to losing the whole memory.
1474 if self.document_store is not None:
1475 from astrocyte.types import Document
1477 for item in items:
1478 try:
1479 await self.document_store.store_document(
1480 Document(
1481 id=item.id,
1482 text=item.text,
1483 metadata=item.metadata,
1484 tags=item.tags,
1485 ),
1486 request.bank_id,
1487 )
1488 except Exception as exc:
1489 _logger.warning(
1490 "document_store.store_document failed for chunk %s: %s "
1491 "(keyword retrieval will miss this chunk)",
1492 item.id,
1493 exc,
1494 )
1496 # 5. Store entities and links (if graph store configured)
1497 if self.graph_store and entities:
1498 # 5a. Attach name embeddings before persistence so the
1499 # Hindsight-inspired entity-resolution cascade has the cheap
1500 # cosine-similarity tier available. Only fires when an
1501 # entity_resolver is wired up — without one, the embedding
1502 # would be persisted but never used, wasting API budget.
1503 if self.entity_resolver is not None:
1504 async with self._profiler.time("entity_emb"):
1505 await self._attach_entity_name_embeddings(entities)
1506 # Path B (Hindsight): rewrite tentative IDs to canonicals
1507 # before storage. Skipped when canonical_resolution=False
1508 # to preserve the legacy two-stage (store → resolve aliases)
1509 # flow.
1510 if self.entity_resolver.canonical_resolution:
1511 try:
1512 async with self._profiler.time("entity_resolve"):
1513 await self.entity_resolver.resolve_canonical_ids_in_place(
1514 new_entities=entities,
1515 bank_id=request.bank_id,
1516 graph_store=self.graph_store,
1517 event_date=request.occurred_at,
1518 )
1519 except Exception as exc:
1520 _logger.warning(
1521 "canonical resolution failed during retain (falling back to tentative IDs): %s",
1522 exc,
1523 )
1525 async with self._profiler.time("entity_store"):
1526 entity_ids = await self.graph_store.store_entities(entities, request.bank_id)
1528 # Link memories to entities. Two paths:
1529 # - Structured fact extraction provides per-fact-per-entity
1530 # associations (each memory linked only to entities IT mentions),
1531 # matching Hindsight's fact-grained granularity.
1532 # - Legacy path uses the Cartesian product (every memory linked
1533 # to every entity in the batch), which works when extraction
1534 # was over the whole text rather than per fact.
1535 if sfe_associations is not None:
1536 associations = [
1537 MemoryEntityAssociation(
1538 memory_id=memory_ids[mem_idx],
1539 entity_id=entity_ids[ent_idx],
1540 )
1541 for ent_idx, mem_idx in sfe_associations
1542 if 0 <= ent_idx < len(entity_ids) and 0 <= mem_idx < len(memory_ids)
1543 ]
1544 else:
1545 associations = [
1546 MemoryEntityAssociation(memory_id=mid, entity_id=eid) for mid in memory_ids for eid in entity_ids
1547 ]
1548 async with self._profiler.time("entity_link_mem"):
1549 await self.graph_store.link_memories_to_entities(associations, request.bank_id)
1551 # Create co-occurrence links between entities. The 2026-05-06
1552 # LME retain-profile measured the unbounded all-pairs
1553 # Cartesian product at 34% of retain wall (52s tail at the
1554 # session-100 mark). The cap in
1555 # ``_build_cooccurrence_pairs`` bounds the work to O(K²) per
1556 # retain regardless of corpus size.
1557 if self.entity_cooccurrence_enabled:
1558 pairs = _build_cooccurrence_pairs(
1559 entity_ids,
1560 self.entity_cooccurrence_max_entities,
1561 )
1562 if pairs:
1563 links = [EntityLink(entity_a=a, entity_b=b, link_type="co_occurs") for a, b in pairs]
1564 async with self._profiler.time("entity_co_occur"):
1565 await self.graph_store.store_links(links, request.bank_id)
1567 # Causal MemoryLinks. Two paths:
1568 # - Structured fact extraction supplies them inline (no
1569 # extra LLM call); resolve indices to memory IDs and store.
1570 # - Legacy path runs a separate fact_causal_extraction LLM call.
1571 if sfe_caused_by is not None and sfe_caused_by:
1572 memory_links = [
1573 MemoryLink(
1574 source_memory_id=memory_ids[src_idx],
1575 target_memory_id=memory_ids[tgt_idx],
1576 link_type="caused_by",
1577 confidence=conf,
1578 weight=1.0,
1579 created_at=datetime.now(timezone.utc),
1580 metadata={"bank_id": request.bank_id, "source": "fact_extraction"},
1581 )
1582 for src_idx, tgt_idx, conf in sfe_caused_by
1583 if 0 <= src_idx < len(memory_ids) and 0 <= tgt_idx < len(memory_ids) and src_idx != tgt_idx
1584 ]
1585 if memory_links:
1586 try:
1587 await self.graph_store.store_memory_links(
1588 memory_links,
1589 request.bank_id,
1590 )
1591 except Exception as exc:
1592 _logger.warning("storing memory_links failed: %s", exc)
1593 elif (
1594 self.causal_links_enabled
1595 and len(memory_ids) > 1
1596 and len(chunks) == len(memory_ids)
1597 and self.llm_provider is not None
1598 ):
1599 # Legacy fact-causal-extraction LLM pass (separate call).
1600 try:
1601 from astrocyte.pipeline.fact_causal_extraction import (
1602 build_memory_links_from_relations,
1603 extract_fact_causal_relations,
1604 )
1606 relations = await extract_fact_causal_relations(
1607 chunks,
1608 self.llm_provider,
1609 max_pairs_per_fact=self.causal_max_pairs_per_memory,
1610 min_confidence=self.causal_min_confidence,
1611 )
1612 memory_links = build_memory_links_from_relations(
1613 relations,
1614 memory_ids,
1615 bank_id=request.bank_id,
1616 )
1617 except Exception as exc:
1618 _logger.warning("fact-level causal extraction failed: %s", exc)
1619 memory_links = []
1620 if memory_links:
1621 try:
1622 await self.graph_store.store_memory_links(
1623 memory_links,
1624 request.bank_id,
1625 )
1626 except Exception as exc:
1627 _logger.warning("storing memory_links failed: %s", exc)
1629 # 5b. Entity resolution (M11) — Hindsight-inspired tiered cascade.
1630 # Skipped when canonical_resolution=True (Path B) because IDs
1631 # are already canonical from the pre-store pass above. Otherwise
1632 # the legacy two-stage path runs the cascade and writes
1633 # alias_of links for matched candidates.
1634 if self.entity_resolver is not None and not self.entity_resolver.canonical_resolution:
1635 try:
1636 await self.entity_resolver.resolve(
1637 new_entities=entities,
1638 source_text=prepared.text,
1639 bank_id=request.bank_id,
1640 graph_store=self.graph_store,
1641 llm_provider=self.llm_provider,
1642 event_date=request.occurred_at,
1643 )
1644 except Exception as exc:
1645 # Resolution failures must never abort a retain — degrade gracefully.
1646 _logger.warning("entity resolution failed during retain: %s", exc)
1648 # 6. Update dedup cache with stored embeddings
1649 for mem_id, emb in zip(memory_ids, embeddings):
1650 self._dedup.add(request.bank_id, mem_id, emb)
1652 # 7. Observation consolidation (fire-and-forget async task).
1653 # Runs after the retain response is returned so it never adds latency
1654 # to the caller. A failure here must never surface to the caller —
1655 # the raw memories are already stored and the consolidation is
1656 # best-effort. The representative vector is the first chunk's
1657 # embedding, which is sufficient for the observation similarity search.
1658 if self._observation_consolidator is not None and memory_ids and chunks:
1659 representative_vec = embeddings[0]
1660 consolidator = self._observation_consolidator
1661 bank_id = request.bank_id
1662 first_chunk = chunks[0]
1663 all_ids = list(memory_ids)
1664 vs = self.vector_store
1665 llm = self.llm_provider
1667 async def _run_consolidation() -> None:
1668 try:
1669 await consolidator.consolidate(
1670 new_memory_text=first_chunk,
1671 new_memory_ids=all_ids,
1672 bank_id=bank_id,
1673 vector_store=vs,
1674 llm_provider=llm,
1675 query_vector=representative_vec,
1676 scope="|".join(sorted(request.tags)) if request.tags else None,
1677 )
1678 except Exception as exc:
1679 _logger.warning("Observation consolidation task failed for bank %s: %s", bank_id, exc)
1681 task = asyncio.create_task(_run_consolidation())
1682 self._background_tasks.add(task)
1683 task.add_done_callback(self._background_tasks.discard)
1685 return RetainResult(
1686 stored=True,
1687 memory_id=memory_ids[0] if memory_ids else None,
1688 )
1690 async def retain_many(self, requests: list[RetainRequest]) -> list[RetainResult]:
1691 """Retain multiple requests while batching embedding generation and vector writes."""
1693 if not requests:
1694 return []
1696 results: list[RetainResult | None] = [None] * len(requests)
1697 records: list[dict[str, Any]] = []
1698 all_chunks: list[str] = []
1699 profile_table = merged_user_and_builtin_profiles(self.extraction_profiles)
1701 for index, request in enumerate(requests):
1702 profile = profile_table.get(request.extraction_profile) if request.extraction_profile else None
1703 mip_pipeline = request.mip_pipeline
1704 mip_chunker = mip_pipeline.chunker if mip_pipeline else None
1705 prepared = prepare_retain_input(
1706 request,
1707 profile,
1708 graph_store_configured=self.graph_store is not None,
1709 )
1710 chunking = resolve_retain_chunking(
1711 prepared.effective_content_type,
1712 profile=profile,
1713 default_strategy=self.chunk_strategy,
1714 default_max_chunk_size=self.max_chunk_size,
1715 mip_chunker=mip_chunker,
1716 )
1717 chunk_kwargs: dict[str, int] = {"max_chunk_size": chunking.max_size}
1718 if chunking.overlap is not None:
1719 chunk_kwargs["overlap"] = chunking.overlap
1721 # Structured 5-dim fact extraction (opt-in) — same branch as
1722 # retain(); produces fact texts + pre-extracted entities +
1723 # caused_by index pairs for this record.
1724 # SFE uses its own chunking strategy (see retain() above).
1725 async with self._profiler.time("sfe"):
1726 (
1727 sfe_fact_texts,
1728 sfe_entities,
1729 sfe_associations,
1730 sfe_caused_by,
1731 ) = await self._structured_fact_extraction_for_text(
1732 prepared,
1733 request,
1734 chunk_strategy=self.structured_fact_extraction_chunk_strategy,
1735 chunk_max_size=(
1736 self.structured_fact_extraction_chunk_max_size
1737 if self.structured_fact_extraction_chunk_max_size is not None
1738 else chunking.max_size
1739 ),
1740 chunk_overlap=chunking.overlap,
1741 )
1742 if sfe_fact_texts is not None:
1743 chunks = list(sfe_fact_texts)
1744 else:
1745 chunks = chunk_text(prepared.text, strategy=chunking.strategy, **chunk_kwargs)
1746 if not chunks:
1747 results[index] = RetainResult(stored=False, error="No content after chunking")
1748 continue
1750 start = len(all_chunks)
1751 all_chunks.extend(chunks)
1752 records.append(
1753 {
1754 "index": index,
1755 "request": request,
1756 "profile": profile,
1757 "prepared": prepared,
1758 "chunks": chunks,
1759 "start": start,
1760 "end": len(all_chunks),
1761 # Carry the SFE artefacts through to the second loop and
1762 # _process_record_entities. None entries mean the record
1763 # took the legacy chunk_text + extract_entities path.
1764 "sfe_entities": sfe_entities,
1765 "sfe_associations": sfe_associations,
1766 "sfe_caused_by": sfe_caused_by,
1767 }
1768 )
1770 if not records:
1771 return [result or RetainResult(stored=False, error="No content after chunking") for result in results]
1773 async with self._profiler.time("embed"):
1774 all_embeddings = await generate_embeddings(all_chunks, self.llm_provider)
1775 stored_records: list[dict[str, Any]] = []
1776 all_items: list[VectorItem] = []
1778 for record in records:
1779 request: RetainRequest = record["request"]
1780 profile: ExtractionProfileConfig | None = record["profile"]
1781 prepared = record["prepared"]
1782 chunks = record["chunks"]
1783 embeddings = all_embeddings[record["start"] : record["end"]]
1784 mip_pipeline = request.mip_pipeline
1785 mip_dedup = mip_pipeline.dedup if mip_pipeline else None
1786 dedup_threshold_override = mip_dedup.threshold if mip_dedup else None
1787 dedup_action = (mip_dedup.action if mip_dedup else None) or "skip_chunk"
1789 keep_indices: list[int] = []
1790 any_duplicate = False
1791 for chunk_index, embedding in enumerate(embeddings):
1792 is_dup, _sim = self._dedup.is_duplicate(
1793 request.bank_id,
1794 embedding,
1795 threshold_override=dedup_threshold_override,
1796 )
1797 if is_dup:
1798 any_duplicate = True
1799 if dedup_action == "warn" or not is_dup:
1800 keep_indices.append(chunk_index)
1802 result_index = record["index"]
1803 if dedup_action == "skip" and any_duplicate:
1804 results[result_index] = RetainResult(
1805 stored=False,
1806 deduplicated=True,
1807 error="Duplicate chunk(s) found; rule action=skip",
1808 )
1809 continue
1810 if not keep_indices:
1811 results[result_index] = RetainResult(
1812 stored=False,
1813 deduplicated=True,
1814 error="All chunks are near-duplicates",
1815 )
1816 continue
1818 chunks = [chunks[i] for i in keep_indices]
1819 embeddings = [embeddings[i] for i in keep_indices]
1821 # Remap SFE indices through dedup (mirrors retain() path).
1822 sfe_associations = record["sfe_associations"]
1823 sfe_caused_by = record["sfe_caused_by"]
1824 sfe_entities = record["sfe_entities"]
1825 if sfe_associations is not None or sfe_caused_by is not None:
1826 old_to_new = {old: new for new, old in enumerate(keep_indices)}
1827 if sfe_associations is not None:
1828 sfe_associations = [
1829 (ent_idx, old_to_new[mem_idx]) for ent_idx, mem_idx in sfe_associations if mem_idx in old_to_new
1830 ]
1831 if sfe_caused_by is not None:
1832 sfe_caused_by = [
1833 (old_to_new[src], old_to_new[tgt], conf)
1834 for src, tgt, conf in sfe_caused_by
1835 if src in old_to_new and tgt in old_to_new
1836 ]
1837 # Persist remapped versions for _process_record_entities.
1838 record["sfe_associations"] = sfe_associations
1839 record["sfe_caused_by"] = sfe_caused_by
1841 if sfe_entities is not None:
1842 entities = list(sfe_entities)
1843 elif profile is not None and profile.entity_extraction == "metadata":
1844 entities = _entities_from_metadata(prepared.metadata)
1845 elif prepared.extract_entities:
1846 entities = await extract_entities(prepared.text, self.llm_provider)
1847 else:
1848 entities = []
1850 profile_tier = profile.authority_tier if profile else None
1851 chunk_metadata = merge_retain_metadata_authority_tier(
1852 prepared.metadata,
1853 bank_id=request.bank_id,
1854 profile_authority_tier=profile_tier,
1855 recall_authority=self.recall_authority,
1856 )
1857 chunk_metadata = dict(chunk_metadata or {})
1858 if request.mip_rule_name:
1859 chunk_metadata["_mip.rule"] = request.mip_rule_name
1860 if mip_pipeline and mip_pipeline.version is not None:
1861 chunk_metadata["_mip.pipeline_version"] = int(mip_pipeline.version)
1862 chunk_metadata.setdefault("_created_at", datetime.now(timezone.utc).isoformat())
1864 # M10 source-aware retain — same helper as the single-retain path.
1865 chunk_ids = await self._provision_source_provenance(
1866 request,
1867 prepared.text,
1868 chunks,
1869 )
1871 memory_ids: list[str] = []
1872 items: list[VectorItem] = []
1873 for chunk, embedding, chunk_id in zip(chunks, embeddings, chunk_ids, strict=False):
1874 mem_id = uuid.uuid4().hex[:16]
1875 memory_ids.append(mem_id)
1876 items.append(
1877 VectorItem(
1878 id=mem_id,
1879 bank_id=request.bank_id,
1880 vector=embedding,
1881 text=chunk,
1882 metadata=chunk_metadata,
1883 tags=prepared.tags,
1884 fact_type=prepared.fact_type,
1885 occurred_at=request.occurred_at,
1886 retained_at=datetime.now(timezone.utc),
1887 chunk_id=chunk_id, # M10: source-chunk backreference
1888 )
1889 )
1891 record.update(
1892 {
1893 "chunks": chunks,
1894 "embeddings": embeddings,
1895 "entities": entities,
1896 "memory_ids": memory_ids,
1897 "items": items,
1898 }
1899 )
1900 stored_records.append(record)
1901 all_items.extend(items)
1903 if all_items:
1904 async with self._profiler.time("store_vec"):
1905 await self.vector_store.store_vectors(all_items)
1907 # Hindsight-parity semantic-kNN graph (C3a) — same call as the
1908 # single-retain path, applied per record so each batch's
1909 # neighbors-search is scoped to its own bank.
1910 for record in stored_records:
1911 await self._persist_semantic_links(
1912 record["request"].bank_id,
1913 record["memory_ids"],
1914 record["embeddings"],
1915 )
1917 if self.document_store is not None:
1918 from astrocyte.types import Document
1920 for record in stored_records:
1921 request = record["request"]
1922 for item in record["items"]:
1923 try:
1924 await self.document_store.store_document(
1925 Document(
1926 id=item.id,
1927 text=item.text,
1928 metadata=item.metadata,
1929 tags=item.tags,
1930 ),
1931 request.bank_id,
1932 )
1933 except Exception as exc:
1934 _logger.warning(
1935 "document_store.store_document failed for chunk %s: %s "
1936 "(keyword retrieval will miss this chunk)",
1937 item.id,
1938 exc,
1939 )
1941 # ── Phase 1: parallel entity processing across all records ──
1942 # Each record's entity work (embed names, store entities, write
1943 # co-occurrence links, run resolution cascade) is independent of
1944 # other records and dominates retain wall-clock when the bank is
1945 # large. ``asyncio.gather`` lets all records' LLM/DB I/O run
1946 # concurrently — for batches of 10 records this gives ~10× speedup
1947 # on the entity-resolution-bound retain phase. The non-I/O steps
1948 # (dedup cache, result assembly) stay sequential below.
1949 if self.graph_store is not None:
1950 entity_records = [r for r in stored_records if r["entities"]]
1951 if entity_records:
1952 await asyncio.gather(*[self._process_record_entities_with_retry(r) for r in entity_records])
1954 for record in stored_records:
1955 request: RetainRequest = record["request"]
1956 memory_ids: list[str] = record["memory_ids"]
1957 embeddings: list[list[float]] = record["embeddings"]
1958 chunks: list[str] = record["chunks"]
1960 for mem_id, embedding in zip(memory_ids, embeddings, strict=False):
1961 self._dedup.add(request.bank_id, mem_id, embedding)
1963 if self._observation_consolidator is not None and memory_ids and chunks:
1964 representative_vec = embeddings[0]
1965 consolidator = self._observation_consolidator
1966 bank_id = request.bank_id
1967 first_chunk = chunks[0]
1968 all_ids = list(memory_ids)
1969 vs = self.vector_store
1970 llm = self.llm_provider
1972 async def _run_consolidation() -> None:
1973 try:
1974 await consolidator.consolidate(
1975 new_memory_text=first_chunk,
1976 new_memory_ids=all_ids,
1977 bank_id=bank_id,
1978 vector_store=vs,
1979 llm_provider=llm,
1980 query_vector=representative_vec,
1981 scope="|".join(sorted(request.tags)) if request.tags else None,
1982 )
1983 except Exception as exc:
1984 _logger.warning(
1985 "Observation consolidation task failed for bank %s: %s",
1986 bank_id,
1987 exc,
1988 )
1990 task = asyncio.create_task(_run_consolidation())
1991 self._background_tasks.add(task)
1992 task.add_done_callback(self._background_tasks.discard)
1994 results[record["index"]] = RetainResult(
1995 stored=True,
1996 memory_id=memory_ids[0] if memory_ids else None,
1997 )
1999 return [result or RetainResult(stored=False, error="Not retained") for result in results]
2001 async def recall(self, request: RecallRequest) -> RecallResult:
2002 """Recall pipeline: embed query → parallel retrieve → fuse → rerank → budget."""
2003 # 1. Embed query
2004 query_embeddings = await generate_embeddings([request.query], self.llm_provider)
2005 query_vector = query_embeddings[0]
2007 # 1b. Wiki-tier precedence (M8 W5) — search compiled wiki pages first.
2008 # If the top hit scores above wiki_confidence_threshold, return wiki hits
2009 # + raw-memory citations and skip the full parallel-retrieve pipeline.
2010 # Caller can bypass this tier by setting fact_types (which implies they
2011 # want raw memories, not compiled wiki pages).
2012 if self.wiki_store is not None and not request.fact_types:
2013 wiki_result = await self._try_wiki_tier(request, query_vector)
2014 if wiki_result is not None:
2015 return wiki_result
2017 # 2. Build filters
2018 filters = VectorFilters(
2019 bank_id=request.bank_id,
2020 tags=request.tags,
2021 fact_types=request.fact_types,
2022 time_range=request.time_range,
2023 as_of=request.as_of, # M9: time-travel filter
2024 session_id=request.session_id, # M31 Fix 2
2025 )
2027 # 2a. Query-level temporal constraint extraction. The analyzer
2028 # parses expressions like "what happened in March 2024?" or
2029 # "last week's events" into a (start, end) range that becomes
2030 # an extra time_range filter on top of any caller-supplied one.
2031 # When the request already supplies a time_range, the
2032 # analyzer's hit is the more restrictive of the two — the
2033 # caller's filter remains the floor.
2034 if (
2035 self.query_analyzer_enabled and request.time_range is None # caller-supplied range wins
2036 ):
2037 try:
2038 from astrocyte.pipeline.query_analyzer import analyze_query
2040 # Resolution order for the relative-phrase anchor:
2041 # 1. ``query_reference_date`` if explicitly set
2042 # (preferred — see types.py for rationale)
2043 # 2. ``as_of`` if set (back-compat: M9 audit-replay
2044 # callers got temporal anchoring "for free" before
2045 # the split — preserve that)
2046 # 3. ``None`` → analyze_query falls back to ``now()``
2047 anchor = request.query_reference_date or request.as_of
2048 analysis = await analyze_query(
2049 request.query,
2050 reference_date=anchor,
2051 llm_provider=self.llm_provider,
2052 allow_llm_fallback=self.query_analyzer_allow_llm_fallback,
2053 allow_temporal_expansion=self.query_analyzer_enable_temporal_expansion,
2054 )
2055 if analysis.temporal_constraint and analysis.temporal_constraint.is_bounded():
2056 c = analysis.temporal_constraint
2057 # VectorFilters.time_range is (start, end) tuple,
2058 # both required for the SQL/in-memory adapters.
2059 # Use bench-bank min/max as defaults for open
2060 # ranges so the adapter contract is satisfied.
2061 far_past = datetime(1900, 1, 1, tzinfo=timezone.utc)
2062 far_future = datetime(2200, 1, 1, tzinfo=timezone.utc)
2063 filters = VectorFilters(
2064 bank_id=filters.bank_id,
2065 tags=filters.tags,
2066 fact_types=filters.fact_types,
2067 time_range=(
2068 c.start_date or far_past,
2069 c.end_date or far_future,
2070 ),
2071 as_of=filters.as_of,
2072 session_id=filters.session_id, # M31 Fix 2 — preserve
2073 )
2074 _logger.info(
2075 "query_analyzer extracted temporal range %s",
2076 analysis.temporal_constraint,
2077 )
2078 except Exception as exc: # pragma: no cover — defensive
2079 _logger.warning(
2080 "query_analyzer failed (%s); continuing without temporal filter.",
2081 exc,
2082 )
2084 # 2b. Extract entities from query for graph search
2085 entity_ids: list[str] | None = None
2086 if self.graph_store:
2087 query_entities = await extract_entities(request.query, self.llm_provider)
2088 if query_entities:
2089 # Look up entity IDs in the graph store
2090 found: list[str] = []
2091 for ent in query_entities:
2092 matches = await self.graph_store.query_entities(ent.name, request.bank_id, limit=3)
2093 found.extend(m.id for m in matches)
2094 entity_ids = found or None
2096 # Resolve per-bank MIP PipelineSpec once (reused for temporal
2097 # half-life override at retrieval time AND for rerank + version
2098 # check further down). Previously resolved after retrieval — moved
2099 # earlier so temporal_half_life_days can flow into parallel_retrieve.
2100 bank_pipeline = None
2101 if self.mip_router is not None:
2102 bank_pipeline = self.mip_router.resolve_pipeline_for_bank(request.bank_id)
2104 # Per-bank half-life override beats the orchestrator default.
2105 # Long-term knowledge banks (e.g. LongMemEval-style corpora with
2106 # months-old answers) set this to 90+; chat banks leave it at 7.
2107 effective_half_life = (
2108 bank_pipeline.temporal_half_life_days
2109 if bank_pipeline is not None and bank_pipeline.temporal_half_life_days is not None
2110 else self.temporal_half_life_days
2111 )
2113 # 2c. HyDE (R1) — generate hypothetical document embedding.
2114 # Runs concurrently with the retrieval step below via a separate task.
2115 # Failures return None; parallel_retrieve treats None as "HyDE disabled".
2116 hyde_vec: list[float] | None = None
2117 if self.enable_hyde:
2118 hyde_vec = await generate_hyde_vector(request.query, self.llm_provider)
2120 # 3. Parallel retrieval
2121 overfetch_limit = request.max_results * self.semantic_overfetch
2122 strategy_timings_ms: dict[str, float] = {}
2123 strategy_candidate_counts: dict[str, int] = {}
2124 strategy_results = await parallel_retrieve(
2125 query_vector=query_vector,
2126 query_text=request.query,
2127 bank_id=request.bank_id,
2128 vector_store=self.vector_store,
2129 graph_store=self.graph_store,
2130 document_store=self.document_store,
2131 entity_ids=entity_ids,
2132 limit=overfetch_limit,
2133 filters=filters,
2134 enable_temporal=self.enable_temporal_retrieval,
2135 temporal_scan_cap=self.temporal_scan_cap,
2136 temporal_half_life_days=effective_half_life,
2137 hyde_vector=hyde_vec,
2138 strategy_timings_ms=strategy_timings_ms,
2139 strategy_candidate_counts=strategy_candidate_counts,
2140 use_bm25_idf=self.bm25_idf_enabled,
2141 )
2142 query_plan = build_query_plan(request.query)
2143 if not request.fact_types and query_plan.needs_multi_hop_synthesis:
2144 entity_path_hits = await self._retrieve_entity_path_fallback(
2145 request.query,
2146 request.bank_id,
2147 limit=overfetch_limit,
2148 )
2149 if entity_path_hits:
2150 strategy_results["entity_path"] = entity_path_hits
2152 # 3b. Intent classification — hoisted here so both observation injection
2153 # (step 3b) and RRF weighting (step 4) share the same result without
2154 # computing it twice. UNKNOWN → neutral 1.0 weights everywhere.
2155 query_intent = QueryIntent.UNKNOWN
2156 if self.enable_intent_aware_recall:
2157 query_intent = classify_query_intent(request.query).intent
2158 intent_weights = weights_for_intent(query_intent)
2160 # 3c. Observation strategy — intent-gated injection.
2161 # The ::obs bank holds distilled, multi-evidence facts synthesised from
2162 # raw memories. Injecting it for *every* recall degrades factual
2163 # precision (abstract summaries displace verbatim answers). Instead,
2164 # inject only for EXPLORATORY and RELATIONAL queries — the two intents
2165 # where synthesised behavioural patterns add value over raw memories:
2166 # EXPLORATORY: "What are Alice's hobbies?" / "Describe Bob's personality"
2167 # RELATIONAL: "How does Alice's role relate to her projects?"
2168 # For all other intents, effective_obs_weight falls back to the
2169 # configured observation_weight (default 0.0 = disabled).
2170 from astrocyte.pipeline.reflect import _auto_prompt_variant
2172 _OBS_INJECTION_INTENTS = {QueryIntent.EXPLORATORY, QueryIntent.RELATIONAL}
2173 prompt_variant = _auto_prompt_variant(request.query)
2174 effective_obs_weight = (
2175 self.observation_injection_weight
2176 if query_intent in _OBS_INJECTION_INTENTS or prompt_variant == "evidence_inference"
2177 else self.observation_weight
2178 )
2179 if self._observation_consolidator is not None and not request.fact_types and effective_obs_weight > 0.0:
2180 obs_results = await self._retrieve_observations(
2181 query_vector, request.bank_id, overfetch_limit, request.as_of
2182 )
2183 if obs_results:
2184 strategy_results["observation"] = obs_results
2186 # 4. RRF fusion (local strategies + optional federated / manual external_context)
2187 weighted_inputs: list[tuple[list[Any], float]] = []
2188 for strategy, results in strategy_results.items():
2189 if not results:
2190 continue
2191 # Observation strategy uses the effective intent-gated weight;
2192 # other strategies use intent-derived weights.
2193 if strategy == "observation":
2194 weight = effective_obs_weight
2195 else:
2196 # Map strategy name → weight. External/proxy results use
2197 # semantic weight (they're typically semantic fusions upstream).
2198 weight = getattr(intent_weights, strategy, 1.0)
2199 weighted_inputs.append((results, weight))
2200 if request.external_context:
2201 weighted_inputs.append(
2202 (memory_hits_as_scored(request.external_context), intent_weights.semantic),
2203 )
2205 # Short-circuit to plain RRF when all weights are 1.0 — keeps the
2206 # baseline path bit-identical for deployments that disable
2207 # intent-aware recall.
2208 if all(w == 1.0 for _, w in weighted_inputs):
2209 fused = rrf_fusion([r for r, _ in weighted_inputs], k=self.rrf_k)
2210 else:
2211 fused = weighted_rrf_fusion(weighted_inputs, k=self.rrf_k)
2213 # 4a. Link expansion (Hindsight parity, C3). After initial RRF
2214 # surfaces direct hits, query the three first-class memory-link
2215 # signals — entity overlap, semantic kNN, causal — and merge
2216 # the resulting candidates back into the fused set. Replaces
2217 # the previous BFS-hop spreading-activation path.
2218 if self.link_expansion_params is not None and self.graph_store is not None and fused:
2219 try:
2220 expansion_hits = await link_expansion(
2221 fused[: self.link_expansion_params.expansion_limit],
2222 bank_id=request.bank_id,
2223 vector_store=self.vector_store,
2224 graph_store=self.graph_store,
2225 params=self.link_expansion_params,
2226 tags=request.tags,
2227 )
2228 except Exception as exc: # pragma: no cover — defensive
2229 _logger.warning(
2230 "link expansion failed (%s); continuing with direct fused hits only.",
2231 exc,
2232 )
2233 expansion_hits = []
2234 if expansion_hits:
2235 # Re-fuse via RRF so direct evidence keeps its precedence
2236 # (top-1 direct beats top-1 expansion by construction).
2237 fused = rrf_fusion([fused, expansion_hits], k=self.rrf_k)
2239 # 4b. Multi-query expansion: decompose the question into sub-questions,
2240 # recall for each independently (parallel), and merge all fused lists
2241 # via a final RRF pass. Only runs when enabled and the LLM judges the
2242 # question as multi-hop (len > 1). The original query's fused result is
2243 # always included so the expansion never discards the baseline recall.
2244 #
2245 # Two guards:
2246 # (a) Empty-bank guard — if initial retrieval returned nothing,
2247 # decomposition cannot help and we avoid a wasted LLM call.
2248 # (b) Confidence gate — peek at the top raw semantic score *before*
2249 # fusion. Cosine similarity is the cleanest signal for "did we
2250 # already find the answer?": if the top hit already exceeds
2251 # multi_query_confidence_threshold we skip decomposition entirely,
2252 # preventing broad sub-queries from displacing the precise answer.
2253 # RRF scores are rank-based and not used here.
2254 _top_semantic_score: float = 0.0
2255 if "semantic" in strategy_results and strategy_results["semantic"]:
2256 _top_semantic_score = strategy_results["semantic"][0].score
2258 if self.enable_multi_query_expansion and fused and _top_semantic_score < self.multi_query_confidence_threshold:
2259 from astrocyte.pipeline.multi_query import decompose_query
2261 sub_queries = await decompose_query(request.query, self.llm_provider)
2262 if len(sub_queries) > 1:
2264 async def _fuse_sub_query(sq: str) -> list[Any]:
2265 sq_vec = (await generate_embeddings([sq], self.llm_provider))[0]
2266 sq_strategy_results = await parallel_retrieve(
2267 query_vector=sq_vec,
2268 query_text=sq,
2269 bank_id=request.bank_id,
2270 vector_store=self.vector_store,
2271 graph_store=self.graph_store,
2272 document_store=self.document_store,
2273 entity_ids=entity_ids,
2274 limit=overfetch_limit,
2275 filters=filters,
2276 enable_temporal=self.enable_temporal_retrieval,
2277 temporal_scan_cap=self.temporal_scan_cap,
2278 temporal_half_life_days=effective_half_life,
2279 use_bm25_idf=self.bm25_idf_enabled,
2280 )
2281 sq_intent = (
2282 classify_query_intent(sq).intent if self.enable_intent_aware_recall else QueryIntent.UNKNOWN
2283 )
2284 sq_weights = weights_for_intent(sq_intent)
2285 sq_weighted = [
2286 (res, getattr(sq_weights, strat, 1.0)) for strat, res in sq_strategy_results.items() if res
2287 ]
2288 if not sq_weighted:
2289 return []
2290 if all(w == 1.0 for _, w in sq_weighted):
2291 return rrf_fusion([r for r, _ in sq_weighted], k=self.rrf_k)
2292 return weighted_rrf_fusion(sq_weighted, k=self.rrf_k)
2294 # sub_queries[0] is the original (already in fused); expand the rest
2295 sub_fused_lists = await asyncio.gather(*[_fuse_sub_query(sq) for sq in sub_queries[1:]])
2296 non_empty = [sf for sf in sub_fused_lists if sf]
2297 if non_empty:
2298 fused = rrf_fusion([fused, *non_empty], k=self.rrf_k)
2300 # 4b. M10 chunk expansion — for each top-K hit with a chunk_id,
2301 # fetch sibling chunks (other vectors from the same SourceDocument)
2302 # and merge them into the candidate pool. Helps multi-hop / split-
2303 # evidence questions where the answer key is in chunk N±1 of a
2304 # chunk that hit. No-op when source_store is unwired or the flag
2305 # is off, or when the vector store doesn't expose
2306 # ``get_by_chunk_ids`` (older adapters degrade gracefully).
2307 if (
2308 self.source_chunk_expansion
2309 and self.source_store is not None
2310 and hasattr(self.vector_store, "get_by_chunk_ids")
2311 ):
2312 fused = await self._expand_via_sibling_chunks(fused, request.bank_id)
2314 # 5. Reranking — apply per-bank MIP RerankSpec when a rule targets this bank (P3)
2315 # (bank_pipeline already resolved above for temporal half-life override)
2316 mip_rerank = bank_pipeline.rerank if bank_pipeline is not None else None
2317 reranked = basic_rerank(fused, request.query, mip_rerank=mip_rerank)
2318 if self.final_rerank_mode == "llm_pairwise":
2319 reranked = await llm_pairwise_rerank(
2320 reranked,
2321 request.query,
2322 self.llm_provider,
2323 top_n=self.final_rerank_top_n,
2324 keep_n=self.final_rerank_keep_n or len(reranked),
2325 )
2326 else:
2327 reranked = apply_context_diversity(reranked, request.query)
2329 # 6. Trim to max_results
2330 trimmed = reranked[: request.max_results]
2332 # 7. Convert to MemoryHit
2333 hits = [
2334 MemoryHit(
2335 text=item.text,
2336 score=item.score,
2337 fact_type=item.fact_type,
2338 metadata=item.metadata,
2339 tags=item.tags,
2340 memory_id=item.id,
2341 bank_id=request.bank_id,
2342 occurred_at=getattr(item, "occurred_at", None),
2343 retained_at=getattr(item, "retained_at", None), # M9
2344 chunk_id=getattr(item, "chunk_id", None), # M10
2345 )
2346 for item in trimmed
2347 ]
2349 # 7b. Version-drift warning (Phase 2, Step 10)
2350 # Compare each hit's persisted ``_mip.pipeline_version`` against the
2351 # version currently configured for this bank's rule. Stale hits indicate
2352 # rule changes since retain — operator should re-index or accept drift.
2353 _warn_on_version_drift(bank_pipeline, hits, request.bank_id)
2355 # 8. Token budget
2356 truncated = False
2357 if request.max_tokens:
2358 hits, truncated = enforce_token_budget(hits, request.max_tokens)
2360 ext_n = len(request.external_context) if request.external_context else 0
2361 strategies_used = list(strategy_results.keys())
2362 if ext_n:
2363 strategies_used.append("proxy")
2364 total_candidates = sum(len(r) for r in strategy_results.values()) + ext_n
2366 return RecallResult(
2367 hits=hits,
2368 total_available=len(fused),
2369 truncated=truncated,
2370 trace=RecallTrace(
2371 strategies_used=strategies_used,
2372 total_candidates=total_candidates,
2373 fusion_method="rrf",
2374 strategy_timings_ms=strategy_timings_ms or None,
2375 strategy_candidate_counts=strategy_candidate_counts or None,
2376 ),
2377 top_semantic_score=_top_semantic_score,
2378 )
2380 async def _retrieve_observations(
2381 self,
2382 query_vector: list[float],
2383 bank_id: str,
2384 limit: int,
2385 as_of: Any | None,
2386 ) -> list[Any]:
2387 """Retrieve observation-layer hits for RRF fusion.
2389 Observations are stored in a dedicated bank (``{bank_id}::obs``) that
2390 is completely separate from the raw memory bank. This prevents
2391 double-counting: the main semantic/keyword/temporal strategies only
2392 search the raw bank, while this method exclusively searches the obs
2393 bank. Converts ``VectorHit`` objects to ``ScoredItem`` to match the
2394 format expected by the weighted RRF fusion step.
2395 """
2396 from astrocyte.pipeline.fusion import ScoredItem
2397 from astrocyte.pipeline.observation import obs_bank_id
2399 obs_bank = obs_bank_id(bank_id)
2400 obs_filters = VectorFilters(
2401 bank_id=obs_bank,
2402 as_of=as_of,
2403 )
2404 try:
2405 hits = await self.vector_store.search_similar(query_vector, obs_bank, limit=limit, filters=obs_filters)
2406 except Exception as exc:
2407 _logger.warning("Observation retrieval failed for bank %s: %s", bank_id, exc)
2408 return []
2410 return [
2411 ScoredItem(
2412 id=h.id,
2413 text=h.text,
2414 score=h.score,
2415 fact_type=h.fact_type,
2416 metadata=h.metadata,
2417 tags=h.tags,
2418 memory_layer="observation",
2419 retained_at=getattr(h, "retained_at", None),
2420 )
2421 for h in hits
2422 ]
2424 async def _retrieve_entity_path_fallback(
2425 self,
2426 query: str,
2427 bank_id: str,
2428 *,
2429 limit: int,
2430 ) -> list[ScoredItem]:
2431 """In-memory entity-path recall when no graph backend is configured."""
2433 query_names = _deterministic_names(query)
2434 if not query_names:
2435 return []
2436 try:
2437 items = await self.vector_store.list_vectors(bank_id, offset=0, limit=max(limit * 10, 200))
2438 except Exception:
2439 return []
2441 hits: list[ScoredItem] = []
2442 for item in items:
2443 metadata = dict(item.metadata or {})
2444 text_names = _deterministic_names(item.text)
2445 metadata_names = {
2446 part.strip().lower()
2447 for key in ("locomo_persons", "locomo_speakers", "person")
2448 for part in str(metadata.get(key) or "").replace("|", ",").split(",")
2449 if part.strip()
2450 }
2451 matched = query_names & (text_names | metadata_names)
2452 if not matched:
2453 continue
2454 metadata["_entity_path"] = " -> ".join(sorted(matched))
2455 metadata["_entity_path_kind"] = "metadata_fallback"
2456 hits.append(
2457 ScoredItem(
2458 id=item.id,
2459 text=item.text,
2460 score=0.65 + min(len(matched), 3) * 0.05,
2461 fact_type=item.fact_type,
2462 metadata=metadata,
2463 tags=item.tags,
2464 memory_layer=item.memory_layer,
2465 occurred_at=item.occurred_at,
2466 retained_at=item.retained_at,
2467 )
2468 )
2469 return sorted(hits, key=lambda hit: hit.score, reverse=True)[:limit]
2471 async def _try_wiki_tier(
2472 self,
2473 request: RecallRequest,
2474 query_vector: list[float],
2475 ) -> RecallResult | None:
2476 """Search compiled wiki pages and return hits if the top score meets the threshold.
2478 Returns ``None`` when:
2479 - No wiki pages exist in the bank.
2480 - The top wiki score is below ``wiki_confidence_threshold``.
2481 - The vector store raises (wiki tier is non-fatal; full recall continues).
2483 When a wiki hit is returned the result also includes raw-memory citations
2484 derived from ``_wiki_source_ids`` stored in VectorItem metadata during
2485 compile. The citations are returned as additional MemoryHits with
2486 ``memory_layer="raw"`` so callers can distinguish synthesised wiki
2487 content from the underlying evidence.
2488 """
2489 try:
2490 wiki_filters = VectorFilters(
2491 bank_id=request.bank_id,
2492 tags=request.tags,
2493 fact_types=["wiki"],
2494 )
2495 wiki_hits = await self.vector_store.search_similar(
2496 query_vector,
2497 request.bank_id,
2498 limit=request.max_results,
2499 filters=wiki_filters,
2500 )
2501 except Exception:
2502 _logger.debug("Wiki-tier search failed; falling back to standard recall", exc_info=True)
2503 return None
2505 if not wiki_hits or wiki_hits[0].score < self.wiki_confidence_threshold:
2506 return None
2508 # Convert wiki VectorHits → MemoryHits
2509 hits: list[MemoryHit] = [
2510 MemoryHit(
2511 text=h.text,
2512 score=h.score,
2513 fact_type=h.fact_type,
2514 metadata=h.metadata,
2515 tags=h.tags,
2516 memory_id=h.id,
2517 bank_id=request.bank_id,
2518 memory_layer="compiled",
2519 )
2520 for h in wiki_hits
2521 ]
2523 # Append raw-memory citations from source_ids stored in metadata
2524 citation_ids: list[str] = []
2525 for h in wiki_hits:
2526 raw_ids_str = (h.metadata or {}).get("_wiki_source_ids", "")
2527 if raw_ids_str:
2528 citation_ids.extend(raw_ids_str.split(","))
2530 if citation_ids:
2531 # Fetch raw memories by scanning the vector store.
2532 # We use list_vectors (paginated) to find matches by ID — the
2533 # VectorStore SPI has no get_by_ids, so we scan once and filter.
2534 raw_map: dict[str, VectorItem] = {}
2535 offset = 0
2536 batch = 100
2537 target = set(citation_ids)
2538 while target:
2539 chunk = await self.vector_store.list_vectors(request.bank_id, offset=offset, limit=batch)
2540 if not chunk:
2541 break
2542 for item in chunk:
2543 if item.id in target:
2544 raw_map[item.id] = item
2545 target.discard(item.id)
2546 if len(chunk) < batch:
2547 break
2548 offset += batch
2550 for cid in citation_ids:
2551 item = raw_map.get(cid)
2552 if item is None:
2553 continue
2554 hits.append(
2555 MemoryHit(
2556 text=item.text,
2557 score=0.0, # citations are provenance, not ranked hits
2558 fact_type=item.fact_type,
2559 metadata=item.metadata,
2560 tags=item.tags,
2561 memory_id=item.id,
2562 bank_id=request.bank_id,
2563 memory_layer="raw",
2564 )
2565 )
2567 # Apply token budget if requested
2568 truncated = False
2569 if request.max_tokens:
2570 hits, truncated = enforce_token_budget(hits, request.max_tokens)
2572 return RecallResult(
2573 hits=hits[: request.max_results],
2574 total_available=len(wiki_hits),
2575 truncated=truncated,
2576 trace=RecallTrace(
2577 strategies_used=["wiki"],
2578 total_candidates=len(wiki_hits),
2579 fusion_method="wiki_tier",
2580 wiki_tier_used=True,
2581 ),
2582 )
2584 async def reflect(self, request: ReflectRequest) -> ReflectResult:
2585 """Reflect pipeline: recall → LLM synthesis."""
2586 query_plan = build_query_plan(request.query)
2588 # 1. Run recall with larger result set. Aggregate/multi-hop queries need
2589 # more candidate memories so synthesis can combine facts instead of
2590 # answering from the first plausible hit.
2591 recall_request = RecallRequest(
2592 query=request.query,
2593 bank_id=request.bank_id,
2594 max_results=query_plan.recall_max_results,
2595 max_tokens=request.max_tokens,
2596 tags=request.tags,
2597 as_of=request.as_of,
2598 query_reference_date=request.query_reference_date,
2599 )
2600 recall_result = await self.recall(recall_request)
2601 expanded_hits = await self._expand_reflect_sources(
2602 request.bank_id,
2603 self._rank_reflect_context(
2604 request.query,
2605 recall_result.hits,
2606 limit=query_plan.reflect_rank_limit,
2607 ),
2608 limit=query_plan.reflect_expand_limit,
2609 tags=request.tags,
2610 )
2611 recall_result.hits = expanded_hits
2612 if request.max_tokens:
2613 recall_result.hits, expanded_truncated = enforce_token_budget(
2614 recall_result.hits,
2615 request.max_tokens,
2616 )
2617 recall_result.truncated = recall_result.truncated or expanded_truncated
2619 path_ctx = self._entity_path_authority_context(recall_result.hits)
2620 auth_ctx: str | None = path_ctx
2621 ra = self.recall_authority
2622 if ra and ra.enabled and ra.apply_to_reflect:
2623 recall_result = apply_recall_authority(recall_result, ra)
2624 auth_ctx = "\n\n".join(part for part in (path_ctx, recall_result.authority_context) if part)
2626 # 2. Resolve per-bank ReflectSpec from MIP (Phase 2, Step 9)
2627 mip_reflect = None
2628 if self.mip_router is not None:
2629 bank_pipeline = self.mip_router.resolve_pipeline_for_bank(request.bank_id)
2630 if bank_pipeline is not None:
2631 mip_reflect = bank_pipeline.reflect
2633 # 2b. Auto-select a prompt when no MIP prompt override is set. Priority
2634 # (highest → lowest):
2635 # 1. MIP explicit prompt (always wins — never overridden here)
2636 # 2. Evidence-strict gate: when top raw semantic score < threshold,
2637 # retrieval is uncertain. Force citation to prevent the LLM from
2638 # constructing answers from tangential memories — the primary
2639 # adversarial failure mode in open-domain benchmarks. Uses cosine
2640 # similarity from recall(), which is set to 0.0 when no semantic
2641 # results were found.
2642 # 3. query_plan prompt variant (pre-retrieval query-shape routing)
2643 # 4. _auto_prompt_variant fallback (legacy temporal/inference heuristic)
2644 _EVIDENCE_STRICT_THRESHOLD = 0.5
2645 if mip_reflect is None or mip_reflect.prompt is None:
2646 from astrocyte.mip.schema import ReflectSpec
2647 from astrocyte.pipeline.reflect import _auto_prompt_variant
2649 prompt_variant = query_plan.prompt_variant or _auto_prompt_variant(request.query)
2651 # Evidence-strict override: weak retrieval scores mean the top
2652 # semantic match is only tangentially related — upgrade to citation
2653 # mode so the LLM admits uncertainty instead of hallucinating.
2654 # Skip if query_plan already chose evidence_strict (adversarial
2655 # query shape) — no need to double-set.
2656 if recall_result.top_semantic_score < _EVIDENCE_STRICT_THRESHOLD and prompt_variant != "evidence_strict":
2657 prompt_variant = "evidence_strict"
2659 if prompt_variant is not None:
2660 if mip_reflect is None:
2661 mip_reflect = ReflectSpec(prompt=prompt_variant)
2662 else:
2663 # Preserve any other MIP settings (e.g. promote_metadata)
2664 mip_reflect = ReflectSpec(
2665 prompt=prompt_variant,
2666 promote_metadata=mip_reflect.promote_metadata,
2667 )
2669 # 2b1. Adversarial-defense gate: premise verification.
2670 # Decompose the question into atomic claims, verify each
2671 # against memory. Short-circuit to "insufficient evidence"
2672 # when ANY presupposition fails. Targets false-premise and
2673 # negative-existence adversarial questions.
2674 if self.adversarial_premise_verification_enabled and self.llm_provider is not None:
2675 try:
2676 from astrocyte.pipeline.premise_verification import verify_question
2678 async def _premise_recall(claim: str, max_results: int) -> list[MemoryHit]:
2679 sub_request = RecallRequest(
2680 query=claim,
2681 bank_id=request.bank_id,
2682 max_results=max_results,
2683 max_tokens=request.max_tokens,
2684 tags=request.tags,
2685 as_of=request.as_of,
2686 query_reference_date=request.query_reference_date,
2687 )
2688 sub_result = await self.recall(sub_request)
2689 return sub_result.hits
2691 verification = await verify_question(
2692 request.query,
2693 recall_fn=_premise_recall,
2694 llm_provider=self.llm_provider,
2695 min_confidence=self.adversarial_premise_min_confidence,
2696 )
2697 short_circuit = verification.short_circuit_message()
2698 if short_circuit is not None:
2699 _logger.info(
2700 "reflect: premise verification short-circuit — %s",
2701 short_circuit,
2702 )
2703 return ReflectResult(answer=short_circuit, sources=[])
2704 except Exception as exc:
2705 _logger.warning(
2706 "premise verification failed (%s); continuing without the guard.",
2707 exc,
2708 )
2710 # 2c. Adversarial-defense gate: score-floor abstention.
2711 # Distinct from the evidence-strict prompt switch above.
2712 # ``evidence_strict`` keeps invoking the LLM with a tighter
2713 # prompt; the abstention floor short-circuits BEFORE any LLM
2714 # call when retrieval is so weak that the question almost
2715 # certainly has no answer in memory. Targets adversarial
2716 # questions (negative existence, false premise, time-shift)
2717 # where the LLM left to its own devices invents an answer.
2718 #
2719 # The effective floor is now derived from the request's
2720 # ``dispositions.skepticism`` (1=trust→never abstain,
2721 # 5=skeptical→aggressive abstention) rather than the legacy
2722 # global ``abstention_enabled`` bool. This lets one deployment
2723 # serve both adversarial-resistant agents and answer-everything
2724 # assistants from a single config — see
2725 # ``_abstention_floor_for_skepticism`` for the mapping. Legacy
2726 # ``abstention_enabled`` remains as a fallback when no
2727 # per-call dispositions are supplied.
2728 skepticism = _resolve_skepticism_for_abstention(
2729 request.dispositions,
2730 self.adversarial_abstention_enabled,
2731 )
2732 effective_floor = _abstention_floor_for_skepticism(
2733 skepticism,
2734 self.adversarial_abstention_floor,
2735 )
2736 if (
2737 effective_floor > 0.0
2738 and recall_result.top_semantic_score < effective_floor
2739 and (not recall_result.hits or all((h.score or 0.0) < effective_floor for h in recall_result.hits[:5]))
2740 ):
2741 _logger.info(
2742 "reflect: abstention floor triggered (top_semantic=%.3f < "
2743 "%.3f, skepticism=%d); returning 'insufficient evidence' "
2744 "without LLM call.",
2745 recall_result.top_semantic_score,
2746 effective_floor,
2747 skepticism,
2748 )
2749 return ReflectResult(
2750 answer="insufficient evidence: no memory supports this question.",
2751 sources=[],
2752 )
2754 # 3. Synthesize. Two paths:
2755 #
2756 # (a) Agentic loop (Hindsight parity) — when configured, the LLM
2757 # selects between ``recall`` and ``done`` over up to N
2758 # iterations. Targets multi-hop / open-domain where a
2759 # single-shot recall misses bridge memories. Each loop
2760 # ``recall`` reuses the full upgraded recall pipeline (RRF
2761 # + spread + cross-encoder rerank + scope tags).
2762 #
2763 # (b) Single-shot synthesis — original path. Faster, simpler.
2764 synthesize_kwargs = dict(
2765 dispositions=request.dispositions,
2766 max_tokens=request.max_tokens or 2048,
2767 authority_context=auth_ctx,
2768 mip_reflect=mip_reflect,
2769 # Forward the relative-phrase anchor so the synthesis
2770 # prompt's ``<reference_date>`` block reflects the
2771 # question's contemporaneous date, not the run wall-clock.
2772 query_reference_date=request.query_reference_date,
2773 )
2774 if self.agentic_reflect_params is not None:
2775 from astrocyte.pipeline.agentic_reflect import agentic_reflect
2777 async def _loop_recall(query: str, max_results: int) -> list[MemoryHit]:
2778 # Reuses everything we just shipped: spread, cross-
2779 # encoder rerank, tag scoping, RRF — same recall the
2780 # outer step ran, parameterized by the agent's refined
2781 # query and its requested max_results. Forwards the
2782 # outer ``as_of`` (audit replay) AND
2783 # ``query_reference_date`` (relative-phrase anchor) so
2784 # every sub-recall the agent makes sees the same
2785 # temporal context as the seed reflect.
2786 sub_request = RecallRequest(
2787 query=query,
2788 bank_id=request.bank_id,
2789 max_results=max_results,
2790 max_tokens=request.max_tokens,
2791 tags=request.tags,
2792 as_of=request.as_of,
2793 query_reference_date=request.query_reference_date,
2794 )
2795 sub_result = await self.recall(sub_request)
2796 return sub_result.hits
2798 # ``search_observations`` tool — searches the consolidated
2799 # observation layer when the consolidator is wired up. The
2800 # ::obs bank scope is reused to mirror how observations are
2801 # stored at retain time.
2802 observations_fn = None
2803 if self._observation_consolidator is not None:
2805 async def _loop_observations(query: str, max_results: int) -> list[MemoryHit]:
2806 qvec_batch = await generate_embeddings([query], self.llm_provider)
2807 qvec = qvec_batch[0] if qvec_batch else []
2808 # ``request.as_of`` (when set) propagates through to
2809 # observation search so time-travel reflect sees the
2810 # same observation snapshot as raw recall. When
2811 # unset, search uses the bank's current state.
2812 obs_results = await self._retrieve_observations(
2813 qvec,
2814 request.bank_id,
2815 max_results,
2816 request.as_of,
2817 )
2818 # Convert ScoredItem → MemoryHit so the loop can
2819 # cite IDs uniformly across tools.
2820 return [
2821 MemoryHit(
2822 text=item.text,
2823 score=item.score,
2824 fact_type=item.fact_type,
2825 metadata=item.metadata,
2826 tags=item.tags,
2827 memory_id=item.id,
2828 bank_id=request.bank_id,
2829 memory_layer="observation",
2830 )
2831 for item in (obs_results or [])
2832 ]
2834 observations_fn = _loop_observations
2836 # ``expand`` tool — fetch source memories cited by a
2837 # compiled fact / wiki / observation. Reuses the existing
2838 # source-expansion path with a single-id seed.
2839 async def _loop_expand(memory_id: str, max_sources: int) -> list[MemoryHit]:
2840 # Find the seed hit in the running pool to expand from.
2841 seed: MemoryHit | None = next(
2842 (h for h in recall_result.hits if h.memory_id == memory_id),
2843 None,
2844 )
2845 if seed is None:
2846 return []
2847 expanded = await self._expand_reflect_sources(
2848 request.bank_id,
2849 [seed],
2850 limit=max(1, max_sources) + 1,
2851 tags=request.tags,
2852 )
2853 # Drop the seed itself so the model sees only NEW evidence.
2854 return [h for h in expanded if h.memory_id != memory_id]
2856 # ``search_mental_models`` tool — top of the hierarchy.
2857 # Forwards to ``MentalModelStore.list(bank_id, scope=...)``,
2858 # converts each MentalModel to a MemoryHit so the agent can
2859 # cite ``model_id`` like any other ``memory_id``. Returns
2860 # all models in scope (no semantic ranking — mental models
2861 # are usually a handful per bank, the LLM picks).
2862 mental_models_fn = None
2863 if self.mental_model_service is not None:
2865 async def _loop_mental_models(
2866 query: str,
2867 scope: str | None,
2868 ) -> list[MemoryHit]:
2869 try:
2870 models = await self.mental_model_service.list( # type: ignore[union-attr]
2871 request.bank_id,
2872 scope=scope,
2873 )
2874 except Exception as exc:
2875 _logger.warning(
2876 "agentic_reflect.mental_models list failed (%s)",
2877 exc,
2878 )
2879 return []
2880 return [
2881 MemoryHit(
2882 text=f"# {m.title}\n\n{m.content}",
2883 score=1.0,
2884 fact_type="mental_model",
2885 metadata={
2886 "scope": m.scope,
2887 "revision": m.revision,
2888 "refreshed_at": m.refreshed_at.isoformat(),
2889 },
2890 memory_id=m.model_id,
2891 bank_id=m.bank_id,
2892 memory_layer="mental_model",
2893 )
2894 for m in models
2895 ]
2897 mental_models_fn = _loop_mental_models
2899 return await agentic_reflect(
2900 request.query,
2901 initial_hits=recall_result.hits,
2902 recall_fn=_loop_recall,
2903 observations_fn=observations_fn,
2904 expand_fn=_loop_expand,
2905 mental_models_fn=mental_models_fn,
2906 llm_provider=self.llm_provider,
2907 params=self.agentic_reflect_params,
2908 final_synthesize_fn=synthesize,
2909 final_synthesize_kwargs=synthesize_kwargs,
2910 )
2911 return await synthesize(
2912 query=request.query,
2913 hits=recall_result.hits,
2914 llm_provider=self.llm_provider,
2915 **synthesize_kwargs,
2916 )
2918 async def shutdown(self) -> None:
2919 """Drain background work and close provider resources owned by the pipeline."""
2920 if self._background_tasks:
2921 _, pending = await asyncio.wait(self._background_tasks, timeout=2.0)
2922 for task in pending:
2923 task.cancel()
2924 if pending:
2925 await asyncio.gather(*pending, return_exceptions=True)
2927 for provider in (self.vector_store, self.graph_store, self.document_store):
2928 close = getattr(provider, "close", None)
2929 if close is None:
2930 continue
2931 try:
2932 result = close()
2933 if inspect.isawaitable(result):
2934 _ = await result
2935 except Exception as exc:
2936 _logger.warning("provider close failed during pipeline shutdown: %s", exc)
2938 def _rank_reflect_context(
2939 self,
2940 query: str,
2941 hits: list[MemoryHit],
2942 *,
2943 limit: int,
2944 ) -> list[MemoryHit]:
2945 """Apply final precision rerank and hierarchy before synthesis.
2947 Uses the configured cross-encoder reranker (Hindsight parity) when
2948 ``self.cross_encoder`` is set; otherwise falls back to the
2949 deterministic heuristic ``cross_encoder_like_rerank``. The
2950 heuristic remains the default so reflect stays dependency-free
2951 unless the operator opts into the cross-encoder backend.
2952 """
2953 if not hits:
2954 return hits
2956 items = [
2957 ScoredItem(
2958 id=h.memory_id or f"hit-{idx}",
2959 text=h.text,
2960 score=h.score,
2961 fact_type=h.fact_type,
2962 metadata=h.metadata,
2963 tags=h.tags,
2964 memory_layer=h.memory_layer,
2965 occurred_at=h.occurred_at,
2966 retained_at=h.retained_at,
2967 )
2968 for idx, h in enumerate(hits)
2969 ]
2971 if self.cross_encoder is not None:
2972 try:
2973 scored = cross_encoder_rerank(
2974 items,
2975 query,
2976 model=self.cross_encoder,
2977 top_k=self.cross_encoder_top_k,
2978 )
2979 except Exception as exc: # pragma: no cover — defensive
2980 _logger.warning(
2981 "cross-encoder rerank failed (%s); falling back to heuristic.",
2982 exc,
2983 )
2984 scored = cross_encoder_like_rerank(items, query)
2985 else:
2986 scored = cross_encoder_like_rerank(items, query)
2988 hit_by_id = {h.memory_id or f"hit-{idx}": h for idx, h in enumerate(hits)}
2989 return [hit_by_id[item.id] for item in scored[:limit] if item.id in hit_by_id]
2991 def _entity_path_authority_context(self, hits: list[MemoryHit]) -> str | None:
2992 path_lines: list[str] = []
2993 direct_lines: list[str] = []
2994 for idx, hit in enumerate(hits, 1):
2995 path = (hit.metadata or {}).get("_entity_path") if hit.metadata else None
2996 if path:
2997 path_lines.append(f"- Memory {idx}: entity_path={path}")
2998 elif hit.score >= 0.5:
2999 direct_lines.append(f"- Memory {idx}: direct_facts")
3000 if not path_lines:
3001 return None
3002 sections = ["entity_path_evidence:", *path_lines]
3003 if direct_lines:
3004 sections.extend(["direct_facts:", *direct_lines[:8]])
3005 sections.append("supporting_context: use entity-path evidence before unrelated semantic matches.")
3006 return "\n".join(sections)
3008 async def _expand_reflect_sources(
3009 self,
3010 bank_id: str,
3011 hits: list[MemoryHit],
3012 *,
3013 limit: int,
3014 tags: list[str] | None = None,
3015 ) -> list[MemoryHit]:
3016 """Append raw sources cited by top wiki/observation hits.
3018 This mirrors Hindsight's reflect loop in a bounded, non-agentic form:
3019 start from compiled/observation evidence, then expand to raw facts for
3020 grounding before synthesis.
3022 ``tags`` (optional): when set, only fetched memories carrying every
3023 listed tag are appended. Closes the leak where a tag-scoped reflect
3024 could otherwise pull cross-scope raw memories via a wiki page's
3025 ``_wiki_source_ids`` metadata.
3026 """
3027 if not hits:
3028 return hits
3030 source_ids: list[str] = []
3031 seen_sources: set[str] = set()
3032 for hit in hits:
3033 for sid in _source_ids_from_metadata(hit.metadata):
3034 if sid not in seen_sources:
3035 seen_sources.add(sid)
3036 source_ids.append(sid)
3037 if not source_ids:
3038 return hits[:limit]
3040 raw_hits = await self._fetch_memory_hits_by_id(bank_id, source_ids, tags=tags)
3041 existing_ids = {h.memory_id for h in hits if h.memory_id}
3042 expanded = list(hits)
3043 for raw in raw_hits:
3044 if raw.memory_id and raw.memory_id not in existing_ids:
3045 expanded.append(raw)
3046 existing_ids.add(raw.memory_id)
3047 if len(expanded) >= limit:
3048 break
3049 return expanded[:limit]
3051 async def _fetch_memory_hits_by_id(
3052 self,
3053 bank_id: str,
3054 ids: list[str],
3055 *,
3056 tags: list[str] | None = None,
3057 ) -> list[MemoryHit]:
3058 target = set(ids)
3059 found: dict[str, VectorItem] = {}
3060 offset = 0
3061 batch = 100
3062 # Tag scoping: a fetched memory is kept only if it carries every
3063 # tag in ``tags``. ``None``/empty disables the filter (legacy
3064 # behavior). Comparison is case-insensitive to match recall's
3065 # tag-filter convention.
3066 required_tags = {str(t).lower() for t in tags} if tags else None
3067 while target:
3068 chunk = await self.vector_store.list_vectors(bank_id, offset=offset, limit=batch)
3069 if not chunk:
3070 break
3071 for item in chunk:
3072 if item.id not in target:
3073 continue
3074 if required_tags is not None:
3075 item_tags = {str(t).lower() for t in (item.tags or [])}
3076 if not required_tags.issubset(item_tags):
3077 # ID matches but scope tags missing — drop and
3078 # mark as resolved so we don't keep scanning.
3079 target.discard(item.id)
3080 continue
3081 found[item.id] = item
3082 target.discard(item.id)
3083 if len(chunk) < batch:
3084 break
3085 offset += batch
3087 return [
3088 MemoryHit(
3089 text=item.text,
3090 score=0.0,
3091 fact_type=item.fact_type,
3092 metadata=item.metadata,
3093 tags=item.tags,
3094 memory_id=item.id,
3095 bank_id=bank_id,
3096 memory_layer=item.memory_layer or "raw",
3097 occurred_at=item.occurred_at,
3098 retained_at=item.retained_at,
3099 )
3100 for sid in ids
3101 if (item := found.get(sid)) is not None
3102 ]
3105def _source_ids_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
3106 if not metadata:
3107 return []
3108 raw = metadata.get("_obs_source_ids") or metadata.get("_wiki_source_ids")
3109 if raw is None:
3110 return []
3111 if isinstance(raw, str):
3112 text = raw.strip()
3113 if not text:
3114 return []
3115 if text.startswith("["):
3116 try:
3117 parsed = json.loads(text)
3118 except json.JSONDecodeError:
3119 parsed = []
3120 if isinstance(parsed, list):
3121 return [str(x) for x in parsed if x]
3122 return [part.strip() for part in text.split(",") if part.strip()]
3123 if isinstance(raw, list):
3124 return [str(x) for x in raw if x]
3125 return []
3128def _deterministic_names(text: str) -> set[str]:
3129 return {match.group(0).strip().lower() for match in re.finditer(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?\b", text or "")}
3132def _entities_from_metadata(metadata: dict[str, Any] | None) -> list[Entity]:
3133 """Build stable entities from structured retain metadata without an LLM call."""
3135 if not metadata:
3136 return []
3137 names: set[str] = set()
3138 for key in ("locomo_persons", "locomo_speakers", "person"):
3139 value = metadata.get(key)
3140 if value is None:
3141 continue
3142 if isinstance(value, str):
3143 names.update(part.strip() for part in value.replace("|", ",").split(",") if part.strip())
3144 elif isinstance(value, list):
3145 names.update(str(part).strip() for part in value if str(part).strip())
3147 entities: list[Entity] = []
3148 for name in sorted(names):
3149 entity_id = re.sub(r"[^a-z0-9]+", "_", name.lower()).strip("_")
3150 if not entity_id:
3151 continue
3152 entities.append(
3153 Entity(
3154 id=f"person:{entity_id}",
3155 name=name,
3156 entity_type="PERSON",
3157 aliases=[],
3158 metadata={"source": "retain_metadata"},
3159 )
3160 )
3161 return entities