diff --git a/src/palabra_ai/audio.py b/src/palabra_ai/audio.py index fc07d0f..262cfec 100644 --- a/src/palabra_ai/audio.py +++ b/src/palabra_ai/audio.py @@ -259,34 +259,6 @@ def to_ws(self) -> bytes: } ) - def to_bench(self): - result = { - "message_type": "__$bench_audio_frame", - "__dbg": { - "size": len(self.data), - "sample_rate": self.sample_rate, - "num_channels": self.num_channels, - "samples_per_channel": self.samples_per_channel, - }, - "data": {}, - } - - # Include transcription metadata only for output frames (those with transcription_id) - if self.transcription_id: - result["data"].update( - { - "transcription_id": self.transcription_id, - "language": self.language, - "last_chunk": self.last_chunk, - } - ) - - # Replace base64 audio data with "..." to avoid log pollution - if "data" in result["data"] and isinstance(result["data"]["data"], str): - result["data"]["data"] = "..." - - return result - @dataclass class AudioBuffer: diff --git a/src/palabra_ai/benchmark/__main__.py b/src/palabra_ai/benchmark/__main__.py index 154e8e1..a7934cd 100644 --- a/src/palabra_ai/benchmark/__main__.py +++ b/src/palabra_ai/benchmark/__main__.py @@ -13,14 +13,13 @@ from palabra_ai import Config, PalabraAI, SourceLang, TargetLang from palabra_ai.audio import save_wav from palabra_ai.benchmark.report import BENCHMARK_ALLOWED_MESSAGE_TYPES -from palabra_ai.benchmark.report import format_report from palabra_ai.benchmark.report import INPUT_CHUNK_DURATION_S from palabra_ai.benchmark.report import Report -from palabra_ai.benchmark.report import save_benchmark_files from palabra_ai.config import WsMode from palabra_ai.lang import Language from palabra_ai.task.adapter.dummy import DummyWriter from palabra_ai.task.adapter.file import FileReader +from palabra_ai.util.fileio import save_text from palabra_ai.util.orjson import to_json from palabra_ai.util.sysinfo import get_system_info @@ -51,20 +50,6 @@ def main(): raise FileNotFoundError(f"Audio file not found: {args.audio}") mode = WsMode(input_chunk_duration_ms=INPUT_CHUNK_DURATION_S * 1000) - # Setup output directory and timestamp if --out is specified - if args.out: - output_dir = args.out - output_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Save sysinfo immediately at startup - sysinfo = get_system_info() - sysinfo["command"] = " ".join(sys.argv) - sysinfo["argv"] = sys.argv - sysinfo["cwd"] = str(Path.cwd()) - sysinfo_path = output_dir / f"{timestamp}_bench_sysinfo.json" - sysinfo_path.write_bytes(to_json(sysinfo, True)) - # Get audio duration for progress tracking with av.open(str(audio_path)) as container: audio_duration = container.duration / 1000000 # convert microseconds to seconds @@ -98,64 +83,40 @@ def on_transcription(msg): # Force benchmark mode with 100ms buffer regardless of config # Config loaded from JSON defaults to 320ms chunks, but benchmark needs 100ms for optimal performance config.mode = WsMode(input_chunk_duration_ms=INPUT_CHUNK_DURATION_S * 1000) - - source_lang = config.source.lang.code - target_lang = config.targets[0].lang.code else: if not args.source_lang or not args.target_lang: parser.error("source_lang and target_lang required without --config") - source_lang = args.source_lang - target_lang = args.target_lang config = Config( - source=SourceLang(Language.get_or_create(source_lang), reader, on_transcription=on_transcription), - targets=[TargetLang(Language.get_or_create(target_lang), DummyWriter())], + source=SourceLang(Language.get_or_create(args.source_lang), reader, on_transcription=on_transcription), + targets=[TargetLang(Language.get_or_create(args.target_lang), DummyWriter())], benchmark=True, mode=mode, allowed_message_types=BENCHMARK_ALLOWED_MESSAGE_TYPES, ) - # Enable debug mode and logging when --out is specified - if output_dir and timestamp: - config.debug = True - config.log_file = str(output_dir / f"{timestamp}_bench.log") - - # Save exact config that goes to set_task (SetTaskMessage.from_config uses to_dict) - config_dict = config.to_dict() - config_path = output_dir / f"{timestamp}_bench_config.json" - config_path.write_bytes(to_json(config_dict, True)) + # Enable debug mode and output directory when --out is specified + # Core will auto-save log, trace, result.json, and audio files + if args.out: + # config.debug = True + config.output_dir = Path(args.out) + print(f"Files will be saved to {args.out}") # Create progress bar with language info progress_bar[0] = tqdm( total=100, - desc=f"Processing {source_lang}→{target_lang}", + desc=f"Processing {config.source_lang}→{config.target_lang}", unit="%", mininterval=7.0, bar_format="{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]" ) - print(f"Running benchmark: {source_lang} → {target_lang}") - if args.out: - print(f"Files will be saved to {args.out}") + print(f"Running benchmark: {config.source_lang} → {config.target_lang}") print("-" * 60) palabra = PalabraAI() result = palabra.run(config, no_raise=True) - # Save RunResult in debug mode when --out is specified - if output_dir and timestamp and result is not None: - try: - result_debug_path = output_dir / f"{timestamp}_bench_runresult_debug.json" - result_debug_path.write_bytes(to_json(result.model_dump(), True)) - except Exception as e: - # If serialization fails, save error info - error_path = output_dir / f"{timestamp}_bench_runresult_error.txt" - error_path.write_text( - f"Failed to serialize RunResult: {e}\n\n" - f"RunResult repr:\n{repr(result)}\n\n" - f"Exception: {result.exc if result else 'N/A'}" - ) - # Complete and close progress bar if progress_bar[0]: progress_bar[0].update(100 - progress_bar[0].n) @@ -183,25 +144,11 @@ def on_transcription(msg): print(" - Task was cancelled by timeout") print(" - Internal cancellation due to error") print(" - One of the subtasks failed and caused cascade cancellation\n") - - # For CancelledError, show ALL logs to understand what happened - if result.log_data and result.log_data.logs: - print(f"Full logs (all {len(result.log_data.logs)} entries):") - for log_line in result.log_data.logs: - print(log_line, end='') - print() else: print(f"\n{'='*80}") print(f"BENCHMARK FAILED: {exc_type}: {exc_msg}") print(f"{'='*80}\n") - # For other errors, show last 100 - if result.log_data and result.log_data.logs: - print("Last 100 log entries:") - for log_line in result.log_data.logs[-100:]: - print(log_line, end='') - print() - # Print traceback from exception if available if hasattr(result.exc, '__traceback__') and result.exc.__traceback__: print("\nOriginal exception traceback:") @@ -212,47 +159,12 @@ def on_transcription(msg): raise RuntimeError("Benchmark failed: no io_data") # Parse report - report, in_audio_canvas, out_audio_canvas = Report.parse(result.io_data) - - # Create file paths (used in report and optionally saved with --out) - if not timestamp: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - in_wav_name = f"{timestamp}_bench_in_{source_lang}.wav" - out_wav_name = f"{timestamp}_bench_out_{target_lang}.wav" - - # Generate text report - report_text = format_report( - report, - result.io_data, - source_lang, - target_lang, - str(audio_path), - out_wav_name, - config - ) - if args.out: - # Use the shared save function - if not output_dir: - output_dir = args.out - save_benchmark_files( - output_dir=output_dir, - timestamp=timestamp, - report=report, - io_data=result.io_data, - config=config, - result=result, - in_audio_canvas=in_audio_canvas, - out_audio_canvas=out_audio_canvas, - source_lang=source_lang, - target_lang=target_lang, - report_text=report_text, - input_file_path=str(audio_path), - file_prefix="bench" - ) - - # Always print report to console - print("\n" + report_text) + report = Report.parse(result.io_data, Path(args.out)) + report.save_all() + else: + report = Report.parse(result.io_data) + print("\n" + report.report_txt) except Exception as e: # Capture traceback IMMEDIATELY - must be done in except block! @@ -264,34 +176,20 @@ def on_transcription(msg): print(f"{'='*80}\n") print(tb_string) - # Save error to file if output directory exists - if output_dir and timestamp: - try: - error_file = output_dir / f"{timestamp}_bench_error.txt" - error_file.write_text(f"Benchmark Error:\n\n{tb_string}") - print(f"\nError details saved to: {error_file}") - except Exception as save_error: - print(f"Failed to save error file: {save_error}") + if config and args.out: + save_text(config.get_out_path(".error.txt"), f"Benchmark Error:\n\n{tb_string}") # Try to save partial report/audio even on error (for debugging) - if output_dir and timestamp and result and result.io_data: + if result and result.io_data: try: print("\nAttempting to save partial results for debugging...") # Try to parse report - report, in_audio, out_audio = Report.parse(result.io_data) - - # Save report files - report_path = output_dir / f"{timestamp}_bench_report_partial.json" - report_path.write_bytes(to_json(report, True)) - print(f"✓ Partial report saved to: {report_path}") - - # Save audio (always when --out is specified) - in_wav = output_dir / f"{timestamp}_bench_in_partial.wav" - out_wav = output_dir / f"{timestamp}_bench_out_partial.wav" - save_wav(in_audio, in_wav, result.io_data.in_sr, result.io_data.channels) - save_wav(out_audio, out_wav, result.io_data.out_sr, result.io_data.channels) - print(f"✓ Partial audio saved: {in_wav.name}, {out_wav.name}") + if args.out: + output_dir = Path(args.out) + report = Report.parse(result.io_data, output_dir) + report.save_all() + print(f"✓ Something saved to: {args.out}") except Exception as save_err: print(f"Could not save partial results: {save_err}") diff --git a/src/palabra_ai/benchmark/report.py b/src/palabra_ai/benchmark/report.py index 9535997..c0cd82f 100644 --- a/src/palabra_ai/benchmark/report.py +++ b/src/palabra_ai/benchmark/report.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from dataclasses import field +from functools import cached_property from pathlib import Path from typing import Any from typing import NamedTuple @@ -20,10 +21,14 @@ from palabra_ai import Message from palabra_ai.audio import save_wav from palabra_ai.benchmark.utils import flatten_container_to_paths, _format_value +from palabra_ai.constant import OUT_IN_AUDIO_SUFFIX +from palabra_ai.constant import OUT_OUT_AUDIO_SUFFIX +from palabra_ai.constant import OUT_REPORT_SUFFIX from palabra_ai.message import IoEvent from palabra_ai.model import IoData +from palabra_ai.util.fileio import save_text from palabra_ai.util.orjson import to_json @@ -143,9 +148,13 @@ class AudioStat: @dataclass class Report: + io_data: IoData + cfg: Config sentences: dict[str, Sentence] = field(default_factory=dict) # transcription_id -> Sentence in_audio_stat: AudioStat | None = None out_audio_stat: AudioStat | None = None + in_audio_canvas: np.typing.NDArray | None = None + out_audio_canvas: np.typing.NDArray | None = None metrics_summary: dict[str, dict[str, float]] = field(default_factory=dict) # metric_name -> {min, max, avg, p50, p90, p95} set_task_e: IoEvent | None = None current_task_e: IoEvent | None = None @@ -184,11 +193,10 @@ def playback(cls, events: list[IoEvent], sr: int, ch: int): start_idx_aligned = round(start_idx_rough / ch) * ch cls.put_audio_to_canvas(audio_canvas, start_idx_aligned, e) return audio_canvas, AudioStat(playback_pos, tids_with_actual_tts_playback, deltas) - return playback_pos, audio_canvas, deltas, tids_with_actual_tts_playback @classmethod - def parse(cls, io_data: IoData) -> Self: + def parse(cls, io_data: IoData, output_dir: Path | None = None) -> Self: sentences = {} all_with_tid: list[IoEvent] = [] @@ -218,6 +226,10 @@ def parse(cls, io_data: IoData) -> Self: raise ValueError("Multiple current_task events found, old: {}, new: {}".format(current_task_e, e)) current_task_e = e + cfg = Config.from_dict(set_task_e.body["data"]) + if output_dir is not None: + cfg.output_dir = output_dir + focused_by_tid = defaultdict(list) for fe in focused: focused_by_tid[fe.tid].append(fe) @@ -332,7 +344,164 @@ def parse(cls, io_data: IoData) -> Self: if values: metrics_summary[metric_name] = calculate_stats(values) - return cls(sentences=sentences, in_audio_stat=in_audio_stat, out_audio_stat=out_audio_stat, metrics_summary=metrics_summary, set_task_e=set_task_e, current_task_e=current_task_e), in_audio_canvas, out_audio_canvas + return cls(io_data=io_data, cfg=cfg, sentences=sentences, in_audio_stat=in_audio_stat, out_audio_stat=out_audio_stat, in_audio_canvas=in_audio_canvas, out_audio_canvas=out_audio_canvas, metrics_summary=metrics_summary, set_task_e=set_task_e, current_task_e=current_task_e) + + def save_in(self, path: Path) -> None: + save_wav(self.in_audio_canvas, path, self.io_data.in_sr, self.io_data.channels) + + def save_out(self, path: Path) -> None: + save_wav(self.out_audio_canvas, path, self.io_data.out_sr, self.io_data.channels) + + def save_txt(self) -> None: + save_text(self.cfg.get_out_path(OUT_REPORT_SUFFIX), self.report_txt) + + def save_all(self): + self.save_txt() + self.save_in(self.cfg.get_out_path(OUT_IN_AUDIO_SUFFIX)) + self.save_out(self.cfg.get_out_path(OUT_OUT_AUDIO_SUFFIX)) + + @cached_property + def report_txt(self) -> str: + return self.format_report() + + def format_report(self) -> str: + """Format report as text with tables and histogram""" + lines = [] + lines.append("=" * 80) + lines.append("PALABRA AI BENCHMARK REPORT") + lines.append("=" * 80) + lines.append("") + + # Mode and audio info + mode_name = "WebRTC" if self.io_data.mode == "webrtc" else "Websocket" + lines.append(f"Mode: {mode_name}") + + # Input/Output info + in_dur = f"{self.in_audio_stat.length_s:.1f}s" if self.in_audio_stat else "?.?s" + out_dur = f"{self.out_audio_stat.length_s:.1f}s" if self.out_audio_stat else "?.?s" + lines.append(f"Reader: [{in_dur}, {self.io_data.in_sr}hz, 16bit, PCM] {self.io_data.reader_x_title}") + lines.append(f"Writer: [{out_dur}, {self.io_data.out_sr}hz, 16bit, PCM] {self.io_data.writer_x_title}") + if self.cfg.output_dir: + lines.append(f"Out dir: {self.cfg.output_dir}") + + # TTS autotempo info + queue_config = self.cfg.translation_queue_configs.global_ if self.cfg.translation_queue_configs else None + if queue_config: + if queue_config.auto_tempo: + lines.append(f"TTS autotempo: ✅ on ({queue_config.min_tempo}-{queue_config.max_tempo})") + else: + lines.append(f"TTS autotempo: ❌ off") + + # CONFIG - comparison of sent vs applied settings + lines.append("") + lines.append("CONFIG (sent vs applied)") + lines.append("-" * 80) + + set_task_data = self.set_task_e.body["data"] if self.set_task_e else {} + current_task_data = self.current_task_e.body["data"] if self.current_task_e else {} + + sent_paths = [(k, _format_value(v)) for k,v in flatten_container_to_paths(set_task_data)] + applied_paths = [(k, _format_value(v)) for k,v in flatten_container_to_paths(current_task_data)] + + # Merge settings using full outer join + merged_settings = merge_task_settings(sent_paths, applied_paths) + + table = PrettyTable() + table.field_names = ["Key", "Sent", "Applied"] + table.align["Key"] = "l" + table.align["Sent"] = "l" + table.align["Applied"] = "l" + + for key, sent_value, applied_value in merged_settings: + sent_str = sent_value if sent_value is not None else "" + applied_str = applied_value if applied_value is not None else "" + table.add_row([key, sent_str, applied_str]) + + lines.append(str(table)) + lines.append("") + + # Metrics summary table + if self.metrics_summary: + lines.append("METRICS SUMMARY") + lines.append("-" * 80) + table = PrettyTable() + table.field_names = ["Metric", "Min", "Max", "Avg", "P50", "P90", "P95"] + + metric_labels = { + "metric_partial": "Partial", + "metric_validated": "Validated", + "metric_translated": "Translated", + "metric_tts_api": "TTS API", + "metric_tts_playback": "TTS Playback" + } + + for metric_name, stats in self.metrics_summary.items(): + label = metric_labels.get(metric_name, metric_name) + table.add_row([ + label, + f"{stats['min']:.3f}", + f"{stats['max']:.3f}", + f"{stats['avg']:.3f}", + f"{stats['p50']:.3f}", + f"{stats['p90']:.3f}", + f"{stats['p95']:.3f}" + ]) + + lines.append(str(table)) + lines.append("") + + # Sentences breakdown + if self.sentences: + lines.append("SENTENCES BREAKDOWN") + lines.append("-" * 80) + table = PrettyTable() + table.field_names = ["Start", "ID", "Validated", "Translated", "Part", "Valid", "Trans", "TTS API", "TTS Play"] + table.align["ID"] = "l" + table.align["Validated"] = "l" + table.align["Translated"] = "l" + + sorted_sentences = sorted(self.sentences.items(), key=lambda x: x[1].local_start_ts) + global_start = sorted_sentences[0][1].local_start_ts if sorted_sentences else 0 + + for raw_tid, sentence in sorted_sentences: + tid = Tid.parse(raw_tid) + + if sentence.has_metrics: + start_time = sentence.local_start_ts - global_start + table.add_row([ + f"{start_time:.1f}s", + tid.display, + truncate_text(sentence.validated_text), + truncate_text(sentence.translated_text), + f"{sentence.metric_partial:.2f}" if sentence.metric_partial is not None else "-", + f"{sentence.metric_validated:.2f}" if sentence.metric_validated else "-", + f"{sentence.metric_translated:.2f}" if sentence.metric_translated else "-", + f"{sentence.metric_tts_api:.2f}" if sentence.metric_tts_api else "-", + f"{sentence.metric_tts_playback:.2f}" if sentence.metric_tts_playback else "-" + ]) + else: + # Text-only row for extra_parts (_part_1+) + table.add_row([ + "", # no start time + tid.display, + truncate_text(sentence.validated_text), + truncate_text(sentence.translated_text), + "", "", "", "", "" # no metrics + ]) + + lines.append(str(table)) + lines.append("") + + # Histogram for TTS playback + if "metric_tts_playback" in self.metrics_summary: + lines.append("TTS PLAYBACK HISTOGRAM") + lines.append("-" * 80) + playback_values = [s.metric_tts_playback for s in self.sentences.values() if s.metric_tts_playback is not None] + lines.append(create_histogram(playback_values)) + lines.append("") + + lines.append("=" * 80) + return "\n".join(lines) def create_histogram(values: list[float], bins: int = 20, width: int = 50) -> str: @@ -408,142 +577,6 @@ def sort_key(item): return result -def format_report(report: Report, io_data: IoData, source_lang: str, target_lang: str, in_file: str, out_file: str, config: Config) -> str: - """Format report as text with tables and histogram""" - lines = [] - lines.append("=" * 80) - lines.append("PALABRA AI BENCHMARK REPORT") - lines.append("=" * 80) - lines.append("") - - # Mode and audio info - mode_name = "WebRTC" if io_data.mode == "webrtc" else "Websocket" - lines.append(f"Mode: {mode_name}") - - # Input/Output info - in_dur = f"{report.in_audio_stat.length_s:.1f}s" if report.in_audio_stat else "?.?s" - out_dur = f"{report.out_audio_stat.length_s:.1f}s" if report.out_audio_stat else "?.?s" - lines.append(f"Input: [{in_dur}, {io_data.in_sr}hz, 16bit, PCM] {in_file}") - lines.append(f"Output: [{out_dur}, {io_data.out_sr}hz, 16bit, PCM] {out_file}") - - # TTS autotempo info - queue_config = config.translation_queue_configs.global_ if config.translation_queue_configs else None - if queue_config: - if queue_config.auto_tempo: - lines.append(f"TTS autotempo: ✅ on ({queue_config.min_tempo}-{queue_config.max_tempo})") - else: - lines.append(f"TTS autotempo: ❌ off") - - # CONFIG - comparison of sent vs applied settings - lines.append("") - lines.append("CONFIG (sent vs applied)") - lines.append("-" * 80) - - set_task_data = report.set_task_e.body["data"] if report.set_task_e else {} - current_task_data = report.current_task_e.body["data"] if report.current_task_e else {} - - sent_paths = [(k, _format_value(v)) for k,v in flatten_container_to_paths(set_task_data)] - applied_paths = [(k, _format_value(v)) for k,v in flatten_container_to_paths(current_task_data)] - - # Merge settings using full outer join - merged_settings = merge_task_settings(sent_paths, applied_paths) - - table = PrettyTable() - table.field_names = ["Key", "Sent", "Applied"] - table.align["Key"] = "l" - table.align["Sent"] = "l" - table.align["Applied"] = "l" - - for key, sent_value, applied_value in merged_settings: - sent_str = sent_value if sent_value is not None else "" - applied_str = applied_value if applied_value is not None else "" - table.add_row([key, sent_str, applied_str]) - - lines.append(str(table)) - lines.append("") - - # Metrics summary table - if report.metrics_summary: - lines.append("METRICS SUMMARY") - lines.append("-" * 80) - table = PrettyTable() - table.field_names = ["Metric", "Min", "Max", "Avg", "P50", "P90", "P95"] - - metric_labels = { - "metric_partial": "Partial", - "metric_validated": "Validated", - "metric_translated": "Translated", - "metric_tts_api": "TTS API", - "metric_tts_playback": "TTS Playback" - } - - for metric_name, stats in report.metrics_summary.items(): - label = metric_labels.get(metric_name, metric_name) - table.add_row([ - label, - f"{stats['min']:.3f}", - f"{stats['max']:.3f}", - f"{stats['avg']:.3f}", - f"{stats['p50']:.3f}", - f"{stats['p90']:.3f}", - f"{stats['p95']:.3f}" - ]) - - lines.append(str(table)) - lines.append("") - - # Sentences breakdown - if report.sentences: - lines.append("SENTENCES BREAKDOWN") - lines.append("-" * 80) - table = PrettyTable() - table.field_names = ["Start", "ID", "Validated", "Translated", "Part", "Valid", "Trans", "TTS API", "TTS Play"] - table.align["ID"] = "l" - table.align["Validated"] = "l" - table.align["Translated"] = "l" - - sorted_sentences = sorted(report.sentences.items(), key=lambda x: x[1].local_start_ts) - global_start = sorted_sentences[0][1].local_start_ts if sorted_sentences else 0 - - for raw_tid, sentence in sorted_sentences: - tid = Tid.parse(raw_tid) - - if sentence.has_metrics: - start_time = sentence.local_start_ts - global_start - table.add_row([ - f"{start_time:.1f}s", - tid.display, - truncate_text(sentence.validated_text), - truncate_text(sentence.translated_text), - f"{sentence.metric_partial:.2f}" if sentence.metric_partial is not None else "-", - f"{sentence.metric_validated:.2f}" if sentence.metric_validated else "-", - f"{sentence.metric_translated:.2f}" if sentence.metric_translated else "-", - f"{sentence.metric_tts_api:.2f}" if sentence.metric_tts_api else "-", - f"{sentence.metric_tts_playback:.2f}" if sentence.metric_tts_playback else "-" - ]) - else: - # Text-only row for extra_parts (_part_1+) - table.add_row([ - "", # no start time - tid.display, - truncate_text(sentence.validated_text), - truncate_text(sentence.translated_text), - "", "", "", "", "" # no metrics - ]) - - lines.append(str(table)) - lines.append("") - - # Histogram for TTS playback - if "metric_tts_playback" in report.metrics_summary: - lines.append("TTS PLAYBACK HISTOGRAM") - lines.append("-" * 80) - playback_values = [s.metric_tts_playback for s in report.sentences.values() if s.metric_tts_playback is not None] - lines.append(create_histogram(playback_values)) - lines.append("") - - lines.append("=" * 80) - return "\n".join(lines) def save_benchmark_files( diff --git a/src/palabra_ai/benchmark/rewind.py b/src/palabra_ai/benchmark/rewind.py index 30d8bc6..4c5c116 100644 --- a/src/palabra_ai/benchmark/rewind.py +++ b/src/palabra_ai/benchmark/rewind.py @@ -11,7 +11,6 @@ from palabra_ai.model import IoData from palabra_ai.message import IoEvent, Dbg from palabra_ai.benchmark.report import save_benchmark_files -from palabra_ai.benchmark.report import format_report from palabra_ai.benchmark.report import Report from palabra_ai import Config @@ -51,7 +50,9 @@ def load_run_result(file_path: Path) -> IoData: mode=io_data_dict['mode'], channels=io_data_dict['channels'], events=events, - count_events=len(events) + count_events=len(events), + reader_x_title=io_data_dict.get('reader_x_title', "n/a"), + writer_x_title=io_data_dict.get('writer_x_title', "n/a") ) return io_data @@ -62,53 +63,23 @@ def main(): parser.add_argument("--out", type=Path, help="Output directory for reconstructed files (if not specified, only prints to console)") args = parser.parse_args() - try: - file_path = Path(args.run_result) - - # Load IoData - io_data = load_run_result(file_path) - - # Parse report (same as main benchmark) - report, in_audio_canvas, out_audio_canvas = Report.parse(io_data) - - config = Config.from_dict(report.set_task_e.body["data"]) - - # Extract languages from config - source_lang = config.source.lang.code - target_lang = config.targets[0].lang.code - - # Generate report using existing format_report function directly - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - in_file = str(file_path) - out_file = f"{timestamp}_rewind_out_{target_lang}.wav" - - report_text = format_report(report, io_data, source_lang, target_lang, in_file, out_file, config) - - # Save files if --out option is specified - if args.out: - save_benchmark_files( - output_dir=args.out, - timestamp=timestamp, - report=report, - io_data=io_data, - config=config, - result=None, # No RunResult in rewind - in_audio_canvas=in_audio_canvas, - out_audio_canvas=out_audio_canvas, - source_lang=source_lang, - target_lang=target_lang, - report_text=report_text, - input_file_path=str(file_path), - file_prefix="rewind" - ) - print(f"\nFiles saved to: {args.out}") - - # Always print report to console - print("\n" + report_text) - - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - sys.exit(1) + + file_path = Path(args.run_result) + + # Load IoData + io_data = load_run_result(file_path) + + # Parse report (same as main benchmark) + if args.out: + report = Report.parse(io_data, Path(args.out)) + report.save_all() + print(f"\nFiles saved to: {args.out}") + else: + report = Report.parse(io_data) + + # Always print report to console + print("\n" + report.report_txt) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/palabra_ai/client.py b/src/palabra_ai/client.py index 9c34618..e933366 100644 --- a/src/palabra_ai/client.py +++ b/src/palabra_ai/client.py @@ -2,19 +2,65 @@ import asyncio import contextlib +import functools from collections.abc import AsyncIterator from dataclasses import dataclass, field from aioshutdown import SIGHUP, SIGINT, SIGTERM from palabra_ai.config import CLIENT_ID, CLIENT_SECRET, DEEP_DEBUG, Config +from palabra_ai.constant import ( + OUT_CONFIG_SUFFIX, + OUT_EXTRA_CONFIG_SUFFIX, + OUT_RUN_RESULT_SUFFIX, + OUT_SYSINFO_SUFFIX, +) from palabra_ai.debug.hang_coroutines import diagnose_hanging_tasks from palabra_ai.exc import ConfigurationError, unwrap_exceptions from palabra_ai.internal.rest import PalabraRESTClient, SessionCredentials from palabra_ai.model import RunResult from palabra_ai.task.base import TaskEvent from palabra_ai.task.manager import Manager -from palabra_ai.util.logger import debug, error, exception, success +from palabra_ai.util.fileio import save_json +from palabra_ai.util.logger import debug, error, exception, success, warning +from palabra_ai.util.sysinfo import get_system_info + + +def with_config_save(func): + @functools.wraps(func) + async def wrapper(self, config, *args, **kwargs): + def safe_save(suffix, data, saver_fn=functools.partial(save_json, indent=True)): + path = None + try: + path = config.get_out_path(suffix) + debug(f"Saving {path}") + saver_fn(path, data) + except Exception as e: + warning(f"⚠️ Exception during save [{suffix}] {path}: {e!r}") + + if not config.output_dir: + warning("Output directory not set, skipping config and sysinfo save") + return await func(self, config, *args, **kwargs) + + safe_save(OUT_SYSINFO_SUFFIX, get_system_info()) + safe_save(OUT_CONFIG_SUFFIX, config.to_dict()) + safe_save(OUT_EXTRA_CONFIG_SUFFIX, config.to_extra_dict()) + + # Run the original async method + result = await func(self, config, *args, **kwargs) + + safe_save(OUT_RUN_RESULT_SUFFIX, result) + + try: + from palabra_ai.benchmark.report import Report + + report = Report.parse(result.io_data, config.output_dir) + report.save_all() + except Exception as e: + warning(f"⚠️ Exception saving run result in {func.__name__}: {e!r}") + return result + + return wrapper @dataclass @@ -86,6 +132,7 @@ def run( finally: debug("Shutdown complete") + @with_config_save async def arun( self, cfg: Config, @@ -107,7 +154,6 @@ async def arun( """ async def _run_with_result(manager: Manager) -> RunResult: - log_data = None exc = None ok = False @@ -127,32 +173,6 @@ async def _run_with_result(manager: Manager) -> RunResult: exception("Error in manager task") exc = e - # CRITICAL: Always try to get log_data from logger - try: - if manager.logger and manager.logger._task: - # Give logger time to complete if still running - if not manager.logger._task.done(): - debug("Waiting for logger to complete...") - try: - await asyncio.wait_for(manager.logger._task, timeout=5.0) - except (TimeoutError, asyncio.CancelledError): - debug( - "Logger task timeout or cancelled, checking result anyway" - ) - - # Try to get the result - log_data = manager.logger.result - if not log_data: - debug("Logger.result is None, trying to call exit() directly") - try: - log_data = await asyncio.wait_for( - manager.logger.exit(), timeout=2.0 - ) - except Exception as e: - debug(f"Failed to get log_data from logger.exit(): {e}") - except Exception: - exception("Failed to retrieve log_data") - # Check if EOS was received (only relevant for WS) eos_received = manager.io.eos_received if manager.io else False @@ -161,7 +181,7 @@ async def _run_with_result(manager: Manager) -> RunResult: return RunResult( ok=ok, exc=exc if not ok else None, - log_data=log_data, + # log_data=log_data, io_data=manager.io.io_data if manager.io else None, eos=eos_received, ) @@ -169,7 +189,7 @@ async def _run_with_result(manager: Manager) -> RunResult: return RunResult( ok=True, exc=None, - log_data=log_data, + # log_data=log_data, io_data=manager.io.io_data if manager.io else None, eos=eos_received, ) @@ -178,6 +198,7 @@ async def _run_with_result(manager: Manager) -> RunResult: raise exc try: + cfg.set_logging() async with self.process(cfg, stopper) as manager: if DEEP_DEBUG: debug(diagnose_hanging_tasks()) diff --git a/src/palabra_ai/config.py b/src/palabra_ai/config.py index ee657c1..a933e6e 100644 --- a/src/palabra_ai/config.py +++ b/src/palabra_ai/config.py @@ -36,6 +36,8 @@ MIN_SPLIT_INTERVAL_DEFAULT, MIN_TRANSCRIPTION_LEN_DEFAULT, MIN_TRANSCRIPTION_TIME_DEFAULT, + OUT_LOG_SUFFIX, + OUT_TRACE_SUFFIX, PHRASE_CHANCE_DEFAULT, QUEUE_MAX_TEMPO, QUEUE_MIN_TEMPO, @@ -60,6 +62,7 @@ from palabra_ai.lang import Language, is_valid_source_language, is_valid_target_language from palabra_ai.message import Message from palabra_ai.types import T_ON_TRANSCRIPTION +from palabra_ai.util.dt import get_now_strftime from palabra_ai.util.logger import set_logging from palabra_ai.util.orjson import from_json, to_json from palabra_ai.util.pydantic import mark_fields_as_set @@ -280,7 +283,7 @@ class WsMode(IoMode): input_sample_rate: int = WS_MODE_INPUT_SAMPLE_RATE output_sample_rate: int = WS_MODE_OUTPUT_SAMPLE_RATE num_channels: int = WS_MODE_CHANNELS - input_chunk_duration_ms: int = WS_MODE_CHUNK_DURATION_MS + input_chunk_duration_ms: float = WS_MODE_CHUNK_DURATION_MS def model_dump(self, *args, **kwargs) -> dict[str, Any]: return { @@ -534,6 +537,8 @@ class Config(BaseModel): rich_default_config: SkipJsonSchema[bool] = Field( default=RICH_DEFAULT_CONFIG, exclude=True ) + x_output_dir: Path | None = Field(default=None, exclude=True) + run_name: str = Field(default_factory=lambda: get_now_strftime(), exclude=True) def __init__( self, @@ -548,7 +553,35 @@ def __init__( self.targets = targets self._ensure_default_fields_are_set() + @model_validator(mode="before") + @classmethod + def _handle_output_dir(cls, data: Any) -> Any: + if isinstance(data, dict) and "output_dir" in data: + data["x_output_dir"] = data.pop("output_dir") + return data + + @property + def output_dir(self) -> Path | None: + return self.x_output_dir + + @output_dir.setter + def output_dir(self, value): + self.x_output_dir = Path(value).absolute() + self.x_output_dir.mkdir(exist_ok=True, parents=True) + + # Only auto-set log_file if not explicitly provided + if "log_file" not in self.model_fields_set: + self.log_file = self.get_out_path(OUT_LOG_SUFFIX) + + # Only auto-set trace_file if not explicitly provided + if "trace_file" not in self.model_fields_set: + self.trace_file = self.get_out_path(OUT_TRACE_SUFFIX) + def model_post_init(self, context: Any, /) -> None: + # Trigger output_dir setter if output_dir was provided + if self.x_output_dir is not None: + self.output_dir = self.x_output_dir + if self.targets is None: self.targets = [] elif isinstance(self.targets, TargetLang): @@ -556,11 +589,52 @@ def model_post_init(self, context: Any, /) -> None: if self.log_file: self.log_file = Path(self.log_file).absolute() self.log_file.parent.mkdir(exist_ok=True, parents=True) - self.trace_file = self.log_file.with_suffix(".trace.json") - set_logging(self.silent, self.debug, self.internal_logs, self.log_file) + + # Only auto-set trace_file if not already set by output_dir + if not self.trace_file: + self.trace_file = self.log_file.with_suffix(OUT_TRACE_SUFFIX) + self._ensure_default_fields_are_set() super().model_post_init(context) + def set_logging(self): + if not self.log_file: + return + set_logging(self.silent, self.debug, self.internal_logs, self.log_file) + + def get_out_path(self, suffix: str) -> Path | None: + if self.output_dir is None: + return None + return self.output_dir / f"{self.run_name}{suffix}" + + @property + def source_lang(self) -> str: + """Get source language code. + + Returns: + str: Source language code (e.g., 'en', 'es', 'ru') + + Raises: + ConfigurationError: If source is not configured + """ + if not self.source: + raise ConfigurationError("Source language not configured") + return self.source.lang.source_code + + @property + def target_lang(self) -> str: + """Get first target language code. + + Returns: + str: First target language code (e.g., 'en', 'es', 'ru') + + Raises: + ConfigurationError: If no targets configured + """ + if not self.targets or len(self.targets) == 0: + raise ConfigurationError("Target language not configured") + return self.targets[0].lang.target_code + def _ensure_default_fields_are_set(self) -> None: """ Ensure that essential default fields are marked as 'set' in Pydantic @@ -673,6 +747,20 @@ def model_dump( result = {**data, **{"pipeline": pipeline}, **self.mode.model_dump()} return result + def to_extra_dict(self) -> dict[str, Any]: + return { + "silent": self.silent, + "debug": self.debug, + "log_file": self.log_file, + "trace_file": self.trace_file, + "drop_empty_frames": self.drop_empty_frames, + "timeout": self.timeout, + "benchmark": self.benchmark, + "rich_default_config": self.rich_default_config, + "output_dir": self.output_dir, + "run_name": self.run_name, + } + def to_dict(self, full: bool = False) -> dict[str, Any]: """ Convert config to dict. diff --git a/src/palabra_ai/constant.py b/src/palabra_ai/constant.py index 6c3b3f3..8e0962b 100644 --- a/src/palabra_ai/constant.py +++ b/src/palabra_ai/constant.py @@ -93,11 +93,21 @@ WS_MODE_INPUT_SAMPLE_RATE = 16000 WS_MODE_OUTPUT_SAMPLE_RATE = 24000 WS_MODE_CHANNELS = 1 -WS_MODE_CHUNK_DURATION_MS = 320 +WS_MODE_CHUNK_DURATION_MS = 320.0 WEBRTC_MODE_INPUT_SAMPLE_RATE = 48000 WEBRTC_MODE_OUTPUT_SAMPLE_RATE = 48000 WEBRTC_MODE_CHANNELS = 1 -WEBRTC_MODE_CHUNK_DURATION_MS = 320 +WEBRTC_MODE_CHUNK_DURATION_MS = 320.0 EOF_SILENCE_DURATION_S = 10.0 + +OUT_SYSINFO_SUFFIX = ".sysinfo.json" +OUT_CONFIG_SUFFIX = ".config.json" +OUT_EXTRA_CONFIG_SUFFIX = ".extra_config.json" +OUT_RUN_RESULT_SUFFIX = ".run_result.json" +OUT_LOG_SUFFIX = ".log" +OUT_TRACE_SUFFIX = ".trace.log" +OUT_IN_AUDIO_SUFFIX = ".in.wav" +OUT_OUT_AUDIO_SUFFIX = ".out.wav" +OUT_REPORT_SUFFIX = ".report.txt" diff --git a/src/palabra_ai/model.py b/src/palabra_ai/model.py index 6f6998d..e263ce4 100644 --- a/src/palabra_ai/model.py +++ b/src/palabra_ai/model.py @@ -5,18 +5,6 @@ from palabra_ai.message import IoEvent -class LogData(BaseModel): - version: str - sysinfo: dict - messages: list[dict] - start_ts: float - cfg: dict - log_file: str - trace_file: str - debug: bool - logs: list[str] - - class IoData(BaseModel): model_config = {"use_enum_values": True} start_perf_ts: float @@ -27,12 +15,13 @@ class IoData(BaseModel): channels: int events: list[IoEvent] count_events: int + reader_x_title: str + writer_x_title: str class RunResult(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) ok: bool exc: BaseException | None = None - log_data: LogData | None = Field(default=None, repr=False) io_data: IoData | None = Field(default=None, repr=False) eos: bool = False diff --git a/src/palabra_ai/task/adapter/base.py b/src/palabra_ai/task/adapter/base.py index f7ca795..77f78fe 100644 --- a/src/palabra_ai/task/adapter/base.py +++ b/src/palabra_ai/task/adapter/base.py @@ -15,8 +15,18 @@ from palabra_ai.config import Config +class PathMinix(abc.ABC): + """Mixin for classes that have a path property.""" + + @property + @abc.abstractmethod + def x_title(self) -> str: + """Path or identifier of the audio source.""" + ... + + @dataclass -class Reader(Task): +class Reader(Task, PathMinix): """Abstract PCM audio reader process.""" _: KW_ONLY @@ -46,7 +56,7 @@ async def read(self, size: int) -> bytes | None: @dataclass -class Writer(Task): +class Writer(Task, PathMinix): _: KW_ONLY cfg: Config = field(default=None, init=False, repr=False) q: asyncio.Queue[AudioFrame | None] = field(default_factory=asyncio.Queue) @@ -116,6 +126,10 @@ class BufferedWriter(UnlimitedExitMixin, Writer): _: KW_ONLY ab: AudioBuffer | None = field(default=None, init=False) + @property + def x_title(self) -> str: + return "in-memory-buffer" + async def boot(self): # Create buffer with estimated duration from config self.ab = AudioBuffer( diff --git a/src/palabra_ai/task/adapter/buffer.py b/src/palabra_ai/task/adapter/buffer.py index f7c3627..e1cd2e8 100644 --- a/src/palabra_ai/task/adapter/buffer.py +++ b/src/palabra_ai/task/adapter/buffer.py @@ -21,6 +21,10 @@ class BufferReader(Reader): _: KW_ONLY _buffer_size: int | None = field(default=None, init=False, repr=False) + @property + def x_title(self) -> str: + return "in-memory-buffer" + def __post_init__(self): self._position = 0 current_pos = self.buffer.tell() @@ -70,6 +74,10 @@ class BufferWriter(BufferedWriter): buffer: io.BytesIO _: KW_ONLY + @property + def x_title(self) -> str: + return "in-memory-buffer" + async def boot(self): await super().boot() self.ab.replace_buffer(self.buffer) diff --git a/src/palabra_ai/task/adapter/device.py b/src/palabra_ai/task/adapter/device.py index 41832de..a4c30be 100644 --- a/src/palabra_ai/task/adapter/device.py +++ b/src/palabra_ai/task/adapter/device.py @@ -161,6 +161,12 @@ class DeviceReader(Reader): sdm: SoundDeviceManager = field(default_factory=SoundDeviceManager) tg: asyncio.TaskGroup | None = field(default=None, init=False) + @property + def x_title(self) -> str: + if isinstance(self.device, Device): + return self.device.name + return str(self.device) + def do_preprocess(self): """Set default duration for real-time input.""" self.duration = RT_DEFAULT_DURATION @@ -236,6 +242,12 @@ class DeviceWriter(Writer): init=False, ) + @property + def x_title(self) -> str: + if isinstance(self.device, Device): + return self.device.name + return str(self.device) + async def boot(self): self._sdm.tg = self.sub_tg device_name = ( diff --git a/src/palabra_ai/task/adapter/dummy.py b/src/palabra_ai/task/adapter/dummy.py index 370ebcd..0c1a3fe 100644 --- a/src/palabra_ai/task/adapter/dummy.py +++ b/src/palabra_ai/task/adapter/dummy.py @@ -16,6 +16,10 @@ class DummyReader(Reader): eof_after_reads: int | None = field(default=None, repr=False) _: KW_ONLY + @property + def x_title(self) -> str: + return "dummy" + async def boot(self): pass @@ -46,6 +50,10 @@ class DummyWriter(Writer): frames_processed: int = 0 _q_reader: asyncio.Task = field(default=None, init=False, repr=False) + @property + def x_title(self) -> str: + return "dummy" + async def q_reader(self): while not self.stopper and not self.eof: try: diff --git a/src/palabra_ai/task/adapter/file.py b/src/palabra_ai/task/adapter/file.py index 50d85eb..101ef42 100644 --- a/src/palabra_ai/task/adapter/file.py +++ b/src/palabra_ai/task/adapter/file.py @@ -54,6 +54,10 @@ def __post_init__(self): raise FileNotFoundError(f"File not found: {self.path}") self._buffer = deque() + @property + def x_title(self) -> str: + return str(self.path) + def _preprocess_audio(self): """Preprocess audio with configurable pipeline.""" # Setup progress bar @@ -234,6 +238,10 @@ def __post_init__(self): self.path = Path(self.path) self.path.parent.mkdir(parents=True, exist_ok=True) + @property + def x_title(self) -> str: + return str(self.path) + async def exit(self): """Write the buffered WAV data to file""" debug("Finalizing FileWriter...") diff --git a/src/palabra_ai/task/io/base.py b/src/palabra_ai/task/io/base.py index e82e91d..c849481 100644 --- a/src/palabra_ai/task/io/base.py +++ b/src/palabra_ai/task/io/base.py @@ -68,6 +68,8 @@ def io_data(self) -> IoData: channels=self.cfg.mode.num_channels, events=self.io_events, count_events=len(self.io_events), + reader_x_title=f"<{self.reader.__class__.__name__}: {self.reader.x_title}>", + writer_x_title=f"<{self.writer.__class__.__name__}: {self.writer.x_title}>", ) @property diff --git a/src/palabra_ai/task/logger.py b/src/palabra_ai/task/logger.py deleted file mode 100644 index 828486d..0000000 --- a/src/palabra_ai/task/logger.py +++ /dev/null @@ -1,212 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from dataclasses import KW_ONLY, asdict, dataclass, field - -import palabra_ai -from palabra_ai.config import ( - Config, -) -from palabra_ai.constant import ( - QUEUE_READ_TIMEOUT, - SHUTDOWN_TIMEOUT, - SLEEP_INTERVAL_DEFAULT, -) -from palabra_ai.message import Dbg -from palabra_ai.model import LogData -from palabra_ai.task.base import Task -from palabra_ai.task.io.base import Io - -# from palabra_ai.task.realtime import Realtime -from palabra_ai.util.fanout_queue import Subscription -from palabra_ai.util.logger import debug, error -from palabra_ai.util.orjson import to_json -from palabra_ai.util.sysinfo import get_system_info - - -@dataclass -class Logger(Task): - """Logs all WebSocket and WebRTC messages to files.""" - - cfg: Config - io: Io - _: KW_ONLY - _messages: list[dict] = field(default_factory=list, init=False) - _start_ts: float = field(default_factory=time.time, init=False) - _io_in_sub: Subscription | None = field(default=None, init=False) - _io_audio_in_sub: Subscription | None = field(default=None, init=False) - _io_audio_out_sub: Subscription | None = field(default=None, init=False) - _io_out_sub: Subscription | None = field(default=None, init=False) - _in_task: asyncio.Task | None = field(default=None, init=False) - _out_task: asyncio.Task | None = field(default=None, init=False) - _audio_inout_task: asyncio.Task | None = field(default=None, init=False) - - def __post_init__(self): - self._io_in_sub = self.io.in_msg_foq.subscribe(self, maxsize=0) - self._io_out_sub = self.io.out_msg_foq.subscribe(self, maxsize=0) - if self.cfg.benchmark: - self._io_audio_in_sub = self.io.bench_audio_foq.subscribe(self, maxsize=0) - - async def boot(self): - self._in_task = self.sub_tg.create_task( - self._consume(self._io_in_sub.q), name="Logger:io_in" - ) - self._out_task = self.sub_tg.create_task( - self._consume(self._io_out_sub.q), name="Logger:io_out" - ) - if self.cfg.benchmark: - self._audio_inout_task = self.sub_tg.create_task( - self._consume(self._io_audio_in_sub.q), name="Logger:io_audio_inout" - ) - debug(f"Logger started, writing to {self.cfg.log_file}") - - async def do(self): - # Wait for stopper - while not self.stopper: - await asyncio.sleep(SLEEP_INTERVAL_DEFAULT) - debug(f"{self.name} task stopped, exiting...") - - async def cancel_subtasks(self): - debug("Cancelling Logger subtasks...") - +self.stopper # noqa - tasks_to_wait = [] - if self._in_task and self._out_task: - tasks_to_wait.extend([self._in_task, self._out_task]) - if self.cfg.benchmark and self._audio_inout_task: - tasks_to_wait.append(self._audio_inout_task) - for t in tasks_to_wait: - t.cancel() - debug(f"Waiting for {len(tasks_to_wait)} tasks to complete...") - try: - await asyncio.gather( - *(asyncio.wait_for(t, timeout=SHUTDOWN_TIMEOUT) for t in tasks_to_wait), - return_exceptions=True, # This will return CancelledError instead of raising it - ) - debug("All Logger subtasks cancelled successfully") - except Exception: - debug("Some Logger subtasks were cancelled or failed") - - async def _exit(self): - debug(f"{self.name}._exit()") - return await self.exit() - - async def exit(self) -> LogData: - debug("Finalizing Logger...") - - # First create LogData BEFORE cancelling tasks - try: - self.cfg.internal_logs.seek(0) - logs = self.cfg.internal_logs.readlines() - debug(f"Collected {len(logs)} internal log lines") - - try: - sysinfo = get_system_info() - except BaseException as e: - sysinfo = {"error": str(e)} - - log_data = LogData( - version=getattr(palabra_ai, "__version__", "n/a"), - sysinfo=sysinfo, - messages=self._messages.copy(), # Copy to avoid losing data - start_ts=self._start_ts, - cfg=self.cfg.to_dict() if hasattr(self.cfg, "to_dict") else {}, - log_file=str(self.cfg.log_file), - trace_file=str(self.cfg.trace_file), - debug=self.cfg.debug, - logs=logs, - ) - - # CRITICAL: Save result immediately - self.result = log_data - debug( - f"Logger: Saved LogData with {len(self._messages)} messages to self.result" - ) - - # Save to file if needed - if self.cfg.trace_file: - try: - with open(self.cfg.trace_file, "wb") as f: - f.write(to_json(log_data)) - debug(f"Saved trace to {self.cfg.trace_file}") - except Exception as e: - error(f"Failed to save trace file: {e}") - - except Exception as e: - error(f"Failed to create LogData: {e}") - # Create minimal LogData with what we have - log_data = LogData( - version="error", - sysinfo={"error": str(e)}, - messages=self._messages.copy() if self._messages else [], - start_ts=self._start_ts, - cfg={}, - log_file="", - trace_file="", - debug=False, - logs=[], - ) - self.result = log_data - - # Now cancel tasks - try: - cancel_task = asyncio.create_task(self.cancel_subtasks()) - await asyncio.wait_for(cancel_task, timeout=2.0) - except TimeoutError: - debug("Logger subtasks cancellation timeout") - except Exception as e: - debug(f"Error cancelling logger subtasks: {e}") - - # Unsubscribe from queues - try: - self.io.in_msg_foq.unsubscribe(self) - self.io.out_msg_foq.unsubscribe(self) - if self.cfg.benchmark: - self.io.bench_audio_foq.unsubscribe(self) - debug("Unsubscribed from IO queues") - except Exception as e: - debug(f"Error unsubscribing: {e}") - - debug( - f"Logger.exit() completed, returning LogData with {len(log_data.messages)} messages" - ) - return log_data - - async def _exit(self): - return await self.exit() - - async def _consume(self, q: asyncio.Queue): - """Process WebSocket messages.""" - while not self.stopper: - try: - msg = await asyncio.wait_for(q.get(), timeout=QUEUE_READ_TIMEOUT) - if msg is None: - debug(f"Received None from {q}, stopping consumer") - break - - dbg_msg = asdict(getattr(msg, "_dbg", Dbg.empty())) - - # Convert enums to strings for benchmark compatibility - if "kind" in dbg_msg and dbg_msg["kind"] is not None: - dbg_msg["kind"] = dbg_msg["kind"].value - if "ch" in dbg_msg and dbg_msg["ch"] is not None: - dbg_msg["ch"] = dbg_msg["ch"].value - if "dir" in dbg_msg and dbg_msg["dir"] is not None: - dbg_msg["dir"] = dbg_msg["dir"].value - - if hasattr(msg, "model_dump"): - dbg_msg["msg"] = msg.model_dump() - elif hasattr(msg, "to_bench"): - dbg_msg["msg"] = msg.to_bench() - else: - raise TypeError( - f"Message {msg} does not have model_dump() or to_bench() method" - ) - self._messages.append(dbg_msg) - debug(f"Consumed message from {q}: {dbg_msg}") - q.task_done() - except TimeoutError: - continue - except asyncio.CancelledError: - debug(f"Consumer for {q} cancelled") - break diff --git a/src/palabra_ai/task/manager.py b/src/palabra_ai/task/manager.py index 8886c7c..e4ed62e 100644 --- a/src/palabra_ai/task/manager.py +++ b/src/palabra_ai/task/manager.py @@ -16,10 +16,7 @@ from palabra_ai.task.adapter.base import Reader, Writer from palabra_ai.task.adapter.dummy import DummyWriter from palabra_ai.task.base import Task - -# from palabra_ai.internal.webrtc import AudioTrackSettings from palabra_ai.task.io.base import Io -from palabra_ai.task.logger import Logger from palabra_ai.task.stat import Stat from palabra_ai.task.transcription import Transcription from palabra_ai.util.logger import debug, exception, success, warning @@ -40,7 +37,7 @@ class Manager(Task): io: Io = field(init=False) # sender: SenderSourceAudio = field(init=False) # receiver: ReceiverTranslatedAudio = field(init=False) - logger: Logger | None = field(default=None, init=False) + # logger: Logger | None = field(default=None, init=False) # rtmon: IoMon = field(init=False) transcription: Transcription = field(init=False) stat: Stat = field(init=False) @@ -101,11 +98,6 @@ def __post_init__(self): f"🔧 {self.name} using default estimated_duration={self.cfg.estimated_duration}s" ) - # if hasattr(self.writer, "set_track_settings"): - # self.writer.set_track_settings(self.track_settings) - # if hasattr(self.reader, "set_track_settings"): - # self.reader.set_track_settings(self.track_settings) - if not self.io_class: self.io_class = self.cfg.mode.get_io_class() self.io = self.io_class( @@ -115,41 +107,16 @@ def __post_init__(self): writer=self.writer, ) - # self.rt = Realtime(self.cfg, self.io) - self.logger = Logger(self.cfg, self.io) - # self.transcription = Transcription(self.cfg, self.io) - # - # self.receiver = ReceiverTranslatedAudio( - # self.cfg, - # self.writer, - # self.rt, - # target.lang, - # ) - - # self.sender = SenderSourceAudio( - # self.cfg, - # self.rt, - # self.reader, - # self.cfg.to_dict(), - # self.track_settings, - # ) - - # self.rtmon = IoMon(self.cfg, self.rt) self.tasks.extend( [ t for t in [ self.reader, - # self.sender, - # self.rt, self.io, - # self.receiver, self.writer, - # self.rtmon, self.transcription, - self.logger, self, self.stat, ] @@ -158,9 +125,6 @@ def __post_init__(self): ) async def start_system(self): - self.logger(self.root_tg) - await self.logger.ready - self.stat(self.root_tg) await self.stat.ready self._show_banner_loop = self.stat.run_banner() @@ -233,7 +197,6 @@ async def exit(self): debug(f"🔧 {self.name}.exit() exiting...") +self.stopper # noqa +self.stat.stopper # noqa - +self.logger.stopper # noqa debug(f"🔧 {self.name}.exit() tasks: {[t.name for t in self.tasks]}") # DON'T use _abort() - it's internal! # Cancel all subtasks properly diff --git a/src/palabra_ai/util/dt.py b/src/palabra_ai/util/dt.py new file mode 100644 index 0000000..682457f --- /dev/null +++ b/src/palabra_ai/util/dt.py @@ -0,0 +1,7 @@ +from datetime import UTC, datetime + + +def get_now_strftime() -> str: + """Get current UTC time as a formatted string.""" + + return datetime.now(UTC).strftime("%Y%m%dT%H%M%S") diff --git a/src/palabra_ai/util/fileio.py b/src/palabra_ai/util/fileio.py new file mode 100644 index 0000000..ace9b5b --- /dev/null +++ b/src/palabra_ai/util/fileio.py @@ -0,0 +1,25 @@ +from pathlib import Path +from typing import Any + +from palabra_ai.util.logger import warning +from palabra_ai.util.orjson import to_json + + +def save_json( + path: Path | None, obj: Any, indent: bool = False, sort_keys: bool = True +) -> None: + """Save object as JSON file.""" + if path is None: + warning(f"save_json: path is None, skipping save {obj=}") + return + path.write_bytes(to_json(obj, indent=indent, sort_keys=sort_keys)) + + +def save_text(path: Path | None, text: str) -> None: + """Save text to file.""" + if path is None: + warning( + f"save_text: path is None, skipping save {text[:30] if text else '???'}..." + ) + return + path.write_text(text) diff --git a/src/palabra_ai/util/sysinfo.py b/src/palabra_ai/util/sysinfo.py index fce0221..6816eaa 100644 --- a/src/palabra_ai/util/sysinfo.py +++ b/src/palabra_ai/util/sysinfo.py @@ -14,6 +14,7 @@ import sys import sysconfig from dataclasses import asdict, dataclass, field +from pathlib import Path from typing import Any try: @@ -48,6 +49,15 @@ class SystemInfo: """Collects basic system information for debugging production issues.""" # Python info + command: str = field(default_factory=lambda: " ".join(sys.argv)) + argv: str = field(default_factory=lambda: sys.argv) + cwd: str = field(default_factory=lambda: str(Path.cwd())) + + palabra_version: str = field( + default_factory=lambda: getattr( + sys.modules.get("palabra_ai"), "__version__", "n/a" + ) + ) python_version: str = field(default_factory=lambda: sys.version) python_version_info: dict[str, Any] = field( default_factory=lambda: { diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 781c3d8..0a046b2 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -201,232 +201,6 @@ def test_benchmark_exception_without_message(): assert e.__cause__ is original_exc -def test_benchmark_saves_error_to_file_with_out(): - """Test that benchmark saves error.txt when --out is specified""" - from palabra_ai.model import RunResult - from palabra_ai.benchmark.__main__ import main - from pathlib import Path - import tempfile - - original_exc = ValueError("Test error for saving") - failed_result = RunResult(ok=False, exc=original_exc, io_data=None) - - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - with patch('palabra_ai.benchmark.__main__.PalabraAI') as mock_palabra_class: - mock_palabra = MagicMock() - mock_palabra.run.return_value = failed_result - mock_palabra_class.return_value = mock_palabra - - with patch('sys.argv', ['benchmark', 'dummy.wav', 'en', 'es', '--out', str(output_dir)]): - with patch('palabra_ai.benchmark.__main__.Path') as mock_path_class: - def path_side_effect(path_str): - if 'dummy.wav' in str(path_str): - mock_path = MagicMock() - mock_path.exists.return_value = True - return mock_path - return Path(path_str) - mock_path_class.side_effect = path_side_effect - - with patch('av.open'): - with patch('palabra_ai.benchmark.__main__.FileReader'): - with patch('palabra_ai.benchmark.__main__.tqdm'): - with patch('palabra_ai.benchmark.__main__.get_system_info', return_value={"test": "info"}): - try: - main() - assert False, "main() should have raised RuntimeError" - except RuntimeError: - # Check that error file was created - error_files = list(output_dir.glob("*_bench_error.txt")) - assert len(error_files) == 1, f"Expected 1 error file, found {len(error_files)}" - - error_content = error_files[0].read_text() - assert "ValueError" in error_content - assert "Test error for saving" in error_content - assert "Traceback" in error_content or "traceback" in error_content - - -def test_benchmark_saves_sysinfo_on_start(): - """Test that benchmark saves sysinfo.json immediately when --out is specified""" - from palabra_ai.model import RunResult, IoData - from palabra_ai.benchmark.__main__ import main - from pathlib import Path - import tempfile - - # Create a successful result to avoid hitting error paths - io_data = IoData( - start_perf_ts=0.0, - start_utc_ts=0.0, - in_sr=16000, - out_sr=16000, - mode="ws", - channels=1, - events=[], - count_events=0 - ) - successful_result = RunResult(ok=True, exc=None, io_data=io_data) - - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - with patch('palabra_ai.benchmark.__main__.PalabraAI') as mock_palabra_class: - mock_palabra = MagicMock() - mock_palabra.run.return_value = successful_result - mock_palabra_class.return_value = mock_palabra - - with patch('sys.argv', ['benchmark', 'dummy.wav', 'en', 'es', '--out', str(output_dir)]): - with patch('palabra_ai.benchmark.__main__.Path') as mock_path_class: - def path_side_effect(path_str): - if 'dummy.wav' in str(path_str): - mock_path = MagicMock() - mock_path.exists.return_value = True - return mock_path - return Path(path_str) - mock_path_class.side_effect = path_side_effect - - with patch('av.open'): - with patch('palabra_ai.benchmark.__main__.FileReader'): - with patch('palabra_ai.benchmark.__main__.tqdm'): - with patch('palabra_ai.benchmark.__main__.get_system_info', return_value={"test": "sysinfo"}): - with patch('palabra_ai.benchmark.__main__.Report.parse', return_value=(MagicMock(), MagicMock(), MagicMock())): - with patch('palabra_ai.benchmark.__main__.format_report', return_value="Test report"): - with patch('palabra_ai.benchmark.__main__.save_wav'): - try: - main() - except Exception: - pass # We don't care if it fails, just checking sysinfo was saved - - # Check that sysinfo file was created - sysinfo_files = list(output_dir.glob("*_bench_sysinfo.json")) - assert len(sysinfo_files) >= 1, f"Expected at least 1 sysinfo file, found {len(sysinfo_files)}" - - -def test_benchmark_handles_cancelled_error(): - """Test that benchmark properly handles CancelledError with context""" - from palabra_ai.model import RunResult, LogData - from palabra_ai.benchmark.__main__ import main - import asyncio - from pathlib import Path - import tempfile - from io import StringIO - import sys - - # Create CancelledError with traceback - cancelled_exc = asyncio.CancelledError() - - # Create log data with many entries to test "all logs" output - log_entries = [f"Entry {i}: Log line {i}\n" for i in range(200)] - log_entries.extend([ - "2025-10-03 15:10:43.128 | SUCCESS | Starting...\n", - "2025-10-03 15:10:47.623 | INFO | Processing...\n", - "2025-10-03 15:10:50.090 | ERROR | Something went wrong\n", - "2025-10-03 15:10:50.327 | INFO | Cancelling...\n", - ]) - - log_data = LogData( - version="1.0.0", - sysinfo={"platform": "test"}, - messages=[], - start_ts=0.0, - cfg={"test": "config"}, - log_file="test.log", - trace_file="", - debug=True, - logs=log_entries - ) - - failed_result = RunResult(ok=False, exc=cancelled_exc, io_data=None, log_data=log_data) - - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - with patch('palabra_ai.benchmark.__main__.PalabraAI') as mock_palabra_class: - mock_palabra = MagicMock() - mock_palabra.run.return_value = failed_result - mock_palabra_class.return_value = mock_palabra - - with patch('sys.argv', ['benchmark', 'dummy.wav', 'en', 'es', '--out', str(output_dir)]): - with patch('palabra_ai.benchmark.__main__.Path') as mock_path_class: - def path_side_effect(path_str): - if 'dummy.wav' in str(path_str): - mock_path = MagicMock() - mock_path.exists.return_value = True - return mock_path - return Path(path_str) - mock_path_class.side_effect = path_side_effect - - with patch('av.open'): - with patch('palabra_ai.benchmark.__main__.FileReader'): - with patch('palabra_ai.benchmark.__main__.tqdm'): - with patch('palabra_ai.benchmark.__main__.get_system_info', return_value={"test": "info"}): - # Capture stdout to check that ALL logs are printed - captured_output = StringIO() - try: - with patch('sys.stdout', captured_output): - main() - assert False, "main() should have raised RuntimeError" - except RuntimeError as e: - assert "CancelledError" in str(e) - - # Check that RunResult debug file was saved - runresult_files = list(output_dir.glob("*_bench_runresult_debug.json")) - assert len(runresult_files) >= 1, f"Expected RunResult debug file, found {len(runresult_files)}" - - # Check that error file was saved - error_files = list(output_dir.glob("*_bench_error.txt")) - assert len(error_files) >= 1, f"Expected error file, found {len(error_files)}" - - # Check that output mentions "cascade cancellation" - output = captured_output.getvalue() - assert "cascade cancellation" in output, "Should mention cascade cancellation" - - # Check that ALL logs were printed (not just last 100) - assert f"Full logs (all {len(log_entries)} entries)" in output - # Check that first entry was printed (would not be if only last 100) - assert "Entry 0: Log line 0" in output - - -def test_sysinfo_contains_command(): - """Test that sysinfo.json contains command line information""" - from palabra_ai.benchmark.__main__ import main - from pathlib import Path - import tempfile - import json - - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - with patch('sys.argv', ['benchmark', 'test.wav', 'en', 'es', '--out', str(output_dir)]): - with patch('palabra_ai.benchmark.__main__.Path') as mock_path_class: - def path_side_effect(path_str): - if 'test.wav' in str(path_str): - mock_path = MagicMock() - mock_path.exists.return_value = True - return mock_path - return Path(path_str) - mock_path_class.side_effect = path_side_effect - - with patch('av.open'): - # Main should save sysinfo immediately - try: - main() - except Exception: - pass # We expect it to fail, just checking sysinfo was saved - - # Check that sysinfo file was created - sysinfo_files = list(output_dir.glob("*_bench_sysinfo.json")) - assert len(sysinfo_files) >= 1, f"Expected sysinfo file, found {len(sysinfo_files)}" - - # Check content - sysinfo = json.loads(sysinfo_files[0].read_text()) - assert "command" in sysinfo - assert "argv" in sysinfo - assert "cwd" in sysinfo - assert "benchmark" in sysinfo["command"] - assert isinstance(sysinfo["argv"], list) - - def test_manager_has_graceful_completion_flag(): """Test that Manager class has _graceful_completion flag""" from palabra_ai.task.manager import Manager @@ -490,6 +264,23 @@ def test_benchmark_parse_handles_part_suffixes(): # Create mock events for different tid patterns base_ts = 0.0 # Use relative timestamps starting from 0 + # Helper to create set_task event + def make_set_task_event(idx, dawn_ts): + from palabra_ai.util.orjson import to_json + body_dict = { + "message_type": "set_task", + "data": { + "source": {"lang": {"code": "en"}}, + "targets": [{"lang": {"code": "es"}}] + } + } + return IoEvent( + head=Dbg(kind=Kind.MESSAGE, ch=None, dir=None, idx=idx, dawn_ts=dawn_ts, dur_s=0.0), + body=to_json(body_dict), + tid=None, + mtype=None + ) + def make_event(idx, tid, mtype, dawn_ts, text="test"): from palabra_ai.util.orjson import to_json import base64 @@ -533,10 +324,13 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): ) events = [ + # set_task event (required by Report.parse) + make_set_task_event(0, base_ts), + # Input audio events - make_event(0, None, "input_audio_data", base_ts), - make_event(1, None, "input_audio_data", base_ts + 0.1), - make_event(2, None, "input_audio_data", base_ts + 0.2), + make_event(1, None, "input_audio_data", base_ts), + make_event(2, None, "input_audio_data", base_ts + 0.1), + make_event(3, None, "input_audio_data", base_ts + 0.2), # sentence_1 (no _part suffix) - should have metrics make_event(10, "sentence_1", "partial_transcription", base_ts + 0.5, "Hello"), @@ -567,11 +361,13 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): mode="ws", channels=1, events=events, - count_events=len(events) + count_events=len(events), + reader_x_title="FileReader(test.wav)", + writer_x_title="DummyWriter()" ) # Parse the report - report, _, _ = Report.parse(io_data) + report = Report.parse(io_data) # Check that we have all 4 sentences assert len(report.sentences) == 4, f"Expected 4 sentences, got {len(report.sentences)}" @@ -628,18 +424,7 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): assert s2p0.local_start_ts == s2p2.local_start_ts, "sentence_2_part_2 should use parent timestamp" # Test that format_report includes IDs correctly - from palabra_ai.benchmark.report import format_report - from palabra_ai.config import Config - from palabra_ai.lang import Language - from palabra_ai import SourceLang, TargetLang - - config = Config( - source=SourceLang(Language.get_or_create("en"), None), - targets=[TargetLang(Language.get_or_create("es"), None)], - benchmark=True - ) - - report_text = format_report(report, io_data, "en", "es", "test.wav", "out.wav", config) + report_text = report.report_txt # Check that table contains formatted IDs assert "sentence_1" in report_text, "Should show sentence_1 without brackets" @@ -660,6 +445,22 @@ def test_benchmark_parse_handles_partial_extra_parts(): base_ts = 0.0 + # Helper to create set_task event + def make_set_task_event(idx, dawn_ts): + body_dict = { + "message_type": "set_task", + "data": { + "source": {"lang": {"code": "en"}}, + "targets": [{"lang": {"code": "es"}}] + } + } + return IoEvent( + head=Dbg(kind=Kind.MESSAGE, ch=None, dir=None, idx=idx, dawn_ts=dawn_ts, dur_s=0.0), + body=to_json(body_dict), + tid=None, + mtype=None + ) + def make_event(idx, tid, mtype, dawn_ts, text="test"): if mtype in ("input_audio_data", "output_audio_data"): audio_samples = np.zeros(160, dtype=np.int16) @@ -689,7 +490,10 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): ) events = [ - make_event(0, None, "input_audio_data", base_ts), + # set_task event (required by Report.parse) + make_set_task_event(0, base_ts), + + make_event(1, None, "input_audio_data", base_ts), # Parent sentence make_event(10, "s1_part_0", "partial_transcription", base_ts + 0.5, "Parent"), @@ -712,10 +516,12 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): mode="ws", channels=1, events=events, - count_events=len(events) + count_events=len(events), + reader_x_title="FileReader(test.wav)", + writer_x_title="DummyWriter()" ) - report, _, _ = Report.parse(io_data) + report = Report.parse(io_data) assert len(report.sentences) == 3, f"Expected 3 sentences, got {len(report.sentences)}" @@ -744,6 +550,22 @@ def test_benchmark_parse_handles_orphan_extra_parts(): base_ts = 0.0 + # Helper to create set_task event + def make_set_task_event(idx, dawn_ts): + body_dict = { + "message_type": "set_task", + "data": { + "source": {"lang": {"code": "en"}}, + "targets": [{"lang": {"code": "es"}}] + } + } + return IoEvent( + head=Dbg(kind=Kind.MESSAGE, ch=None, dir=None, idx=idx, dawn_ts=dawn_ts, dur_s=0.0), + body=to_json(body_dict), + tid=None, + mtype=None + ) + def make_event(idx, tid, mtype, dawn_ts, text="test"): body_dict = { "message_type": mtype, @@ -763,6 +585,9 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): ) events = [ + # set_task event (required by Report.parse) + make_set_task_event(0, base_ts), + # Orphan _part_1 without parent make_event(10, "orphan_part_1", "validated_transcription", base_ts + 0.5, "Orphan"), make_event(11, "orphan_part_1", "translated_transcription", base_ts + 0.6, "Huerfano"), @@ -776,7 +601,9 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): mode="ws", channels=1, events=events, - count_events=len(events) + count_events=len(events), + reader_x_title="FileReader(test.wav)", + writer_x_title="DummyWriter()" ) # Capture stdout to check for warning @@ -784,7 +611,7 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): old_stdout = sys.stdout try: sys.stdout = captured - report, _, _ = Report.parse(io_data) + report = Report.parse(io_data) sys.stdout = old_stdout finally: sys.stdout = old_stdout @@ -930,6 +757,22 @@ def test_benchmark_handles_missing_partial_transcription(): import base64 import numpy as np + # Helper to create set_task event + def make_set_task_event(idx, dawn_ts): + body_dict = { + "message_type": "set_task", + "data": { + "source": {"lang": {"code": "en"}}, + "targets": [{"lang": {"code": "es"}}] + } + } + return IoEvent( + head=Dbg(kind=Kind.MESSAGE, ch=None, dir=None, idx=idx, dawn_ts=dawn_ts, dur_s=0.0), + body=to_json(body_dict), + tid=None, + mtype=None + ) + def make_event(idx, tid, mtype, dawn_ts, text="test"): if mtype in ("input_audio_data", "output_audio_data"): audio_samples = np.zeros(160, dtype=np.int16) @@ -968,7 +811,10 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): # Create events: input, validated, translated, output_audio (no partial) events = [ - make_event(0, None, "input_audio_data", base_ts), + # set_task event (required by Report.parse) + make_set_task_event(0, base_ts), + + make_event(1, None, "input_audio_data", base_ts), make_event(10, base_tid, "validated_transcription", base_ts + 0.5, "validated text"), make_event(11, base_tid, "translated_transcription", base_ts + 1.0, "translated text"), make_event(12, base_tid, "output_audio_data", base_ts + 1.5), @@ -977,11 +823,13 @@ def make_event(idx, tid, mtype, dawn_ts, text="test"): # Create IoData io_data = IoData( start_perf_ts=base_ts, start_utc_ts=base_ts, in_sr=16000, out_sr=24000, - mode="test", channels=1, events=events, count_events=len(events) + mode="test", channels=1, events=events, count_events=len(events), + reader_x_title="FileReader(test.wav)", + writer_x_title="DummyWriter()" ) # Parse with Report - should handle missing partial_transcription - report, _, _ = Report.parse(io_data) + report = Report.parse(io_data) # Verify sentence was created despite missing partial_transcription assert len(report.sentences) == 1 @@ -1111,7 +959,9 @@ def test_rewind_saves_files_with_out_option(): mode="ws", channels=1, events=[], - count_events=0 + count_events=0, + reader_x_title="FileReader(test.wav)", + writer_x_title="DummyWriter()" ) # Create mock benchmark result file content @@ -1123,7 +973,9 @@ def test_rewind_saves_files_with_out_option(): "out_sr": 24000, "mode": "ws", "channels": 1, - "events": [] + "events": [], + "reader_x_title": "FileReader(test.wav)", + "writer_x_title": "DummyWriter()" } } @@ -1142,45 +994,18 @@ def test_rewind_saves_files_with_out_option(): mock_load.return_value = io_data with patch('palabra_ai.benchmark.rewind.Report.parse') as mock_parse: - # Mock empty audio canvases and report with proper set_task_e - import numpy as np - from palabra_ai.benchmark.report import Report - - # Create mock report with set_task_e that contains config data + # Mock report with save_all and report_txt mock_report = MagicMock() - mock_report.set_task_e = MagicMock() - mock_report.set_task_e.body = { - "data": { - "source": {"lang": {"code": "en"}}, - "targets": [{"lang": {"code": "es"}}] - } - } - - mock_in_audio = np.zeros(1000, dtype=np.int16) - mock_out_audio = np.zeros(1000, dtype=np.int16) - mock_parse.return_value = (mock_report, mock_in_audio, mock_out_audio) - - with patch('palabra_ai.benchmark.rewind.format_report') as mock_format: - mock_format.return_value = "Test report content" - - with patch('palabra_ai.benchmark.rewind.Config.from_dict') as mock_config: - # Create mock config with language properties - mock_config_obj = MagicMock() - mock_config_obj.source.lang.code = "en" - mock_config_obj.targets = [MagicMock()] - mock_config_obj.targets[0].lang.code = "es" - mock_config.return_value = mock_config_obj - - with patch('palabra_ai.benchmark.rewind.save_benchmark_files') as mock_save: - # Run rewind - rewind_main() + mock_report.report_txt = "Test report content" + mock_report.save_all = MagicMock() + mock_parse.return_value = mock_report - # Verify save_benchmark_files was called - mock_save.assert_called_once() + # Run rewind + rewind_main() - # The test was originally checking for actual file creation, but since we're mocking - # save_benchmark_files, we just verify the function was called correctly - assert True, "Rewind executed successfully with mocked dependencies" + # Verify parse and save_all were called + mock_parse.assert_called_once() + mock_report.save_all.assert_called_once() diff --git a/tests/test_client.py b/tests/test_client.py index 55e5b41..00d979f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -453,7 +453,7 @@ async def test_coro(): mock_manager = MagicMock() mock_manager._graceful_completion = False # External cancellation, not graceful mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_manager.io = mock_io # Create a future that raises CancelledError when awaited mock_task = asyncio.Future() @@ -469,7 +469,6 @@ async def test_coro(): result = await client.arun(config, no_raise=True) assert result.ok is False assert isinstance(result.exc, asyncio.CancelledError) - assert result.log_data is None asyncio.run(test_coro()) @@ -484,7 +483,7 @@ async def test_coro(): with patch.object(client, 'process') as mock_process: mock_manager = MagicMock() mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_io.eos_received = False mock_manager.io = mock_io # Create a future that raises ValueError when awaited @@ -503,78 +502,10 @@ async def test_coro(): result = await client.arun(config, no_raise=True) assert result.ok is False assert isinstance(result.exc, ValueError) - assert result.log_data is None mock_exception.assert_any_call("Error in manager task") asyncio.run(test_coro()) -def test_run_async_with_logger_timeout(): - """Test _run_with_result when logger times out""" - config = Config() - config.source = SourceLang(lang="es") - - client = PalabraAI(client_id="test", client_secret="test") - - async def test_coro(): - with patch.object(client, 'process') as mock_process: - mock_manager = MagicMock() - mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} - mock_manager.io = mock_io - async def normal_task(): - return None - mock_manager.task = asyncio.create_task(normal_task()) - mock_manager.logger = MagicMock() - mock_manager.logger._task = MagicMock() - mock_manager.logger._task.done.return_value = False - mock_manager.logger.result = None - mock_manager.logger.exit = AsyncMock(side_effect=asyncio.TimeoutError()) - mock_process.return_value.__aenter__.return_value = mock_manager - mock_process.return_value.__aexit__.return_value = None - - with patch('asyncio.wait_for', side_effect=asyncio.TimeoutError()): - with patch('palabra_ai.client.debug') as mock_debug: - result = await client.arun(config, no_raise=True) - assert result.ok is True - assert result.log_data is None - mock_debug.assert_any_call("Logger task timeout or cancelled, checking result anyway") - - asyncio.run(test_coro()) - -def test_run_async_with_logger_exception(): - """Test _run_with_result when logger.exit() raises exception""" - config = Config() - config.source = SourceLang(lang="es") - - client = PalabraAI(client_id="test", client_secret="test") - - async def test_coro(): - with patch.object(client, 'process') as mock_process: - mock_manager = MagicMock() - mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} - mock_manager.io = mock_io - # Create a future that completes normally - mock_task = asyncio.Future() - mock_task.set_result(None) - mock_manager.task = mock_task - mock_manager.logger = MagicMock() - mock_manager.logger._task = MagicMock() - mock_manager.logger._task.done.return_value = True - mock_manager.logger.result = None - mock_manager.logger.exit = AsyncMock(side_effect=RuntimeError("Logger exit error")) - mock_process.return_value.__aenter__.return_value = mock_manager - mock_process.return_value.__aexit__.return_value = None - - with patch('palabra_ai.client.debug') as mock_debug: - with patch('palabra_ai.client.error') as mock_error: - result = await client.arun(config, no_raise=True) - assert result.ok is True - assert result.log_data is None - mock_debug.assert_any_call("Failed to get log_data from logger.exit(): Logger exit error") - - asyncio.run(test_coro()) - def test_run_async_with_process_error(): """Test _run when process raises error""" config = Config() @@ -997,7 +928,7 @@ async def test_run_result_eos_field_from_manager(): mock_manager = MagicMock() mock_io = MagicMock() mock_io.eos_received = True - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_manager.io = mock_io # Mock logger with result @@ -1076,7 +1007,7 @@ async def test_run_result_eos_field_false(): mock_manager = MagicMock() mock_io = MagicMock() mock_io.eos_received = False - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "test", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_manager.io = mock_io # Mock logger with result @@ -1123,14 +1054,15 @@ def test_benchmark_completes_successfully_with_graceful_shutdown(): mode="ws", channels=1, events=[], - count_events=0 + count_events=0, + reader_x_title="", + writer_x_title="" ) successful_result = RunResult( ok=True, exc=None, io_data=io_data, - log_data=None, eos=True ) @@ -1157,7 +1089,7 @@ async def test_manager_cancelled_graceful_shutdown(): mock_manager = MagicMock() mock_manager._graceful_completion = True # GRACEFUL shutdown mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "ws", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "ws", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_io.eos_received = True mock_manager.io = mock_io @@ -1208,7 +1140,7 @@ async def test_manager_cancelled_external(): mock_manager = MagicMock() mock_manager._graceful_completion = False # EXTERNAL cancellation mock_io = MagicMock() - mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "ws", "channels": 1, "events": [], "count_events": 0} + mock_io.io_data = {"start_perf_ts": 0.0, "start_utc_ts": 0.0, "in_sr": 16000, "out_sr": 16000, "mode": "ws", "channels": 1, "events": [], "count_events": 0, "reader_x_title": "", "writer_x_title": ""} mock_io.eos_received = False mock_manager.io = mock_io @@ -1261,7 +1193,9 @@ async def test_graceful_shutdown_returns_ok_true(): "mode": "ws", "channels": 1, "events": [], - "count_events": 0 + "count_events": 0, + "reader_x_title": "", + "writer_x_title": "" } mock_io.eos_received = True # EOS was received - normal completion mock_manager.io = mock_io diff --git a/tests/test_config.py b/tests/test_config.py index b495097..2d4004f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1306,3 +1306,127 @@ def test_config_json_schema_has_stream_defaults(): assert output_target["format"]["default"] == "pcm_s16le" assert output_target["sample_rate"]["default"] == WS_MODE_OUTPUT_SAMPLE_RATE assert output_target["channels"]["default"] == WS_MODE_CHANNELS + +# ═══════════════════════════════════════════════════ +# NEW TESTS: Config Language Properties (TDD - RED) +# ═══════════════════════════════════════════════════ + +def test_config_source_lang_property(): + """Test Config.source_lang property returns source language code""" + config = Config( + source=SourceLang(lang=Language.get_by_bcp47("en-US")), + targets=[TargetLang(lang=ES)] + ) + # Uses source_code from Language + assert config.source_lang == "en" + + +def test_config_target_lang_property(): + """Test Config.target_lang property returns first target language code""" + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)] + ) + # Uses target_code from Language + assert config.target_lang == "es" + + +def test_config_source_lang_raises_when_not_configured(): + """Test Config.source_lang raises ConfigurationError if source not set""" + config = Config() + with pytest.raises(ConfigurationError, match="Source language not configured"): + _ = config.source_lang + + +def test_config_target_lang_raises_when_not_configured(): + """Test Config.target_lang raises ConfigurationError if no targets""" + config = Config(source=SourceLang(lang=EN)) + with pytest.raises(ConfigurationError, match="Target language not configured"): + _ = config.target_lang + + +def test_config_target_lang_with_multiple_targets(): + """Test Config.target_lang returns first target when multiple exist""" + config = Config( + source=SourceLang(lang=EN), + targets=[ + TargetLang(lang=ES), + TargetLang(lang=FR), + ] + ) + assert config.target_lang == "es" + + +# ═══════════════════════════════════════════════════ +# NEW TESTS: Config.output_dir Field (TDD - RED) +# ═══════════════════════════════════════════════════ + +def test_config_output_dir_creates_directory(tmp_path): + """Test Config.output_dir creates directory if doesn't exist""" + output_dir = tmp_path / "outputs" + assert not output_dir.exists() + + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)], + output_dir=output_dir + ) + + assert output_dir.exists() + assert output_dir.is_dir() + + +def test_config_output_dir_sets_log_file_automatically(tmp_path): + """Test Config.output_dir auto-sets log_file when not explicitly set""" + output_dir = tmp_path / "outputs" + + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)], + output_dir=output_dir + ) + + assert config.log_file is not None + assert config.log_file.parent == output_dir + assert config.log_file.suffix == ".log" + + +def test_config_output_dir_sets_trace_file_automatically(tmp_path): + """Test Config.output_dir auto-sets trace_file when not explicitly set""" + output_dir = tmp_path / "outputs" + + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)], + output_dir=output_dir + ) + + assert config.trace_file is not None + assert config.trace_file.parent == output_dir + assert str(config.trace_file).endswith(".trace.log") + + +def test_config_output_dir_respects_explicit_log_file(tmp_path): + """Test Config.output_dir keeps explicit log_file even when output_dir set""" + output_dir = tmp_path / "outputs" + custom_log = tmp_path / "custom.log" + + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)], + output_dir=output_dir, + log_file=str(custom_log) + ) + + assert config.log_file == custom_log + + +def test_config_output_dir_none_works_normally(): + """Test Config works normally when output_dir is None""" + config = Config( + source=SourceLang(lang=EN), + targets=[TargetLang(lang=ES)], + output_dir=None + ) + + assert config.output_dir is None diff --git a/tests/test_model.py b/tests/test_model.py index cd7030b..74ac90a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,30 +1,5 @@ import pytest -from palabra_ai.model import LogData, RunResult - - -def test_log_data_creation(): - """Test LogData model creation""" - log_data = LogData( - version="1.0.0", - sysinfo={"os": "darwin"}, - messages=[{"type": "test", "data": "example"}], - start_ts=1234567890.0, - cfg={"mode": "test"}, - log_file="test.log", - trace_file="test.trace", - debug=True, - logs=["log1", "log2"] - ) - - assert log_data.version == "1.0.0" - assert log_data.sysinfo["os"] == "darwin" - assert len(log_data.messages) == 1 - assert log_data.start_ts == 1234567890.0 - assert log_data.cfg["mode"] == "test" - assert log_data.log_file == "test.log" - assert log_data.trace_file == "test.trace" - assert log_data.debug is True - assert len(log_data.logs) == 2 +from palabra_ai.model import RunResult def test_run_result_success(): @@ -33,7 +8,6 @@ def test_run_result_success(): assert result.ok is True assert result.exc is None - assert result.log_data is None def test_run_result_with_exception(): @@ -43,28 +17,6 @@ def test_run_result_with_exception(): assert result.ok is False assert result.exc == test_exception - assert result.log_data is None - - -def test_run_result_with_log_data(): - """Test RunResult model with log data""" - log_data = LogData( - version="1.0.0", - sysinfo={}, - messages=[], - start_ts=0.0, - cfg={}, - log_file="", - trace_file="", - debug=False, - logs=[] - ) - - result = RunResult(ok=True, log_data=log_data) - - assert result.ok is True - assert result.exc is None - assert result.log_data == log_data def test_run_result_arbitrary_types(): @@ -94,21 +46,8 @@ def test_run_result_eos_field_true(): def test_run_result_eos_field_with_all_params(): """Test RunResult eos field works with all parameters""" - log_data = LogData( - version="1.0.0", - sysinfo={}, - messages=[], - start_ts=0.0, - cfg={}, - log_file="", - trace_file="", - debug=False, - logs=[] - ) - - result = RunResult(ok=True, exc=None, log_data=log_data, eos=True) + result = RunResult(ok=True, exc=None, eos=True) assert result.ok is True assert result.exc is None - assert result.log_data == log_data assert result.eos is True \ No newline at end of file diff --git a/tests/test_task_adapter_base.py b/tests/test_task_adapter_base.py index 3b922fa..4cd634e 100644 --- a/tests/test_task_adapter_base.py +++ b/tests/test_task_adapter_base.py @@ -35,6 +35,10 @@ def __init__(self, *args, **kwargs): self._task = None self._sub_tasks = [] + @property + def x_title(self) -> str: + return "concrete-reader" + async def read(self, size: int) -> bytes | None: """Mock implementation""" return b"test_data" @@ -76,6 +80,10 @@ def __init__(self, *args, **kwargs): self.ab = None self.drop_empty_frames = drop_empty_frames + @property + def x_title(self) -> str: + return "concrete-buffered-writer" + async def boot(self): """Mock boot implementation""" from palabra_ai.audio import AudioBuffer @@ -118,6 +126,10 @@ def __init__(self, *args, **kwargs): self._task = None self._sub_tasks = [] + @property + def x_title(self) -> str: + return "concrete-writer" + async def write(self, frame: AudioFrame): """Mock implementation""" pass diff --git a/tests/test_task_logger.py b/tests/test_task_logger.py deleted file mode 100644 index 623df86..0000000 --- a/tests/test_task_logger.py +++ /dev/null @@ -1,342 +0,0 @@ -import asyncio -import json -import pytest -import time -from pathlib import Path -from unittest.mock import MagicMock, AsyncMock, patch, call, mock_open -from dataclasses import asdict - -from palabra_ai.task.logger import Logger -from palabra_ai.task.base import TaskEvent -from palabra_ai.message import Dbg -from palabra_ai.config import Config -from palabra_ai.util.fanout_queue import FanoutQueue, Subscription -from palabra_ai.enum import Channel, Direction - - -class TestLogger: - """Test Logger class""" - - @pytest.fixture - def mock_config(self, tmp_path): - """Create mock config""" - from io import StringIO - config = MagicMock(spec=Config) - config.log_file = tmp_path / "test.log" - config.trace_file = tmp_path / "trace.json" - config.debug = True - config.benchmark = False # Add benchmark attribute - # Add internal_logs attribute - config.internal_logs = StringIO("internal log 1\ninternal log 2\n") - # Add to_dict method that returns a serializable dict - config.to_dict.return_value = { - "log_file": str(tmp_path / "test.log"), - "trace_file": str(tmp_path / "trace.json"), - "debug": True, - "benchmark": False - } - return config - - @pytest.fixture - def mock_io(self): - """Create mock IO""" - io = MagicMock() - io.in_msg_foq = MagicMock(spec=FanoutQueue) - io.out_msg_foq = MagicMock(spec=FanoutQueue) - - # Create mock subscriptions - in_sub = MagicMock(spec=Subscription) - in_sub.q = asyncio.Queue() - out_sub = MagicMock(spec=Subscription) - out_sub.q = asyncio.Queue() - - io.in_msg_foq.subscribe.return_value = in_sub - io.out_msg_foq.subscribe.return_value = out_sub - - return io - - def test_init(self, mock_config, mock_io): - """Test initialization""" - logger = Logger(cfg=mock_config, io=mock_io) - - assert logger.cfg == mock_config - assert logger.io == mock_io - assert logger._messages == [] - assert isinstance(logger._start_ts, float) - assert logger._io_in_sub is not None - assert logger._io_out_sub is not None - - # Verify subscriptions - mock_io.in_msg_foq.subscribe.assert_called_once_with(logger, maxsize=0) - mock_io.out_msg_foq.subscribe.assert_called_once_with(logger, maxsize=0) - - @pytest.mark.asyncio - async def test_boot(self, mock_config, mock_io): - """Test boot method""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.sub_tg = MagicMock() - mock_in_task = MagicMock() - mock_out_task = MagicMock() - logger.sub_tg.create_task.side_effect = [mock_in_task, mock_out_task] - - with patch('palabra_ai.task.logger.debug') as mock_debug: - await logger.boot() - - # Verify tasks created - assert logger.sub_tg.create_task.call_count == 2 - assert logger._in_task == mock_in_task - assert logger._out_task == mock_out_task - - # Verify debug message - mock_debug.assert_called_once() - assert "Logger started" in str(mock_debug.call_args[0][0]) - - @pytest.mark.asyncio - async def test_do(self, mock_config, mock_io): - """Test do method""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.stopper = TaskEvent() - - # Set stopper after short delay - async def set_stopper(): - await asyncio.sleep(0.01) - +logger.stopper - - with patch('palabra_ai.task.logger.debug') as mock_debug: - asyncio.create_task(set_stopper()) - await logger.do() - - mock_debug.assert_called_once() - assert "task stopped" in str(mock_debug.call_args[0][0]) - - @pytest.mark.asyncio - async def test_consume_with_message(self, mock_config, mock_io): - """Test _consume with valid message""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.stopper = TaskEvent() - - # Create mock message - mock_msg = MagicMock() - # Set _dbg as a Dbg instance (as expected by the logger) - from palabra_ai.enum import Kind - dbg = Dbg(kind=Kind.MESSAGE, ch=Channel.WS, dir=Direction.IN) - dbg.ts = 1234.5 # Set specific timestamp - mock_msg._dbg = dbg - mock_msg.model_dump.return_value = {"type": "test_message"} - - # Create queue and add message - q = asyncio.Queue() - await q.put(mock_msg) - await q.put(None) # Signal stop - - with patch('palabra_ai.task.logger.debug') as mock_debug: - await logger._consume(q) - - # Verify message was processed - assert len(logger._messages) == 1 - assert logger._messages[0]["msg"]["type"] == "test_message" - mock_debug.assert_called() - - @pytest.mark.asyncio - async def test_consume_timeout(self, mock_config, mock_io): - """Test _consume with timeout""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.stopper = TaskEvent() - - q = asyncio.Queue() - - # Set stopper after short delay - async def set_stopper(): - await asyncio.sleep(0.1) - +logger.stopper - - asyncio.create_task(set_stopper()) - await logger._consume(q) - - # Should complete without error - assert len(logger._messages) == 0 - - @pytest.mark.asyncio - async def test_consume_cancelled(self, mock_config, mock_io): - """Test _consume when cancelled""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.stopper = TaskEvent() - - q = AsyncMock() - q.get = AsyncMock(side_effect=asyncio.CancelledError()) - - with patch('palabra_ai.task.logger.debug') as mock_debug: - await logger._consume(q) - - mock_debug.assert_called_once() - assert "cancelled" in str(mock_debug.call_args[0][0]) - - @pytest.mark.asyncio - async def test_consume_no_dbg_attribute(self, mock_config, mock_io): - """Test _consume with message without _dbg attribute""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.stopper = TaskEvent() - - # Create mock message without _dbg - mock_msg = MagicMock() - del mock_msg._dbg # Remove _dbg attribute - mock_msg.model_dump.return_value = {"type": "test_message"} - - # Create queue and add message - q = asyncio.Queue() - await q.put(mock_msg) - await q.put(None) # Signal stop - - await logger._consume(q) - - # Verify message was processed with empty Dbg - assert len(logger._messages) == 1 - assert logger._messages[0]["msg"]["type"] == "test_message" - - @pytest.mark.asyncio - async def test_exit_success(self, mock_config, mock_io, tmp_path): - """Test successful exit""" - logger = Logger(cfg=mock_config, io=mock_io) - logger._messages = [{"msg": {"type": "test1"}}, {"msg": {"type": "test2"}}] - - # Create mock tasks that are actual asyncio.Task objects - async def dummy_task(): - await asyncio.sleep(0.001) # Very short sleep - - logger._in_task = asyncio.create_task(dummy_task()) - logger._out_task = asyncio.create_task(dummy_task()) - - # Create log file - log_file = tmp_path / "test.log" - log_file.write_text("Log line 1\nLog line 2\n") - - with patch('palabra_ai.task.logger.debug') as mock_debug: - with patch('palabra_ai.task.logger.get_system_info') as mock_sysinfo: - mock_sysinfo.return_value = {"os": "test", "version": "1.0"} - - # Mock the file write operation - with patch('builtins.open', mock_open()) as mock_file: - result = await logger.exit() - - # Tasks should be done or cancelled - assert logger._in_task.done() - assert logger._out_task.done() - - # Verify trace file write was attempted with correct path - mock_file.assert_any_call(mock_config.trace_file, "wb") - - # Verify unsubscribe - mock_io.in_msg_foq.unsubscribe.assert_called_once_with(logger) - mock_io.out_msg_foq.unsubscribe.assert_called_once_with(logger) - - assert mock_debug.call_count >= 3 - - @pytest.mark.asyncio - async def test_exit_log_file_error(self, mock_config, mock_io, tmp_path): - """Test exit when log file can't be read""" - logger = Logger(cfg=mock_config, io=mock_io) - logger._messages = [] - - # Create mock tasks that are actual asyncio.Task objects - async def dummy_task(): - await asyncio.sleep(0.001) # Very short sleep - - logger._in_task = asyncio.create_task(dummy_task()) - logger._out_task = asyncio.create_task(dummy_task()) - - # Make log file unreadable - mock_config.log_file = "/nonexistent/file.log" - - with patch('palabra_ai.task.logger.debug'): - with patch('palabra_ai.task.logger.get_system_info') as mock_sysinfo: - mock_sysinfo.return_value = {"os": "test"} - - # Mock the file write operation - with patch('builtins.open', mock_open()) as mock_file: - result = await logger.exit() - - # Verify trace file write was attempted with correct path - mock_file.assert_any_call(mock_config.trace_file, "wb") - - @pytest.mark.asyncio - async def test_exit_sysinfo_error(self, mock_config, mock_io, tmp_path): - """Test exit when sysinfo fails""" - logger = Logger(cfg=mock_config, io=mock_io) - logger._messages = [] - - # Create mock tasks that are actual asyncio.Task objects - async def dummy_task(): - await asyncio.sleep(0.001) # Very short sleep - - logger._in_task = asyncio.create_task(dummy_task()) - logger._out_task = asyncio.create_task(dummy_task()) - - # Create log file - log_file = tmp_path / "test.log" - log_file.write_text("Log line\n") - - with patch('palabra_ai.task.logger.debug'): - with patch('palabra_ai.task.logger.get_system_info') as mock_sysinfo: - mock_sysinfo.side_effect = RuntimeError("Sysinfo error") - - # Mock the file write operation - with patch('builtins.open', mock_open()) as mock_file: - result = await logger.exit() - - # Verify trace file write was attempted with correct path - mock_file.assert_any_call(mock_config.trace_file, "wb") - - @pytest.mark.asyncio - async def test_exit_with_version(self, mock_config, mock_io): - """Test exit includes version info""" - logger = Logger(cfg=mock_config, io=mock_io) - logger._messages = [] - - # Create mock tasks that are actual asyncio.Task objects - async def dummy_task(): - await asyncio.sleep(0.001) # Very short sleep - - logger._in_task = asyncio.create_task(dummy_task()) - logger._out_task = asyncio.create_task(dummy_task()) - - with patch('palabra_ai.task.logger.debug'): - with patch('palabra_ai.task.logger.get_system_info') as mock_sysinfo: - with patch('palabra_ai.__version__', '1.2.3'): - mock_sysinfo.return_value = {} - - # Mock the file write operation - with patch('builtins.open', mock_open()) as mock_file: - await logger.exit() - - # Verify trace file write was attempted - mock_file.assert_called_with(mock_config.trace_file, "wb") - - @pytest.mark.asyncio - async def test_exit_no_tasks(self, mock_config, mock_io): - """Test exit when no tasks were created""" - logger = Logger(cfg=mock_config, io=mock_io) - logger._messages = [] - logger._in_task = None - logger._out_task = None - - with patch('palabra_ai.task.logger.debug'): - with patch('palabra_ai.task.logger.get_system_info') as mock_sysinfo: - mock_sysinfo.return_value = {} - - # Mock the file write operation - with patch('builtins.open', mock_open()) as mock_file: - await logger.exit() - - # Should complete without error - mock_file.assert_called_with(mock_config.trace_file, "wb") - - @pytest.mark.asyncio - async def test_underscore_exit(self, mock_config, mock_io): - """Test _exit method calls exit""" - logger = Logger(cfg=mock_config, io=mock_io) - logger.exit = AsyncMock(return_value="test_result") - - result = await logger._exit() - - logger.exit.assert_called_once() - assert result == "test_result" diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index b45f5a4..af04e16 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -21,6 +21,10 @@ def __init__(self): self._task = None self.name = "MockReader" + @property + def x_title(self) -> str: + return "mock-reader" + async def boot(self): pass @@ -44,6 +48,10 @@ def __init__(self): self._task = None self.name = "MockWriter" + @property + def x_title(self) -> str: + return "mock-writer" + async def boot(self): pass @@ -181,10 +189,10 @@ async def mock_call(tg): manager.root_tg = MagicMock() manager.sub_tg = MagicMock() - # Mock logger properly - manager.logger = MagicMock() - manager.logger.ready = TaskEvent() - +manager.logger.ready # Set it + # # Mock logger properly + # manager.logger = MagicMock() + # manager.logger.ready = TaskEvent() + # +manager.logger.ready # Set it await manager.start_system()