Coverage for astrocyte/pipeline/entity_extraction.py: 76%

62 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Entity extraction — calls LLM Provider SPI for NER. 

2 

3Async (I/O-bound). See docs/_design/built-in-pipeline.md section 2. 

4""" 

5 

6from __future__ import annotations 

7 

8import json 

9import logging 

10import uuid 

11from typing import TYPE_CHECKING 

12 

13from astrocyte.types import Entity, Message 

14 

15if TYPE_CHECKING: 

16 from astrocyte.provider import LLMProvider 

17 

18logger = logging.getLogger("astrocyte.pipeline") 

19 

20_JSON_FENCE_TAG = "json" 

21 

22_EXTRACTION_SYSTEM_PROMPT = """Extract named entities from user-provided text. 

23Return a JSON array of objects with keys: "name", "entity_type", "aliases". 

24entity_type must be one of: PERSON, ORG, LOCATION, PRODUCT, EVENT, CONCEPT, OTHER. 

25aliases should be an array of alternative names (empty array if none). 

26If no entities are found, return an empty array []. 

27Respond with ONLY the JSON array, no other text.""" 

28 

29 

30async def extract_entities( 

31 text: str, 

32 llm_provider: LLMProvider, 

33 model: str | None = None, 

34 max_text_length: int = 2000, 

35) -> list[Entity]: 

36 """Extract named entities from text via LLM. 

37 

38 Returns a list of Entity objects. Returns empty list on failure. 

39 """ 

40 user_msg = f"<content>\n{text[:max_text_length]}\n</content>" 

41 try: 

42 completion = await llm_provider.complete( 

43 messages=[ 

44 Message(role="system", content=_EXTRACTION_SYSTEM_PROMPT), 

45 Message(role="user", content=user_msg), 

46 ], 

47 model=model, 

48 max_tokens=1024, 

49 temperature=0.0, 

50 ) 

51 return _parse_entities(completion.text) 

52 except Exception: 

53 logger.warning("Entity extraction failed, returning empty list", exc_info=True) 

54 return [] 

55 

56 

57def _parse_entities(response: str) -> list[Entity]: 

58 """Parse LLM response into Entity objects.""" 

59 try: 

60 # Try to find JSON array in the response 

61 text = response.strip() 

62 # Handle markdown code blocks 

63 if "```" in text: 

64 start = text.index("```") + 3 

65 if start > len(text): 

66 logger.warning("Malformed markdown code block in entity response") 

67 return [] 

68 if ( 

69 start + len(_JSON_FENCE_TAG) <= len(text) 

70 and text[start : start + len(_JSON_FENCE_TAG)].lower() == _JSON_FENCE_TAG 

71 ): 

72 start += len(_JSON_FENCE_TAG) 

73 close = text.find("```", start) 

74 if close < 0: 

75 logger.warning("Malformed markdown code block in entity response") 

76 return [] 

77 text = text[start:close].strip() 

78 

79 # Seek to the first JSON value (array or object) and parse only 

80 # that — raw_decode stops after the first complete JSON token, 

81 # so trailing text or a second array the LLM appended won't 

82 # cause "Extra data" errors. 

83 bracket = text.find("[") 

84 brace = text.find("{") 

85 if bracket < 0 and brace < 0: 

86 logger.warning("Entity extraction: no JSON found in LLM response") 

87 return [] 

88 if bracket < 0: 

89 json_start = brace 

90 elif brace < 0: 

91 json_start = bracket 

92 else: 

93 json_start = min(bracket, brace) 

94 decoder = json.JSONDecoder() 

95 entities_data, _ = decoder.raw_decode(text, json_start) 

96 # Handle common LLM wrapper formats: {"entities": [...]}, {"results": [...]} 

97 if isinstance(entities_data, dict): 

98 for key in ("entities", "results", "items", "data"): 

99 if key in entities_data and isinstance(entities_data[key], list): 

100 logger.info("Entity extraction: unwrapped JSON from %r key", key) 

101 entities_data = entities_data[key] 

102 break 

103 if not isinstance(entities_data, list): 

104 logger.warning("Entity extraction returned non-list JSON: %s", type(entities_data).__name__) 

105 return [] 

106 

107 entities: list[Entity] = [] 

108 for item in entities_data: 

109 if not isinstance(item, dict) or "name" not in item: 

110 continue 

111 entities.append( 

112 Entity( 

113 id=uuid.uuid4().hex, 

114 name=item["name"], 

115 entity_type=item.get("entity_type", "OTHER"), 

116 aliases=item.get("aliases", []), 

117 ) 

118 ) 

119 return entities 

120 except (json.JSONDecodeError, ValueError): 

121 logger.warning("Entity extraction: failed to parse LLM response as JSON", exc_info=True) 

122 return []