-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path__init__.py
More file actions
125 lines (102 loc) · 3.39 KB
/
__init__.py
File metadata and controls
125 lines (102 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""
Render Module
This module provides rendering capabilities for GNN specifications to various
target languages and simulation environments.
"""
# Phase 6: render submodules are in-tree; unconditional imports.
from .processor import (
get_available_renderers,
get_module_info,
process_render,
render_gnn_spec,
)
from .pomdp_processor import POMDPRenderProcessor, process_pomdp_for_frameworks
from .generators import (
generate_activeinference_jl_code,
generate_discopy_code,
generate_pymdp_code,
generate_rxinfer_code,
)
from .pymdp import render_gnn_to_pymdp
from .rxinfer import render_gnn_to_rxinfer, render_gnn_to_rxinfer_toml
from .discopy import render_gnn_to_discopy
from .activeinference_jl import render_gnn_to_activeinference_jl
from .pytorch import render_gnn_to_pytorch
from .numpyro import render_gnn_to_numpyro
from .pymdp.pymdp_renderer import PyMDPRenderer
class JAXRenderer:
"""Facade over ``render_gnn_to_jax`` exposed as a class for callers that
want polymorphic dispatch. The real rendering work is in
``render/jax/jax_renderer.py`` — this class forwards ``render`` to it."""
def render(self, spec) -> str:
from .jax.jax_renderer import render_gnn_to_jax
result = render_gnn_to_jax(spec)
return result if isinstance(result, str) else str(result)
def get_supported_frameworks():
"""Return list of supported rendering frameworks.
Returns:
List of framework names that can be used for rendering.
"""
return ['pymdp', 'rxinfer', 'activeinference_jl', 'jax', 'discopy', 'pytorch', 'numpyro']
def validate_render(result, framework=None):
"""Validate render output.
Args:
result: The render result to validate.
framework: Optional framework name for framework-specific validation.
Returns:
True if validation passes.
Raises:
ValueError: If validation fails.
"""
if result is None:
raise ValueError("Render result is None")
if isinstance(result, str) and len(result) == 0:
raise ValueError("Render result is empty string")
return True
__all__ = [
# Core functions
'process_render',
'render_gnn_spec',
'get_module_info',
'get_available_renderers',
# Generator functions
'generate_pymdp_code',
'generate_rxinfer_code',
'generate_activeinference_jl_code',
'generate_discopy_code',
# Specific renderer functions
'render_gnn_to_pymdp',
'render_gnn_to_rxinfer',
'render_gnn_to_rxinfer_toml',
'render_gnn_to_discopy',
'render_gnn_to_activeinference_jl',
'render_gnn_to_pytorch',
'render_gnn_to_numpyro',
# Renderer classes
'PyMDPRenderer',
'JAXRenderer',
# POMDP processing
'POMDPRenderProcessor',
'process_pomdp_for_frameworks',
# Utility functions
'get_supported_frameworks',
'validate_render',
]
__version__ = "1.6.0"
FEATURES = {
"pymdp_rendering": True,
"rxinfer_rendering": True,
"activeinference_jl_rendering": True,
"discopy_rendering": True,
"jax_rendering": True,
"pytorch_rendering": True,
"numpyro_rendering": True,
"mcp_integration": True,
"pomdp_processing": True,
"state_space_extraction": True,
"modular_injection": True,
"framework_specific_outputs": True,
"structured_documentation": True
}
from .render import main # expose CLI entry as attribute for tests