Skip to content

Commit 7470cd9

Browse files
committed
Rebase
1 parent f9886d8 commit 7470cd9

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

template/server/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from messaging import ContextWebSocket
1818
from stream import StreamingListJsonResponse
1919
from utils.locks import LockedMap
20-
from envs import get_envs
2120

2221
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
2322
logger = logging.Logger(__name__)
@@ -36,11 +35,15 @@ async def lifespan(app: FastAPI):
3635
client = httpx.AsyncClient()
3736

3837
try:
39-
python_context = await create_context(client, websockets, "python", "/home/user")
38+
python_context = await create_context(
39+
client, websockets, "python", "/home/user"
40+
)
4041
default_websockets["python"] = python_context.id
4142
websockets["default"] = websockets[python_context.id]
4243

43-
javascript_context = await create_context(client, websockets, "javascript", "/home/user")
44+
javascript_context = await create_context(
45+
client, websockets, "javascript", "/home/user"
46+
)
4447
default_websockets["javascript"] = javascript_context.id
4548

4649
logger.info("Connected to default runtime")
@@ -112,6 +115,7 @@ async def post_execute(request: ExecutionRequest):
112115
ws.execute(
113116
request.code,
114117
env_vars=request.env_vars,
118+
access_token=request.headers.get("X-Access-Token", None),
115119
)
116120
)
117121

template/server/messaging.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
logger = logging.getLogger(__name__)
2929

30+
3031
class Execution:
3132
def __init__(self, in_background: bool = False):
3233
self.queue = Queue[
@@ -51,13 +52,7 @@ class ContextWebSocket:
5152
_global_env_vars: Optional[Dict[StrictStr, str]] = None
5253
_cleanup_task: Optional[asyncio.Task] = None
5354

54-
def __init__(
55-
self,
56-
context_id: str,
57-
session_id: str,
58-
language: str,
59-
cwd: str
60-
):
55+
def __init__(self, context_id: str, session_id: str, language: str, cwd: str):
6156
self.language = language
6257
self.cwd = cwd
6358
self.context_id = context_id
@@ -155,13 +150,13 @@ def _set_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
155150
command = self._set_env_var_snippet(k, v)
156151
if command:
157152
env_commands.append(command)
158-
153+
159154
return "\n".join(env_commands)
160155

161156
def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
162157
"""Build environment variable cleanup code for the current language."""
163158
cleanup_commands = []
164-
159+
165160
for key in env_vars:
166161
# Check if this var exists in global env vars
167162
if self._global_env_vars and key in self._global_env_vars:
@@ -171,39 +166,39 @@ def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
171166
else:
172167
# Remove the variable
173168
command = self._delete_env_var_snippet(key)
174-
169+
175170
if command:
176171
cleanup_commands.append(command)
177-
172+
178173
return "\n".join(cleanup_commands)
179174

180175
def _get_code_indentation(self, code: str) -> str:
181176
"""Get the indentation from the first non-empty line of code."""
182177
if not code or not code.strip():
183178
return ""
184-
185-
lines = code.split('\n')
179+
180+
lines = code.split("\n")
186181
for line in lines:
187182
if line.strip(): # First non-empty line
188-
return line[:len(line) - len(line.lstrip())]
189-
183+
return line[: len(line) - len(line.lstrip())]
184+
190185
return ""
191186

192187
def _indent_code_with_level(self, code: str, indent_level: str) -> str:
193188
"""Apply the given indentation level to each line of code."""
194189
if not code or not indent_level:
195190
return code
196-
197-
lines = code.split('\n')
191+
192+
lines = code.split("\n")
198193
indented_lines = []
199-
194+
200195
for line in lines:
201196
if line.strip(): # Non-empty lines
202197
indented_lines.append(indent_level + line)
203198
else:
204199
indented_lines.append(line)
205-
206-
return '\n'.join(indented_lines)
200+
201+
return "\n".join(indented_lines)
207202

208203
async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
209204
"""Clean up environment variables in a separate execution request."""
@@ -276,7 +271,8 @@ async def change_current_directory(
276271
async def execute(
277272
self,
278273
code: Union[str, StrictStr],
279-
env_vars: Dict[StrictStr, str] = None,
274+
env_vars: Dict[StrictStr, str],
275+
access_token: str,
280276
):
281277
message_id = str(uuid.uuid4())
282278
self._executions[message_id] = Execution()
@@ -294,28 +290,32 @@ async def execute(
294290
logger.warning(f"Cleanup task failed: {e}")
295291
finally:
296292
self._cleanup_task = None
297-
293+
298294
# Get the indentation level from the code
299295
code_indent = self._get_code_indentation(code)
300-
296+
301297
# Build the complete code snippet with env vars
302298
complete_code = code
303-
299+
304300
global_env_vars_snippet = ""
305301
env_vars_snippet = ""
306302

307303
if self._global_env_vars is None:
308-
self._global_env_vars = await get_envs()
304+
self._global_env_vars = await get_envs(access_token=access_token)
309305
global_env_vars_snippet = self._set_env_vars_code(self._global_env_vars)
310-
306+
311307
if env_vars:
312308
env_vars_snippet = self._set_env_vars_code(env_vars)
313309

314310
if global_env_vars_snippet or env_vars_snippet:
315-
indented_env_code = self._indent_code_with_level(f"{global_env_vars_snippet}\n{env_vars_snippet}", code_indent)
311+
indented_env_code = self._indent_code_with_level(
312+
f"{global_env_vars_snippet}\n{env_vars_snippet}", code_indent
313+
)
316314
complete_code = f"{indented_env_code}\n{complete_code}"
317315

318-
logger.info(f"Sending code for the execution ({message_id}): {complete_code}")
316+
logger.info(
317+
f"Sending code for the execution ({message_id}): {complete_code}"
318+
)
319319
request = self._get_execute_request(message_id, complete_code, False)
320320

321321
# Send the code for execution
@@ -329,7 +329,9 @@ async def execute(
329329

330330
# Clean up env vars in a separate request after the main code has run
331331
if env_vars:
332-
self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars))
332+
self._cleanup_task = asyncio.create_task(
333+
self._cleanup_env_vars(env_vars)
334+
)
333335

334336
async def _receive_message(self):
335337
if not self._ws:
@@ -396,15 +398,19 @@ async def _process_message(self, data: dict):
396398

397399
elif data["msg_type"] == "stream":
398400
if data["content"]["name"] == "stdout":
399-
logger.debug(f"Execution {parent_msg_ig} received stdout: {data['content']['text']}")
401+
logger.debug(
402+
f"Execution {parent_msg_ig} received stdout: {data['content']['text']}"
403+
)
400404
await queue.put(
401405
Stdout(
402406
text=data["content"]["text"], timestamp=data["header"]["date"]
403407
)
404408
)
405409

406410
elif data["content"]["name"] == "stderr":
407-
logger.debug(f"Execution {parent_msg_ig} received stderr: {data['content']['text']}")
411+
logger.debug(
412+
f"Execution {parent_msg_ig} received stderr: {data['content']['text']}"
413+
)
408414
await queue.put(
409415
Stderr(
410416
text=data["content"]["text"], timestamp=data["header"]["date"]

0 commit comments

Comments
 (0)