Coverage for astrocyte/_discovery.py: 58%

31 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Provider discovery via Python entry points (importlib.metadata).""" 

2 

3from __future__ import annotations 

4 

5import importlib 

6import importlib.metadata 

7from typing import Any 

8 

9# Entry point group names — providers register under these in pyproject.toml 

10ENTRY_POINT_GROUPS = { 

11 "vector_stores": "astrocyte.vector_stores", 

12 "graph_stores": "astrocyte.graph_stores", 

13 "document_stores": "astrocyte.document_stores", 

14 "wiki_stores": "astrocyte.wiki_stores", 

15 "mental_model_stores": "astrocyte.mental_model_stores", 

16 "source_stores": "astrocyte.source_stores", 

17 "pageindex_stores": "astrocyte.pageindex_stores", # M9 section recall 

18 "engine_providers": "astrocyte.engine_providers", 

19 "llm_providers": "astrocyte.llm_providers", 

20 "outbound_transports": "astrocyte.outbound_transports", 

21 "ingest_stream_drivers": "astrocyte.ingest_stream_drivers", 

22 "ingest_poll_drivers": "astrocyte.ingest_poll_drivers", 

23} 

24 

25 

26def discover_entry_points(group: str) -> dict[str, Any]: 

27 """Discover all registered providers for an entry point group. 

28 

29 Returns a dict of {name: loaded_class} for all installed providers 

30 in the given group. 

31 """ 

32 ep_group = ENTRY_POINT_GROUPS.get(group, group) 

33 result: dict[str, Any] = {} 

34 for ep in importlib.metadata.entry_points(group=ep_group): 

35 result[ep.name] = ep.load() 

36 return result 

37 

38 

39def resolve_provider(name: str, group: str) -> Any: 

40 """Resolve a single provider by name from entry points, or by import path. 

41 

42 If name contains ":" (e.g., "mypackage.module:ClassName"), it's treated 

43 as a direct import path. Otherwise, it's looked up in entry points. 

44 """ 

45 if ":" in name: 

46 # Direct import path 

47 module_path, class_name = name.rsplit(":", 1) 

48 module = importlib.import_module(module_path) 

49 obj = getattr(module, class_name) 

50 if not isinstance(obj, type): 

51 raise TypeError(f"'{name}' resolved to {type(obj).__name__}, expected a class") 

52 return obj 

53 

54 # Entry point lookup 

55 ep_group = ENTRY_POINT_GROUPS.get(group, group) 

56 for ep in importlib.metadata.entry_points(group=ep_group): 

57 if ep.name == name: 

58 return ep.load() 

59 

60 raise LookupError(f"Provider '{name}' not found in entry point group '{ep_group}'") 

61 

62 

63def available_providers() -> dict[str, dict[str, Any]]: 

64 """Discover all installed providers across all groups. 

65 

66 Returns a dict of {group: {name: class}} for all installed providers. 

67 """ 

68 result: dict[str, dict[str, Any]] = {} 

69 for group_key, _ in ENTRY_POINT_GROUPS.items(): 

70 providers = discover_entry_points(group_key) 

71 if providers: 

72 result[group_key] = providers 

73 return result