From 4a942421d0566b5e58c3031f751e66cd7d08ec70 Mon Sep 17 00:00:00 2001 From: Xiaoya Chong <150726549+xiaoyachong@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:23:04 -0700 Subject: [PATCH 1/2] add smi replay simulator --- src/arroyosas/app/db_replay_sim_cli.py | 218 +++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 src/arroyosas/app/db_replay_sim_cli.py diff --git a/src/arroyosas/app/db_replay_sim_cli.py b/src/arroyosas/app/db_replay_sim_cli.py new file mode 100644 index 0000000..1e13653 --- /dev/null +++ b/src/arroyosas/app/db_replay_sim_cli.py @@ -0,0 +1,218 @@ +import asyncio +import logging +import os +import sqlite3 +from datetime import datetime + +import typer +import zmq +import zmq.asyncio +import msgpack +from tiled.client import from_uri + +from ..config import settings +from ..schemas import ( + RawFrameEvent, + SASStart, + SASStop, + SerializableNumpyArrayModel, +) + +""" +Simulates image retrieval by reading Tiled URLs from a local SQLite database +and sends the fetched images onto ZMQ. +""" + +# Get configuration from environment variables with defaults +DEFAULT_DB_PATH = os.getenv("DB_PATH", "latent_vectors.db") +DEFAULT_API_KEY = os.getenv("TILED_API_KEY", None) + +# Setup logging +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +app = typer.Typer() + + +def get_urls_from_db(db_path, limit=None): + """Get a list of Tiled URLs from the database""" + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + query = "SELECT id, tiled_url FROM vectors ORDER BY id" + if limit: + query += f" LIMIT {limit}" + + cursor.execute(query) + results = cursor.fetchall() + conn.close() + + if not results: + logger.warning(f"No Tiled URLs found in database {db_path}") + return [] + + logger.info(f"Found {len(results)} Tiled URLs in database") + return results + except Exception as e: + logger.error(f"Error reading from database: {e}") + return [] + + +def read_image_from_tiled_url(tiled_url, api_key=None): + """ + Read an image from a Tiled URL. + + Example URL: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/79b18ce7-c6a9-4c25-8b7c-7d3c5a57a536/streams/primary/pil2M_image?slice=0:1,0:1679,0:1475 + + The function: + 1. Extracts the index from the slice parameter (number before ":" in the first part) + 2. Changes "array/full" to "metadata" in the base URL + 3. Extracts the dataset URI (part after "array/full/") without query parameters + 4. Connects to the Tiled server and retrieves the image + """ + try: + # Extract index from slice parameter + index = 0 # Default index + if '?' in tiled_url and 'slice=' in tiled_url: + slice_param = tiled_url.split('slice=')[1].split('&')[0] + if ':' in slice_param: + parts = slice_param.split(',')[0].split(':') + if parts[0].isdigit(): + index = int(parts[0]) + + # Parse the URL to extract base URL and path + url_without_query = tiled_url.split('?')[0] # Remove query parameters + url_parts = url_without_query.split('/api/v1/') + + if len(url_parts) != 2: + logger.error(f"Invalid Tiled URL format: {tiled_url}") + return None, 0 + + # Change array/full to metadata + base_uri = f"{url_parts[0]}/api/v1/metadata" + + # Extract dataset URI - get everything after "array/full/" + full_path = url_parts[1] + + if 'array/full/' in url_without_query: + # If the URL contains array/full, extract the part after it + path_parts = full_path.split('array/full/') + if len(path_parts) > 1: + dataset_uri = path_parts[1] + else: + dataset_uri = full_path + else: + # If URL doesn't contain array/full, use the whole path + dataset_uri = full_path + + logger.debug(f"Base URI: {base_uri}, Dataset URI: {dataset_uri}, Index: {index}") + + # Connect to the Tiled server + client = from_uri(base_uri, api_key=api_key) + + # Access the dataset + tiled_data = client[dataset_uri] + logger.debug(f"Dataset shape: {tiled_data.shape}, dtype: {tiled_data.dtype}") + + # Retrieve the image at the specified index + image = tiled_data[index] + + return image, index + + except Exception as e: + logger.error(f"Error reading from Tiled URL {tiled_url}: {e}") + return None, 0 + + +@app.command() +def main( + db_path: str = typer.Option(DEFAULT_DB_PATH, help="Path to the SQLite database containing Tiled URLs"), + max_frames: int = typer.Option(10000, help="Maximum number of frames to process"), + api_key: str = typer.Option(DEFAULT_API_KEY, help="API key for Tiled authentication") +): + """ + Run the image simulator that reads Tiled URLs from a database, fetches the images, and publishes them via ZMQ. + + Configuration can be set via environment variables: + - DB_PATH: Path to the SQLite database + - TILED_API_KEY: API key for Tiled authentication + + Command-line arguments override environment variables. + """ + # Log the configuration + logger.info(f"Starting DB Image Simulator with:") + logger.info(f"- Database path: {db_path}") + logger.info(f"- Max frames: {max_frames}") + logger.info(f"- API key provided: {api_key is not None}") + + async def run(): + # Check if database exists + if not os.path.exists(db_path): + logger.error(f"Database file not found: {db_path}") + return + + # Setup ZMQ socket + context = zmq.asyncio.Context() + socket = context.socket(zmq.PUB) + address = settings.tiled_poller.zmq_frame_publisher.address + logger.info(f"Binding to ZMQ address: {address}") + socket.bind(address) + + # Get URLs from database + urls = get_urls_from_db(db_path, limit=max_frames) + if not urls: + logger.error("No URLs found in database, cannot continue") + return + + # Send start event + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + start = SASStart( + width=1679, # Default values - these will be updated from real data + height=1475, + data_type="uint32", + tiled_url="db://latent_vectors", + run_name=f"db_tiled_run", + run_id=str(current_time), + ) + logger.info(f"Sending start event") + await socket.send(msgpack.packb(start.model_dump())) + + # Process each URL + for db_id, tiled_url in urls: + try: + logger.info(f"Processing URL from DB record {db_id}: {tiled_url}") + + # Read image data from Tiled URL + image_data, index = read_image_from_tiled_url(tiled_url, api_key) + + if image_data is None: + logger.error(f"Failed to read image from {tiled_url}") + continue + + # Send the frame event + event = RawFrameEvent( + image=SerializableNumpyArrayModel(array=image_data), + frame_number=index, + tiled_url=tiled_url, + ) + logger.info(f"Sending frame {index}") + await socket.send(msgpack.packb(event.model_dump())) + + # Small delay between frames + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"Error processing frame from {tiled_url}: {e}") + + # Send stop event + stop = SASStop(num_frames=len(urls)) + logger.info(f"Sending stop event") + await socket.send(msgpack.packb(stop.model_dump())) + logger.info(f"Complete - sent {len(urls)} frames") + + asyncio.run(run()) + + +if __name__ == "__main__": + app() \ No newline at end of file From 7f984e00367bef01a49334c7f559ebfc963f1efb Mon Sep 17 00:00:00 2001 From: Xiaoya Chong <150726549+xiaoyachong@users.noreply.github.com> Date: Fri, 22 Aug 2025 13:50:57 -0700 Subject: [PATCH 2/2] update threading in replay simulator --- pyproject.toml | 1 + src/arroyosas/app/db_replay_sim_cli.py | 159 +++++++++++++++++++------ 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7cf815..ff48065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ all = [ "pandas", "python-dotenv", "pyarrow>=14.0.1", + "aiosqlite", ] diff --git a/src/arroyosas/app/db_replay_sim_cli.py b/src/arroyosas/app/db_replay_sim_cli.py index 1e13653..bcf54c5 100644 --- a/src/arroyosas/app/db_replay_sim_cli.py +++ b/src/arroyosas/app/db_replay_sim_cli.py @@ -1,7 +1,7 @@ import asyncio import logging import os -import sqlite3 +import aiosqlite from datetime import datetime import typer @@ -26,6 +26,18 @@ # Get configuration from environment variables with defaults DEFAULT_DB_PATH = os.getenv("DB_PATH", "latent_vectors.db") DEFAULT_API_KEY = os.getenv("TILED_API_KEY", None) +# New environment variable for selecting prod vs dev environment +TILED_ENV = os.getenv("TILED_ENV", "dev").lower() + +# Define environment-specific URLs +TILED_URLS = { + "dev": { + "url_pattern": "http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/" + }, + "prod": { + "url_pattern": "http://tiled.nsls2.bnl.gov/api/v1/array/full/" + } +} # Setup logging logger = logging.getLogger(__name__) @@ -34,42 +46,105 @@ app = typer.Typer() -def get_urls_from_db(db_path, limit=None): - """Get a list of Tiled URLs from the database""" +async def get_urls_from_db(db_path, limit=None): + """Get a list of Tiled URLs from the database asynchronously""" try: - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - query = "SELECT id, tiled_url FROM vectors ORDER BY id" - if limit: - query += f" LIMIT {limit}" - - cursor.execute(query) - results = cursor.fetchall() - conn.close() - - if not results: - logger.warning(f"No Tiled URLs found in database {db_path}") - return [] + async with aiosqlite.connect(db_path) as conn: + query = "SELECT id, tiled_url FROM vectors ORDER BY id" + if limit: + query += f" LIMIT {limit}" - logger.info(f"Found {len(results)} Tiled URLs in database") - return results + async with conn.execute(query) as cursor: + results = await cursor.fetchall() + + if not results: + logger.warning(f"No Tiled URLs found in database {db_path}") + return [] + + logger.info(f"Found {len(results)} Tiled URLs in database") + return results except Exception as e: logger.error(f"Error reading from database: {e}") return [] -def read_image_from_tiled_url(tiled_url, api_key=None): +def transform_url_for_env(tiled_url, env): """ - Read an image from a Tiled URL. + Transform a Tiled URL to match the specified environment format. + + Args: + tiled_url: The original Tiled URL (typically from dev environment) + env: The target environment ('dev' or 'prod') + + Returns: + str: Transformed URL for the target environment + """ + if env not in TILED_URLS: + logger.warning(f"Unknown environment '{env}', falling back to 'dev'") + env = "dev" + + # If we're staying in dev, no transformation needed + if env == "dev" and "tiled-dev.nsls2.bnl.gov" in tiled_url: + return tiled_url + + # Extract slice parameter if present + slice_param = None + if '?' in tiled_url: + slice_param = tiled_url.split('?')[1] + + # Parse the URL to extract UUID and stream path + url_without_query = tiled_url.split('?')[0] # Remove query parameters + + # Extract UUID and stream path + uuid = None + stream_path = None - Example URL: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/79b18ce7-c6a9-4c25-8b7c-7d3c5a57a536/streams/primary/pil2M_image?slice=0:1,0:1679,0:1475 + if 'array/full/' in url_without_query: + path_after_full = url_without_query.split('array/full/')[1] + parts = path_after_full.split('/') + if len(parts) >= 1: + uuid = parts[0] + if len(parts) > 1: + stream_path = '/'.join(parts[1:]) - The function: - 1. Extracts the index from the slice parameter (number before ":" in the first part) - 2. Changes "array/full" to "metadata" in the base URL - 3. Extracts the dataset URI (part after "array/full/") without query parameters - 4. Connects to the Tiled server and retrieves the image + if not uuid or not stream_path: + logger.error(f"Could not parse Tiled URL: {tiled_url}") + return tiled_url # Return original if parsing fails + + # Transform URL based on environment + if env == "prod": + # Get image name from stream_path + parts = stream_path.split('/') + if len(parts) > 0: + image_name = parts[-1] + # Format: http://tiled.nsls2.bnl.gov/api/v1/array/full/smi/raw/{uuid}/primary/data/{image_name}?slice=... + new_url = f"{TILED_URLS[env]['url_pattern']}smi/raw/{uuid}/primary/data/{image_name}" + else: + # Fallback if we can't extract image name + logger.error(f"Could not extract image name from stream path: {stream_path}") + return tiled_url + else: + # Dev URL format: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/{uuid}/{stream_path}?slice=... + new_url = f"{TILED_URLS[env]['url_pattern']}{uuid}/{stream_path}" + + # Add slice parameter if it exists + if slice_param: + new_url = f"{new_url}?{slice_param}" + + logger.debug(f"Transformed URL: {tiled_url} -> {new_url}") + return new_url + + +def _read_image_from_tiled_url_sync(tiled_url, api_key=None): + """ + Read an image from a Tiled URL. + + Args: + tiled_url: The Tiled URL (already transformed for the appropriate environment) + api_key: API key for Tiled authentication + + Returns: + tuple: (image_data, index) """ try: # Extract index from slice parameter @@ -125,11 +200,17 @@ def read_image_from_tiled_url(tiled_url, api_key=None): return None, 0 +async def read_image_from_tiled_url(tiled_url, api_key=None): + """Async wrapper for _read_image_from_tiled_url_sync""" + return await asyncio.to_thread(_read_image_from_tiled_url_sync, tiled_url, api_key) + + @app.command() def main( db_path: str = typer.Option(DEFAULT_DB_PATH, help="Path to the SQLite database containing Tiled URLs"), max_frames: int = typer.Option(10000, help="Maximum number of frames to process"), - api_key: str = typer.Option(DEFAULT_API_KEY, help="API key for Tiled authentication") + api_key: str = typer.Option(DEFAULT_API_KEY, help="API key for Tiled authentication"), + env: str = typer.Option(TILED_ENV, help="Tiled environment to use ('dev' or 'prod')") ): """ Run the image simulator that reads Tiled URLs from a database, fetches the images, and publishes them via ZMQ. @@ -137,6 +218,7 @@ def main( Configuration can be set via environment variables: - DB_PATH: Path to the SQLite database - TILED_API_KEY: API key for Tiled authentication + - TILED_ENV: Environment to use ('dev' or 'prod') Command-line arguments override environment variables. """ @@ -145,6 +227,7 @@ def main( logger.info(f"- Database path: {db_path}") logger.info(f"- Max frames: {max_frames}") logger.info(f"- API key provided: {api_key is not None}") + logger.info(f"- Tiled environment: {env}") async def run(): # Check if database exists @@ -160,7 +243,7 @@ async def run(): socket.bind(address) # Get URLs from database - urls = get_urls_from_db(db_path, limit=max_frames) + urls = await get_urls_from_db(db_path, limit=max_frames) if not urls: logger.error("No URLs found in database, cannot continue") return @@ -171,8 +254,8 @@ async def run(): width=1679, # Default values - these will be updated from real data height=1475, data_type="uint32", - tiled_url="db://latent_vectors", - run_name=f"db_tiled_run", + tiled_url=f"{env}://latent_vectors", + run_name=f"{env}_tiled_run", run_id=str(current_time), ) logger.info(f"Sending start event") @@ -183,18 +266,22 @@ async def run(): try: logger.info(f"Processing URL from DB record {db_id}: {tiled_url}") - # Read image data from Tiled URL - image_data, index = read_image_from_tiled_url(tiled_url, api_key) + # Transform the URL for the current environment before processing + transformed_url = transform_url_for_env(tiled_url, env) + logger.info(f"Transformed URL: {transformed_url}") + + # Read image data from transformed Tiled URL + image_data, index = await read_image_from_tiled_url(transformed_url, api_key) if image_data is None: - logger.error(f"Failed to read image from {tiled_url}") + logger.error(f"Failed to read image from {transformed_url}") continue - # Send the frame event + # Send the frame event with transformed URL event = RawFrameEvent( image=SerializableNumpyArrayModel(array=image_data), frame_number=index, - tiled_url=tiled_url, + tiled_url=transformed_url, ) logger.info(f"Sending frame {index}") await socket.send(msgpack.packb(event.model_dump()))