Coverage for astrocyte/pipeline/embedding.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Embedding generation — calls LLM Provider SPI. 

2 

3Async (I/O-bound). See docs/_design/built-in-pipeline.md section 2. 

4""" 

5 

6from __future__ import annotations 

7 

8import logging 

9from typing import TYPE_CHECKING 

10 

11if TYPE_CHECKING: 

12 from astrocyte.provider import LLMProvider 

13 

14logger = logging.getLogger("astrocyte.pipeline") 

15 

16 

17async def generate_embeddings( 

18 texts: list[str], 

19 llm_provider: LLMProvider, 

20 model: str | None = None, 

21) -> list[list[float]]: 

22 """Generate embeddings for a list of texts via the LLM SPI. 

23 

24 Falls back to a simple hash-based embedding if the provider raises NotImplementedError. 

25 """ 

26 try: 

27 return await llm_provider.embed(texts, model=model) 

28 except NotImplementedError: 

29 import warnings 

30 

31 msg = ( 

32 "Embedding provider does not support embed() — falling back to pseudo-embeddings. " 

33 "Semantic search will NOT work correctly. Configure a real embedding provider for production use." 

34 ) 

35 logger.error(msg) 

36 warnings.warn(msg, UserWarning, stacklevel=2) 

37 return [_pseudo_embedding(text) for text in texts] 

38 

39 

40def _pseudo_embedding(text: str, dims: int = 128) -> list[float]: 

41 """Generate a deterministic pseudo-embedding from text. 

42 

43 NOT for production — only for development when no embedding model is available. 

44 Uses character-level hashing to produce a normalized vector. 

45 """ 

46 import hashlib 

47 

48 h = hashlib.sha256(text.encode()).digest() 

49 raw = [float(b) / 255.0 for b in h] 

50 # Extend to desired dimensions by repeating 

51 while len(raw) < dims: 

52 h = hashlib.sha256(h).digest() 

53 raw.extend(float(b) / 255.0 for b in h) 

54 raw = raw[:dims] 

55 # Normalize 

56 import math 

57 

58 norm = math.sqrt(sum(x * x for x in raw)) 

59 if norm > 0: 

60 raw = [x / norm for x in raw] 

61 return raw