Coverage for astrocyte/pipeline/entity_extraction.py: 76%
62 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Entity extraction — calls LLM Provider SPI for NER.
3Async (I/O-bound). See docs/_design/built-in-pipeline.md section 2.
4"""
6from __future__ import annotations
8import json
9import logging
10import uuid
11from typing import TYPE_CHECKING
13from astrocyte.types import Entity, Message
15if TYPE_CHECKING:
16 from astrocyte.provider import LLMProvider
18logger = logging.getLogger("astrocyte.pipeline")
20_JSON_FENCE_TAG = "json"
22_EXTRACTION_SYSTEM_PROMPT = """Extract named entities from user-provided text.
23Return a JSON array of objects with keys: "name", "entity_type", "aliases".
24entity_type must be one of: PERSON, ORG, LOCATION, PRODUCT, EVENT, CONCEPT, OTHER.
25aliases should be an array of alternative names (empty array if none).
26If no entities are found, return an empty array [].
27Respond with ONLY the JSON array, no other text."""
30async def extract_entities(
31 text: str,
32 llm_provider: LLMProvider,
33 model: str | None = None,
34 max_text_length: int = 2000,
35) -> list[Entity]:
36 """Extract named entities from text via LLM.
38 Returns a list of Entity objects. Returns empty list on failure.
39 """
40 user_msg = f"<content>\n{text[:max_text_length]}\n</content>"
41 try:
42 completion = await llm_provider.complete(
43 messages=[
44 Message(role="system", content=_EXTRACTION_SYSTEM_PROMPT),
45 Message(role="user", content=user_msg),
46 ],
47 model=model,
48 max_tokens=1024,
49 temperature=0.0,
50 )
51 return _parse_entities(completion.text)
52 except Exception:
53 logger.warning("Entity extraction failed, returning empty list", exc_info=True)
54 return []
57def _parse_entities(response: str) -> list[Entity]:
58 """Parse LLM response into Entity objects."""
59 try:
60 # Try to find JSON array in the response
61 text = response.strip()
62 # Handle markdown code blocks
63 if "```" in text:
64 start = text.index("```") + 3
65 if start > len(text):
66 logger.warning("Malformed markdown code block in entity response")
67 return []
68 if (
69 start + len(_JSON_FENCE_TAG) <= len(text)
70 and text[start : start + len(_JSON_FENCE_TAG)].lower() == _JSON_FENCE_TAG
71 ):
72 start += len(_JSON_FENCE_TAG)
73 close = text.find("```", start)
74 if close < 0:
75 logger.warning("Malformed markdown code block in entity response")
76 return []
77 text = text[start:close].strip()
79 # Seek to the first JSON value (array or object) and parse only
80 # that — raw_decode stops after the first complete JSON token,
81 # so trailing text or a second array the LLM appended won't
82 # cause "Extra data" errors.
83 bracket = text.find("[")
84 brace = text.find("{")
85 if bracket < 0 and brace < 0:
86 logger.warning("Entity extraction: no JSON found in LLM response")
87 return []
88 if bracket < 0:
89 json_start = brace
90 elif brace < 0:
91 json_start = bracket
92 else:
93 json_start = min(bracket, brace)
94 decoder = json.JSONDecoder()
95 entities_data, _ = decoder.raw_decode(text, json_start)
96 # Handle common LLM wrapper formats: {"entities": [...]}, {"results": [...]}
97 if isinstance(entities_data, dict):
98 for key in ("entities", "results", "items", "data"):
99 if key in entities_data and isinstance(entities_data[key], list):
100 logger.info("Entity extraction: unwrapped JSON from %r key", key)
101 entities_data = entities_data[key]
102 break
103 if not isinstance(entities_data, list):
104 logger.warning("Entity extraction returned non-list JSON: %s", type(entities_data).__name__)
105 return []
107 entities: list[Entity] = []
108 for item in entities_data:
109 if not isinstance(item, dict) or "name" not in item:
110 continue
111 entities.append(
112 Entity(
113 id=uuid.uuid4().hex,
114 name=item["name"],
115 entity_type=item.get("entity_type", "OTHER"),
116 aliases=item.get("aliases", []),
117 )
118 )
119 return entities
120 except (json.JSONDecodeError, ValueError):
121 logger.warning("Entity extraction: failed to parse LLM response as JSON", exc_info=True)
122 return []