From f0ae4432701ef368344c438e77a835a4b0ea9526 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 14 Feb 2026 18:25:10 +0530 Subject: [PATCH 1/7] first stab at TTS evaluation --- .../versions/047_add_tts_evaluation_tables.py | 157 +++++ .../api/docs/tts_evaluation/create_dataset.md | 4 + .../api/docs/tts_evaluation/get_dataset.md | 3 + .../app/api/docs/tts_evaluation/get_result.md | 3 + .../app/api/docs/tts_evaluation/get_run.md | 6 + .../api/docs/tts_evaluation/list_datasets.md | 3 + .../app/api/docs/tts_evaluation/list_runs.md | 3 + .../docs/tts_evaluation/start_evaluation.md | 8 + .../docs/tts_evaluation/update_feedback.md | 5 + .../app/api/routes/evaluations/__init__.py | 2 + .../api/routes/tts_evaluations/__init__.py | 1 + .../app/api/routes/tts_evaluations/dataset.py | 144 +++++ .../api/routes/tts_evaluations/evaluation.py | 298 ++++++++++ .../app/api/routes/tts_evaluations/result.py | 128 ++++ .../app/api/routes/tts_evaluations/router.py | 12 + backend/app/core/batch/__init__.py | 8 +- backend/app/core/batch/gemini.py | 69 +++ backend/app/crud/evaluations/cron.py | 26 +- backend/app/crud/tts_evaluations/__init__.py | 42 ++ backend/app/crud/tts_evaluations/batch.py | 174 ++++++ backend/app/crud/tts_evaluations/cron.py | 546 ++++++++++++++++++ backend/app/crud/tts_evaluations/dataset.py | 173 ++++++ backend/app/crud/tts_evaluations/result.py | 328 +++++++++++ backend/app/crud/tts_evaluations/run.py | 230 ++++++++ backend/app/models/tts_evaluation.py | 258 +++++++++ .../app/services/tts_evaluations/__init__.py | 1 + backend/app/services/tts_evaluations/audio.py | 57 ++ .../app/services/tts_evaluations/constants.py | 10 + .../app/services/tts_evaluations/dataset.py | 180 ++++++ 29 files changed, 2873 insertions(+), 6 deletions(-) create mode 100644 backend/app/alembic/versions/047_add_tts_evaluation_tables.py create mode 100644 backend/app/api/docs/tts_evaluation/create_dataset.md create mode 100644 backend/app/api/docs/tts_evaluation/get_dataset.md create mode 100644 backend/app/api/docs/tts_evaluation/get_result.md create mode 100644 backend/app/api/docs/tts_evaluation/get_run.md create mode 100644 backend/app/api/docs/tts_evaluation/list_datasets.md create mode 100644 backend/app/api/docs/tts_evaluation/list_runs.md create mode 100644 backend/app/api/docs/tts_evaluation/start_evaluation.md create mode 100644 backend/app/api/docs/tts_evaluation/update_feedback.md create mode 100644 backend/app/api/routes/tts_evaluations/__init__.py create mode 100644 backend/app/api/routes/tts_evaluations/dataset.py create mode 100644 backend/app/api/routes/tts_evaluations/evaluation.py create mode 100644 backend/app/api/routes/tts_evaluations/result.py create mode 100644 backend/app/api/routes/tts_evaluations/router.py create mode 100644 backend/app/crud/tts_evaluations/__init__.py create mode 100644 backend/app/crud/tts_evaluations/batch.py create mode 100644 backend/app/crud/tts_evaluations/cron.py create mode 100644 backend/app/crud/tts_evaluations/dataset.py create mode 100644 backend/app/crud/tts_evaluations/result.py create mode 100644 backend/app/crud/tts_evaluations/run.py create mode 100644 backend/app/models/tts_evaluation.py create mode 100644 backend/app/services/tts_evaluations/__init__.py create mode 100644 backend/app/services/tts_evaluations/audio.py create mode 100644 backend/app/services/tts_evaluations/constants.py create mode 100644 backend/app/services/tts_evaluations/dataset.py diff --git a/backend/app/alembic/versions/047_add_tts_evaluation_tables.py b/backend/app/alembic/versions/047_add_tts_evaluation_tables.py new file mode 100644 index 00000000..3a970b00 --- /dev/null +++ b/backend/app/alembic/versions/047_add_tts_evaluation_tables.py @@ -0,0 +1,157 @@ +"""add tts evaluation tables + +Revision ID: 047 +Revises: 046 +Create Date: 2026-02-14 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "047" +down_revision = "046" +branch_labels = None +depends_on = None + + +def upgrade(): + # Create tts_result table + op.create_table( + "tts_result", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the TTS result", + ), + sa.Column( + "sample_text", + sa.Text(), + nullable=False, + comment="Input text that was synthesized to speech", + ), + sa.Column( + "object_store_url", + sa.String(), + nullable=True, + comment="S3 URL of the generated WAV audio file", + ), + sa.Column( + "metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Audio metadata: {duration_seconds, size_bytes}", + ), + sa.Column( + "provider", + sa.String(length=100), + nullable=False, + comment="TTS provider used (e.g., gemini-2.5-pro-preview-tts)", + ), + sa.Column( + "status", + sa.String(length=20), + nullable=False, + server_default="PENDING", + comment="Result status: PENDING, SUCCESS, FAILED", + ), + sa.Column( + "score", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Extensible evaluation metrics (null in Phase 1)", + ), + sa.Column( + "is_correct", + sa.Boolean(), + nullable=True, + comment="Human feedback: audio quality correctness (null=not reviewed)", + ), + sa.Column( + "comment", + sa.Text(), + nullable=True, + comment="Human feedback comment", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if synthesis failed", + ), + sa.Column( + "evaluation_run_id", + sa.Integer(), + nullable=False, + comment="Reference to the evaluation run", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was last updated", + ), + sa.ForeignKeyConstraint( + ["evaluation_run_id"], + ["evaluation_run.id"], + name="fk_tts_result_run_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_tts_result_run_id", + "tts_result", + ["evaluation_run_id"], + unique=False, + ) + op.create_index( + "idx_tts_result_feedback", + "tts_result", + ["evaluation_run_id", "is_correct"], + unique=False, + ) + op.create_index( + "idx_tts_result_status", + "tts_result", + ["evaluation_run_id", "status"], + unique=False, + ) + + +def downgrade(): + op.drop_index("idx_tts_result_status", table_name="tts_result") + op.drop_index("idx_tts_result_feedback", table_name="tts_result") + op.drop_index("ix_tts_result_run_id", table_name="tts_result") + op.drop_table("tts_result") diff --git a/backend/app/api/docs/tts_evaluation/create_dataset.md b/backend/app/api/docs/tts_evaluation/create_dataset.md new file mode 100644 index 00000000..7e20bfc2 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/create_dataset.md @@ -0,0 +1,4 @@ +Create a new TTS evaluation dataset with text samples. + +Each sample requires: +- **text**: Text string to be synthesized into speech diff --git a/backend/app/api/docs/tts_evaluation/get_dataset.md b/backend/app/api/docs/tts_evaluation/get_dataset.md new file mode 100644 index 00000000..92531270 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_dataset.md @@ -0,0 +1,3 @@ +Get a TTS evaluation dataset by ID. + +Returns dataset metadata including sample count. diff --git a/backend/app/api/docs/tts_evaluation/get_result.md b/backend/app/api/docs/tts_evaluation/get_result.md new file mode 100644 index 00000000..0d89d90c --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_result.md @@ -0,0 +1,3 @@ +Get a single TTS synthesis result by ID. + +Returns the result including audio URL, metadata, and human feedback status. diff --git a/backend/app/api/docs/tts_evaluation/get_run.md b/backend/app/api/docs/tts_evaluation/get_run.md new file mode 100644 index 00000000..66c141a9 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_run.md @@ -0,0 +1,6 @@ +Get a TTS evaluation run by ID with optional results. + +Use query parameters to control result inclusion and pagination: +- `include_results`: Include synthesis results (default: true) +- `result_limit` / `result_offset`: Paginate results +- `provider` / `status`: Filter results diff --git a/backend/app/api/docs/tts_evaluation/list_datasets.md b/backend/app/api/docs/tts_evaluation/list_datasets.md new file mode 100644 index 00000000..6fc2f6c1 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/list_datasets.md @@ -0,0 +1,3 @@ +List all TTS evaluation datasets for the current project. + +Supports pagination with `limit` and `offset` parameters. diff --git a/backend/app/api/docs/tts_evaluation/list_runs.md b/backend/app/api/docs/tts_evaluation/list_runs.md new file mode 100644 index 00000000..e0119a69 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/list_runs.md @@ -0,0 +1,3 @@ +List TTS evaluation runs for the current project. + +Supports filtering by `dataset_id` and `status`, with pagination via `limit` and `offset`. diff --git a/backend/app/api/docs/tts_evaluation/start_evaluation.md b/backend/app/api/docs/tts_evaluation/start_evaluation.md new file mode 100644 index 00000000..543ed2de --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/start_evaluation.md @@ -0,0 +1,8 @@ +Start a TTS evaluation run on a dataset. + +The evaluation will: +1. Process each text sample through the specified TTS providers +2. Generate speech audio using Gemini Batch API +3. Store WAV audio files in S3 for human review + +**Supported providers:** gemini-2.5-pro-preview-tts diff --git a/backend/app/api/docs/tts_evaluation/update_feedback.md b/backend/app/api/docs/tts_evaluation/update_feedback.md new file mode 100644 index 00000000..7701bb52 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/update_feedback.md @@ -0,0 +1,5 @@ +Update human feedback on a TTS synthesis result. + +Fields: +- **is_correct**: Whether the synthesized audio quality is acceptable (null to clear) +- **comment**: Optional feedback comment diff --git a/backend/app/api/routes/evaluations/__init__.py b/backend/app/api/routes/evaluations/__init__.py index 14a2d87d..c0aa7175 100644 --- a/backend/app/api/routes/evaluations/__init__.py +++ b/backend/app/api/routes/evaluations/__init__.py @@ -4,9 +4,11 @@ from app.api.routes.evaluations import dataset, evaluation from app.api.routes.stt_evaluations.router import router as stt_router +from app.api.routes.tts_evaluations.router import router as tts_router router = APIRouter() router.include_router(dataset.router) router.include_router(stt_router) +router.include_router(tts_router) router.include_router(evaluation.router) diff --git a/backend/app/api/routes/tts_evaluations/__init__.py b/backend/app/api/routes/tts_evaluations/__init__.py new file mode 100644 index 00000000..78c67295 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/__init__.py @@ -0,0 +1 @@ +"""TTS Evaluation API routes.""" diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py new file mode 100644 index 00000000..79b3acd4 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -0,0 +1,144 @@ +"""TTS dataset API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.language import get_language_by_id +from app.crud.tts_evaluations import ( + get_tts_dataset_by_id, + list_tts_datasets, +) +from app.models.tts_evaluation import ( + TTSDatasetCreate, + TTSDatasetPublic, +) +from app.services.tts_evaluations.dataset import upload_tts_dataset +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/datasets", + response_model=APIResponse[TTSDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Create TTS dataset", + description=load_description("tts_evaluation/create_dataset.md"), +) +def create_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_create: TTSDatasetCreate = Body(...), +) -> APIResponse[TTSDatasetPublic]: + """Create a TTS evaluation dataset.""" + # Validate language_id if provided + if dataset_create.language_id is not None: + language = get_language_by_id( + session=_session, language_id=dataset_create.language_id + ) + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + + dataset = upload_tts_dataset( + session=_session, + name=dataset_create.name, + samples=dataset_create.samples, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + description=dataset_create.description, + language_id=dataset_create.language_id, + ) + + return APIResponse.success_response( + data=TTSDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + ) + + +@router.get( + "/datasets", + response_model=APIResponse[list[TTSDatasetPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List TTS datasets", + description=load_description("tts_evaluation/list_datasets.md"), +) +def list_datasets( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[TTSDatasetPublic]]: + """List TTS evaluation datasets.""" + datasets, total = list_tts_datasets( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=datasets, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/datasets/{dataset_id}", + response_model=APIResponse[TTSDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS dataset", + description=load_description("tts_evaluation/get_dataset.md"), +) +def get_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int, +) -> APIResponse[TTSDatasetPublic]: + """Get a TTS evaluation dataset.""" + dataset = get_tts_dataset_by_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + return APIResponse.success_response( + data=TTSDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ), + metadata={ + "sample_count": (dataset.dataset_metadata or {}).get("sample_count", 0) + }, + ) diff --git a/backend/app/api/routes/tts_evaluations/evaluation.py b/backend/app/api/routes/tts_evaluations/evaluation.py new file mode 100644 index 00000000..d60f32b4 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/evaluation.py @@ -0,0 +1,298 @@ +"""TTS evaluation run API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.tts_evaluations import ( + create_tts_run, + create_tts_results, + get_results_by_run_id, + get_tts_dataset_by_id, + get_tts_run_by_id, + list_tts_runs, + start_tts_evaluation_batch, + update_tts_run, +) +from app.models.tts_evaluation import ( + TTSEvaluationRunCreate, + TTSEvaluationRunPublic, + TTSEvaluationRunWithResults, + TTSSampleCreate, +) +from app.services.tts_evaluations.constants import ( + DEFAULT_STYLE_PROMPT, + DEFAULT_VOICE_NAME, +) +from app.services.tts_evaluations.dataset import parse_tts_samples_from_csv +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/runs", + response_model=APIResponse[TTSEvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Start TTS evaluation", + description=load_description("tts_evaluation/start_evaluation.md"), +) +def start_tts_evaluation( + _session: SessionDep, + auth_context: AuthContextDep, + run_create: TTSEvaluationRunCreate = Body(...), +) -> APIResponse[TTSEvaluationRunPublic]: + """Start a TTS evaluation run.""" + logger.info( + f"[start_tts_evaluation] Starting TTS evaluation | " + f"run_name: {run_create.run_name}, dataset_id: {run_create.dataset_id}, " + f"models: {run_create.models}" + ) + + # Validate dataset exists + dataset = get_tts_dataset_by_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + sample_count = (dataset.dataset_metadata or {}).get("sample_count", 0) + + if sample_count == 0: + raise HTTPException(status_code=400, detail="Dataset has no samples") + + # Get sample texts from the dataset CSV + sample_texts = _get_sample_texts_from_dataset( + _session, dataset, auth_context.project_.id + ) + + if not sample_texts: + raise HTTPException( + status_code=400, detail="Could not load samples from dataset" + ) + + language_id = dataset.language_id + + # Create run record + run = create_tts_run( + session=_session, + run_name=run_create.run_name, + dataset_id=run_create.dataset_id, + dataset_name=dataset.name, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + models=run_create.models, + language_id=language_id, + total_items=len(sample_texts) * len(run_create.models), + ) + + # Create result records for each sample text and model + results = create_tts_results( + session=_session, + sample_texts=sample_texts, + evaluation_run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + models=run_create.models, + ) + + try: + batch_result = start_tts_evaluation_batch( + session=_session, + run=run, + results=results, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + logger.info( + f"[start_tts_evaluation] TTS evaluation batch submitted | " + f"run_id: {run.id}, batch_jobs: {list(batch_result.get('batch_jobs', {}).keys())}" + ) + except Exception as e: + logger.error( + f"[start_tts_evaluation] Batch submission failed | " + f"run_id: {run.id}, error: {str(e)}" + ) + update_tts_run( + session=_session, + run_id=run.id, + status="failed", + error_message=str(e), + ) + raise HTTPException(status_code=500, detail=f"Batch submission failed: {e}") + + # Refresh run to get updated status + run = get_tts_run_by_id( + session=_session, + run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response( + data=TTSEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + run_metadata={ + "voice_name": DEFAULT_VOICE_NAME, + "style_prompt": DEFAULT_STYLE_PROMPT, + }, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + ) + + +def _get_sample_texts_from_dataset( + session: "SessionDep", + dataset: "EvaluationDataset", + project_id: int, +) -> list[str]: + """Extract sample texts from a TTS dataset's CSV in S3. + + Args: + session: Database session + dataset: The evaluation dataset record + project_id: Project ID + + Returns: + List of text strings + """ + if not dataset.object_store_url: + logger.warning( + f"[_get_sample_texts_from_dataset] No object_store_url | " + f"dataset_id={dataset.id}" + ) + return [] + + try: + from app.core.cloud.storage import get_cloud_storage + + storage = get_cloud_storage(session=session, project_id=project_id) + csv_bytes = storage.stream(dataset.object_store_url).read() + samples = parse_tts_samples_from_csv(csv_bytes) + return [s["text"] for s in samples] + except Exception as e: + logger.error( + f"[_get_sample_texts_from_dataset] Failed to load CSV | " + f"dataset_id={dataset.id}, error={str(e)}" + ) + return [] + + +@router.get( + "/runs", + response_model=APIResponse[list[TTSEvaluationRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List TTS evaluation runs", + description=load_description("tts_evaluation/list_runs.md"), +) +def list_tts_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int | None = Query(None, description="Filter by dataset ID"), + status: str | None = Query(None, description="Filter by status"), + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[TTSEvaluationRunPublic]]: + """List TTS evaluation runs.""" + runs, total = list_tts_runs( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + dataset_id=dataset_id, + status=status, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=runs, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/runs/{run_id}", + response_model=APIResponse[TTSEvaluationRunWithResults], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS evaluation run", + description=load_description("tts_evaluation/get_run.md"), +) +def get_tts_evaluation_run( + _session: SessionDep, + auth_context: AuthContextDep, + run_id: int, + include_results: bool = Query(True, description="Include results in response"), + result_limit: int = Query(100, ge=1, le=1000, description="Max results to return"), + result_offset: int = Query(0, ge=0, description="Result offset"), + provider: str | None = Query(None, description="Filter results by provider"), + status: str | None = Query(None, description="Filter results by status"), +) -> APIResponse[TTSEvaluationRunWithResults]: + """Get a TTS evaluation run with results.""" + run = get_tts_run_by_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not run: + raise HTTPException(status_code=404, detail="Evaluation run not found") + + results = [] + results_total = 0 + + if include_results: + results, results_total = get_results_by_run_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + provider=provider, + status=status, + limit=result_limit, + offset=result_offset, + ) + + return APIResponse.success_response( + data=TTSEvaluationRunWithResults( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + results=results, + results_total=results_total, + ), + metadata={"results_total": results_total}, + ) diff --git a/backend/app/api/routes/tts_evaluations/result.py b/backend/app/api/routes/tts_evaluations/result.py new file mode 100644 index 00000000..577e69a4 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/result.py @@ -0,0 +1,128 @@ +"""TTS result feedback API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.tts_evaluations import ( + get_tts_result_by_id, + update_tts_human_feedback, +) +from app.models.tts_evaluation import ( + TTSFeedbackUpdate, + TTSResultPublic, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.patch( + "/results/{result_id}", + response_model=APIResponse[TTSResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update human feedback", + description=load_description("tts_evaluation/update_feedback.md"), +) +def update_result_feedback( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, + feedback: TTSFeedbackUpdate = Body(...), +) -> APIResponse[TTSResultPublic]: + """Update human feedback on a TTS result.""" + logger.info( + f"[update_result_feedback] Updating feedback | " + f"result_id: {result_id}, is_correct: {feedback.is_correct}" + ) + + # Verify result exists and belongs to this project + existing = get_tts_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not existing: + raise HTTPException(status_code=404, detail="Result not found") + + # Update feedback + result = update_tts_human_feedback( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + is_correct=feedback.is_correct, + comment=feedback.comment, + ) + + return APIResponse.success_response( + data=TTSResultPublic( + id=result.id, + sample_text=result.sample_text, + object_store_url=result.object_store_url, + duration_seconds=(result.metadata_ or {}).get("duration_seconds"), + size_bytes=(result.metadata_ or {}).get("size_bytes"), + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) + + +@router.get( + "/results/{result_id}", + response_model=APIResponse[TTSResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS result", + description=load_description("tts_evaluation/get_result.md"), +) +def get_result( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, +) -> APIResponse[TTSResultPublic]: + """Get a TTS result by ID.""" + result = get_tts_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + return APIResponse.success_response( + data=TTSResultPublic( + id=result.id, + sample_text=result.sample_text, + object_store_url=result.object_store_url, + duration_seconds=(result.metadata_ or {}).get("duration_seconds"), + size_bytes=(result.metadata_ or {}).get("size_bytes"), + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) diff --git a/backend/app/api/routes/tts_evaluations/router.py b/backend/app/api/routes/tts_evaluations/router.py new file mode 100644 index 00000000..1a73ec30 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/router.py @@ -0,0 +1,12 @@ +"""Main router for TTS evaluation API routes.""" + +from fastapi import APIRouter + +from . import dataset, evaluation, result + +router = APIRouter(prefix="/evaluations/tts", tags=["TTS Evaluation"]) + +# Include all sub-routers +router.include_router(dataset.router) +router.include_router(evaluation.router) +router.include_router(result.router) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 382fcc4f..1b89c24f 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,7 +1,12 @@ """Batch processing infrastructure for LLM providers.""" from .base import BatchProvider -from .gemini import BatchJobState, GeminiBatchProvider, create_stt_batch_requests +from .gemini import ( + BatchJobState, + GeminiBatchProvider, + create_stt_batch_requests, + create_tts_batch_requests, +) from .openai import OpenAIBatchProvider from .operations import ( download_batch_results, @@ -17,6 +22,7 @@ "GeminiBatchProvider", "OpenAIBatchProvider", "create_stt_batch_requests", + "create_tts_batch_requests", "start_batch_job", "download_batch_results", "process_completed_batch", diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 1dd28a95..f008a6f1 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -441,3 +441,72 @@ def create_stt_batch_requests( logger.info(f"[create_stt_batch_requests] Created {len(requests)} batch requests") return requests + + +def create_tts_batch_requests( + texts: list[str], + voice_name: str, + style_prompt: str | None = None, + keys: list[str] | None = None, +) -> list[dict[str, Any]]: + """Create batch API requests for Gemini TTS. + + Generates request payloads in Gemini's JSONL batch format for + text-to-speech synthesis with audio output. + + Args: + texts: List of text strings to synthesize + voice_name: Prebuilt voice name (e.g., "Kore") + style_prompt: Optional style/tone instructions prepended to text + keys: Optional list of custom IDs for tracking results. If not provided, + uses 0-indexed integers as strings. + + Returns: + List of batch request dicts in Gemini JSONL format + + Example: + >>> texts = ["Hello, how can I help you?"] + >>> requests = create_tts_batch_requests( + ... texts, voice_name="Kore", + ... style_prompt="Read in a calm tone", + ... keys=["result-1"] + ... ) + """ + if keys is not None and len(keys) != len(texts): + raise ValueError( + f"Length of keys ({len(keys)}) must match texts ({len(texts)})" + ) + + requests = [] + for i, text in enumerate(texts): + key = keys[i] if keys is not None else str(i) + + # Prepend style prompt if provided + content_text = f"{style_prompt}: {text}" if style_prompt else text + + request = { + "key": key, + "request": { + "contents": [ + { + "parts": [{"text": content_text}], + "role": "user", + } + ], + "generationConfig": { + "responseModalities": ["AUDIO"], + "speechConfig": { + "voiceConfig": { + "prebuiltVoiceConfig": { + "voiceName": voice_name, + } + } + }, + }, + }, + } + requests.append(request) + + logger.info(f"[create_tts_batch_requests] Created {len(requests)} batch requests") + + return requests diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index e8393736..309fbe93 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -14,6 +14,7 @@ from app.crud.evaluations.processing import poll_all_pending_evaluations from app.crud.stt_evaluations import poll_all_pending_stt_evaluations +from app.crud.tts_evaluations import poll_all_pending_tts_evaluations logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: Delegates to poll_all_pending_evaluations which fetches all processing evaluation runs in a single query, groups by project, and processes them. - Also polls STT evaluations similarly. + Also polls STT and TTS evaluations similarly. Args: session: Database session @@ -41,13 +42,28 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: # Poll STT evaluations (single query, grouped by project) stt_summary = await poll_all_pending_stt_evaluations(session=session) + # Poll TTS evaluations (single query, grouped by project) + tts_summary = await poll_all_pending_tts_evaluations(session=session) + # Merge summaries - total_processed = text_summary["processed"] + stt_summary["processed"] - total_failed = text_summary["failed"] + stt_summary["failed"] + total_processed = ( + text_summary["processed"] + + stt_summary["processed"] + + tts_summary["processed"] + ) + total_failed = ( + text_summary["failed"] + stt_summary["failed"] + tts_summary["failed"] + ) total_still_processing = ( - text_summary["still_processing"] + stt_summary["still_processing"] + text_summary["still_processing"] + + stt_summary["still_processing"] + + tts_summary["still_processing"] + ) + all_details = ( + text_summary.get("details", []) + + stt_summary.get("details", []) + + tts_summary.get("details", []) ) - all_details = text_summary.get("details", []) + stt_summary.get("details", []) logger.info( f"[process_all_pending_evaluations] Completed: " diff --git a/backend/app/crud/tts_evaluations/__init__.py b/backend/app/crud/tts_evaluations/__init__.py new file mode 100644 index 00000000..fcced112 --- /dev/null +++ b/backend/app/crud/tts_evaluations/__init__.py @@ -0,0 +1,42 @@ +"""TTS Evaluation CRUD operations.""" + +from .batch import start_tts_evaluation_batch +from .cron import poll_all_pending_tts_evaluations +from .dataset import ( + create_tts_dataset, + get_tts_dataset_by_id, + list_tts_datasets, +) +from .run import ( + create_tts_run, + get_tts_run_by_id, + list_tts_runs, + update_tts_run, +) +from .result import ( + create_tts_results, + get_tts_result_by_id, + get_results_by_run_id, + update_tts_human_feedback, +) + +__all__ = [ + # Batch + "start_tts_evaluation_batch", + # Cron + "poll_all_pending_tts_evaluations", + # Dataset + "create_tts_dataset", + "get_tts_dataset_by_id", + "list_tts_datasets", + # Run + "create_tts_run", + "get_tts_run_by_id", + "list_tts_runs", + "update_tts_run", + # Result + "create_tts_results", + "get_tts_result_by_id", + "get_results_by_run_id", + "update_tts_human_feedback", +] diff --git a/backend/app/crud/tts_evaluations/batch.py b/backend/app/crud/tts_evaluations/batch.py new file mode 100644 index 00000000..37c5a63c --- /dev/null +++ b/backend/app/crud/tts_evaluations/batch.py @@ -0,0 +1,174 @@ +"""Batch submission functions for TTS evaluation processing.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch import ( + GeminiBatchProvider, + create_tts_batch_requests, + start_batch_job, +) +from app.crud.tts_evaluations.result import ( + get_pending_results_for_run, + update_tts_result, +) +from app.crud.tts_evaluations.run import update_tts_run +from app.models import EvaluationRun +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult +from app.services.stt_evaluations.gemini import GeminiClient +from app.services.tts_evaluations.constants import ( + DEFAULT_STYLE_PROMPT, + DEFAULT_TTS_MODEL, + DEFAULT_VOICE_NAME, +) + +logger = logging.getLogger(__name__) + + +def start_tts_evaluation_batch( + *, + session: Session, + run: EvaluationRun, + results: list[TTSResult], + org_id: int, + project_id: int, +) -> dict[str, Any]: + """Submit Gemini batch jobs for TTS evaluation. + + Submits one batch job per model. Each batch job is tracked via + its config containing evaluation_run_id and tts_provider. + + Args: + session: Database session + run: The evaluation run record + results: List of TTSResult records (contains sample_text) + org_id: Organization ID + project_id: Project ID + + Returns: + dict: Result with batch job information per model + + Raises: + Exception: If batch submission fails for all models + """ + models = run.providers or [DEFAULT_TTS_MODEL] + + logger.info( + f"[start_tts_evaluation_batch] Starting batch submission | " + f"run_id: {run.id}, result_count: {len(results)}, " + f"models: {models}" + ) + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + + # Collect unique sample texts and their result IDs per model + # Group results by model to build per-model batch requests + results_by_model: dict[str, list[TTSResult]] = {} + for result in results: + results_by_model.setdefault(result.provider, []).append(result) + + # Submit one batch job per model + batch_jobs: dict[str, Any] = {} + first_batch_job_id: int | None = None + + for model in models: + model_results = results_by_model.get(model, []) + if not model_results: + continue + + texts = [r.sample_text for r in model_results] + keys = [str(r.id) for r in model_results] + + # Create JSONL batch requests for TTS + jsonl_data = create_tts_batch_requests( + texts=texts, + voice_name=DEFAULT_VOICE_NAME, + style_prompt=DEFAULT_STYLE_PROMPT, + keys=keys, + ) + + model_path = f"models/{model}" + batch_provider = GeminiBatchProvider( + client=gemini_client.client, model=model_path + ) + + try: + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="google", + job_type="tts_evaluation", + organization_id=org_id, + project_id=project_id, + jsonl_data=jsonl_data, + config={ + "model": model, + "tts_provider": model, + "evaluation_run_id": run.id, + "voice_name": DEFAULT_VOICE_NAME, + "style_prompt": DEFAULT_STYLE_PROMPT, + }, + ) + + batch_jobs[model] = { + "batch_job_id": batch_job.id, + "provider_batch_id": batch_job.provider_batch_id, + } + + if first_batch_job_id is None: + first_batch_job_id = batch_job.id + + logger.info( + f"[start_tts_evaluation_batch] Batch job created | " + f"run_id: {run.id}, model: {model}, " + f"batch_job_id: {batch_job.id}" + ) + + except Exception as e: + logger.error( + f"[start_tts_evaluation_batch] Failed to submit batch | " + f"model: {model}, error: {str(e)}" + ) + pending = get_pending_results_for_run( + session=session, run_id=run.id, provider=model + ) + for result in pending: + update_tts_result( + session=session, + result_id=result.id, + status=JobStatus.FAILED.value, + error_message=f"Batch submission failed for {model}: {str(e)}", + ) + session.commit() + + if not batch_jobs: + raise Exception("Batch submission failed for all models") + + # Link first batch job to the evaluation run (for pending run detection) + update_tts_run( + session=session, + run_id=run.id, + status="processing", + batch_job_id=first_batch_job_id, + ) + + logger.info( + f"[start_tts_evaluation_batch] Batch submission complete | " + f"run_id: {run.id}, models_submitted: {list(batch_jobs.keys())}, " + f"result_count: {len(results)}" + ) + + return { + "success": True, + "run_id": run.id, + "batch_jobs": batch_jobs, + "result_count": len(results), + } diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py new file mode 100644 index 00000000..d9a32bab --- /dev/null +++ b/backend/app/crud/tts_evaluations/cron.py @@ -0,0 +1,546 @@ +"""Cron processing functions for TTS evaluations. + +This module provides functions that are called periodically to process +pending TTS evaluations - polling batch status and processing completed batches. + +Follows the same pattern as STT evaluations: single query to fetch all +processing runs, grouped by project_id for credential management. +""" + +import base64 +import logging +import uuid +from collections import defaultdict +from typing import Any + +from sqlalchemy import Integer +from sqlmodel import Session, select + +from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status +from app.core.cloud.storage import get_cloud_storage +from app.core.storage_utils import upload_to_object_store +from app.crud.tts_evaluations.result import count_results_by_status, update_tts_result +from app.crud.tts_evaluations.run import update_tts_run +from app.models import EvaluationRun +from app.models.batch_job import BatchJob +from app.models.job import JobStatus +from app.models.stt_evaluation import EvaluationType +from app.models.tts_evaluation import TTSResult +from app.services.stt_evaluations.gemini import GeminiClient +from app.services.tts_evaluations.audio import calculate_duration, pcm_to_wav + +logger = logging.getLogger(__name__) + +# Terminal states that indicate batch processing is complete +TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +async def poll_all_pending_tts_evaluations( + session: Session, +) -> dict[str, Any]: + """Poll all pending TTS evaluations across all organizations. + + Fetches all TTS evaluation runs with status='processing' in a single query, + groups them by project_id, and processes each project with its own + Gemini client. + + Args: + session: Database session + + Returns: + Summary dict with total, processed, failed, still_processing counts + """ + logger.info("[poll_all_pending_tts_evaluations] Starting TTS evaluation polling") + + # Single query to fetch all processing TTS evaluation runs + statement = select(EvaluationRun).where( + EvaluationRun.type == EvaluationType.TTS.value, + EvaluationRun.status == "processing", + EvaluationRun.batch_job_id.is_not(None), + ) + pending_runs = session.exec(statement).all() + + if not pending_runs: + logger.info("[poll_all_pending_tts_evaluations] No pending TTS runs found") + return { + "total": 0, + "processed": 0, + "failed": 0, + "still_processing": 0, + "details": [], + } + + logger.info( + f"[poll_all_pending_tts_evaluations] Found {len(pending_runs)} pending TTS runs" + ) + + # Group evaluations by project_id since credentials are per project + evaluations_by_project: dict[int, list[EvaluationRun]] = defaultdict(list) + for run in pending_runs: + evaluations_by_project[run.project_id].append(run) + + # Process each project separately + all_results: list[dict[str, Any]] = [] + total_processed = 0 + total_failed = 0 + total_still_processing = 0 + + for project_id, project_runs in evaluations_by_project.items(): + org_id = project_runs[0].organization_id + + try: + try: + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + except Exception as client_err: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to get Gemini client | " + f"org_id={org_id} | project_id={project_id} | error={client_err}" + ) + for run in project_runs: + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Gemini client initialization failed: {str(client_err)}", + ) + all_results.append( + { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "action": "failed", + "error": str(client_err), + } + ) + total_failed += 1 + continue + + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + for run in project_runs: + try: + result = await poll_tts_run( + session=session, + run=run, + batch_provider=batch_provider, + org_id=org_id, + ) + all_results.append(result) + + if result["action"] in ("completed", "processed"): + total_processed += 1 + elif result["action"] == "failed": + total_failed += 1 + else: + total_still_processing += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to poll TTS run | " + f"run_id={run.id} | {e}", + exc_info=True, + ) + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Polling failed: {str(e)}", + ) + all_results.append( + { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "action": "failed", + "error": str(e), + } + ) + total_failed += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to process project | " + f"project_id={project_id} | {e}", + exc_info=True, + ) + for run in project_runs: + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Project processing failed: {str(e)}", + ) + all_results.append( + { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "action": "failed", + "error": f"Project processing failed: {str(e)}", + } + ) + total_failed += 1 + + summary = { + "total": len(pending_runs), + "processed": total_processed, + "failed": total_failed, + "still_processing": total_still_processing, + "details": all_results, + } + + logger.info( + f"[poll_all_pending_tts_evaluations] Polling summary | " + f"processed={total_processed} | failed={total_failed} | " + f"still_processing={total_still_processing}" + ) + + return summary + + +def _get_batch_jobs_for_run( + session: Session, + run: EvaluationRun, +) -> list[BatchJob]: + """Find all batch jobs associated with a TTS evaluation run. + + Args: + session: Database session + run: The evaluation run + + Returns: + list[BatchJob]: All batch jobs for this run + """ + stmt = select(BatchJob).where( + BatchJob.job_type == "tts_evaluation", + BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, + ) + return list(session.exec(stmt).all()) + + +async def poll_tts_run( + session: Session, + run: EvaluationRun, + batch_provider: GeminiBatchProvider, + org_id: int, +) -> dict[str, Any]: + """Poll a single TTS evaluation run's batch status. + + Finds all batch jobs for this run (one per provider) and polls each. + Only marks the run as complete when all batch jobs are in terminal states. + + Args: + session: Database session + run: The evaluation run to poll + batch_provider: Initialized GeminiBatchProvider + org_id: Organization ID + + Returns: + dict: Status result with run details and action taken + """ + log_prefix = f"[org={org_id}][project={run.project_id}][eval={run.id}]" + logger.info(f"[poll_tts_run] {log_prefix} Polling run") + + previous_status = run.status + + batch_jobs = _get_batch_jobs_for_run(session=session, run=run) + + if not batch_jobs: + logger.warning(f"[poll_tts_run] {log_prefix} No batch jobs found") + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message="No batch jobs found", + ) + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": "failed", + "action": "failed", + "error": "No batch jobs found", + } + + all_terminal = True + any_succeeded = False + any_failed = False + errors: list[str] = [] + + for batch_job in batch_jobs: + provider_name = batch_job.config.get("tts_provider", "unknown") + + # Skip batch jobs already in terminal state + if batch_job.provider_status in TERMINAL_STATES: + if batch_job.provider_status == BatchJobState.SUCCEEDED.value: + any_succeeded = True + else: + any_failed = True + errors.append( + f"{provider_name}: {batch_job.error_message or batch_job.provider_status}" + ) + continue + + # Poll batch job status + poll_batch_status( + session=session, + provider=batch_provider, + batch_job=batch_job, + ) + + session.refresh(batch_job) + provider_status = batch_job.provider_status + + logger.info( + f"[poll_tts_run] {log_prefix} Batch status | " + f"batch_job_id={batch_job.id} | provider={provider_name} | " + f"state={provider_status}" + ) + + if provider_status not in TERMINAL_STATES: + all_terminal = False + continue + + # Batch reached terminal state - process it + if provider_status == BatchJobState.SUCCEEDED.value: + await process_completed_tts_batch( + session=session, + run=run, + batch_job=batch_job, + batch_provider=batch_provider, + ) + any_succeeded = True + else: + any_failed = True + errors.append( + f"{provider_name}: {batch_job.error_message or provider_status}" + ) + + if not all_terminal: + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": run.status, + "action": "no_change", + } + + # All batch jobs are done - finalize the run + status_counts = count_results_by_status(session=session, run_id=run.id) + pending = status_counts.get(JobStatus.PENDING.value, 0) + failed_count = status_counts.get(JobStatus.FAILED.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = None + if any_failed: + error_message = "; ".join(errors) + elif failed_count > 0: + error_message = f"{failed_count} synthesis(es) failed" + + update_tts_run( + session=session, + run_id=run.id, + status=final_status, + error_message=error_message, + ) + + action = "completed" if not any_failed else "failed" + + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": final_status, + "action": action, + **({"error": error_message} if error_message else {}), + } + + +async def process_completed_tts_batch( + session: Session, + run: EvaluationRun, + batch_job: Any, + batch_provider: GeminiBatchProvider, +) -> None: + """Process completed Gemini batch - download results, convert audio, upload to S3. + + For each result: + 1. Download JSONL results from Gemini + 2. Extract base64-encoded PCM audio from response + 3. Decode base64 -> raw PCM bytes + 4. Wrap in WAV container (24kHz, 16-bit, mono) + 5. Upload WAV to S3 + 6. Update TTSResult with object_store_url and metadata + + Args: + session: Database session + run: The evaluation run + batch_job: The BatchJob record + batch_provider: Initialized GeminiBatchProvider + """ + logger.info( + f"[process_completed_tts_batch] Processing batch results | " + f"run_id={run.id}, batch_job_id={batch_job.id}" + ) + + tts_provider = batch_job.config.get("tts_provider", "gemini-2.5-pro-preview-tts") + + # Get cloud storage for S3 uploads + storage = get_cloud_storage(session=session, project_id=run.project_id) + + processed_count = 0 + failed_count = 0 + + try: + results = batch_provider.download_batch_results(batch_job.provider_batch_id) + + logger.info( + f"[process_completed_tts_batch] Got batch results | " + f"batch_job_id={batch_job.id}, result_count={len(results)}" + ) + + for batch_result in results: + custom_id = batch_result["custom_id"] + try: + result_id = int(custom_id) + except (ValueError, TypeError): + logger.warning( + f"[process_completed_tts_batch] Invalid custom_id | " + f"batch_job_id={batch_job.id}, custom_id={custom_id}" + ) + failed_count += 1 + continue + + # Find result record + stmt = select(TTSResult).where( + TTSResult.id == result_id, + TTSResult.evaluation_run_id == run.id, + TTSResult.provider == tts_provider, + ) + result_record = session.exec(stmt).one_or_none() + + if not result_record: + logger.warning( + f"[process_completed_tts_batch] Result record not found | " + f"result_id={result_id}" + ) + failed_count += 1 + continue + + if batch_result.get("response"): + try: + # Extract base64 audio from Gemini TTS response + audio_b64 = _extract_audio_from_response(batch_result["response"]) + + if not audio_b64: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message="No audio data in response", + ) + failed_count += 1 + continue + + # Decode base64 -> raw PCM bytes + pcm_data = base64.b64decode(audio_b64) + + # Wrap in WAV container + wav_data = pcm_to_wav(pcm_data) + + # Calculate duration + duration = calculate_duration(len(pcm_data)) + + # Upload WAV to S3 + audio_filename = f"{uuid.uuid4()}.wav" + audio_url = upload_to_object_store( + storage=storage, + content=wav_data, + filename=audio_filename, + subdirectory=f"evaluations/tts/audio", + content_type="audio/wav", + ) + + # Update result + update_tts_result( + session=session, + result_id=result_record.id, + object_store_url=audio_url, + metadata={ + "duration_seconds": round(duration, 3), + "size_bytes": len(wav_data), + }, + status=JobStatus.SUCCESS.value, + ) + processed_count += 1 + + except Exception as audio_err: + logger.error( + f"[process_completed_tts_batch] Audio processing failed | " + f"result_id={result_id}, error={str(audio_err)}" + ) + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=f"Audio processing failed: {str(audio_err)}", + ) + failed_count += 1 + else: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 + + session.commit() + + except Exception as e: + logger.error( + f"[process_completed_tts_batch] Failed to process batch results | " + f"batch_job_id={batch_job.id}, error={str(e)}", + exc_info=True, + ) + raise + + logger.info( + f"[process_completed_tts_batch] Batch results processed | " + f"run_id={run.id}, provider={tts_provider}, " + f"processed={processed_count}, failed={failed_count}" + ) + + +def _extract_audio_from_response(response: dict[str, Any]) -> str | None: + """Extract base64-encoded audio data from a Gemini TTS response. + + Gemini TTS returns audio as base64-encoded PCM data in the + inlineData field of the response parts. + + Args: + response: Gemini response dictionary + + Returns: + Base64 encoded audio string, or None if not found + """ + # Navigate: candidates -> content -> parts -> inlineData -> data + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + inline_data = part.get("inlineData", {}) + if inline_data.get("data"): + return inline_data["data"] + return None diff --git a/backend/app/crud/tts_evaluations/dataset.py b/backend/app/crud/tts_evaluations/dataset.py new file mode 100644 index 00000000..a2352073 --- /dev/null +++ b/backend/app/crud/tts_evaluations/dataset.py @@ -0,0 +1,173 @@ +"""CRUD operations for TTS evaluation datasets.""" + +import logging +from typing import Any + +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, func, select + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models import EvaluationDataset +from app.models.stt_evaluation import EvaluationType +from app.models.tts_evaluation import TTSDatasetPublic + +logger = logging.getLogger(__name__) + + +def create_tts_dataset( + *, + session: Session, + name: str, + org_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, + object_store_url: str | None = None, + dataset_metadata: dict[str, Any] | None = None, +) -> EvaluationDataset: + """Create a new TTS evaluation dataset. + + Args: + session: Database session + name: Dataset name + org_id: Organization ID + project_id: Project ID + description: Optional description + language_id: Optional reference to global.languages table + object_store_url: Optional object store URL + dataset_metadata: Optional metadata dict + + Returns: + EvaluationDataset: Created dataset + + Raises: + HTTPException: If dataset with same name already exists + """ + logger.info( + f"[create_tts_dataset] Creating TTS dataset | " + f"name: {name}, org_id: {org_id}, project_id: {project_id}" + ) + + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.TTS.value, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=dataset_metadata or {}, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + try: + session.add(dataset) + session.flush() + + logger.info( + f"[create_tts_dataset] TTS dataset created | " + f"dataset_id: {dataset.id}, name: {name}" + ) + + return dataset + + except IntegrityError as e: + session.rollback() + if "uq_evaluation_dataset_name_org_project" in str(e): + logger.error( + f"[create_tts_dataset] Dataset name already exists | name: {name}" + ) + raise HTTPException( + status_code=400, + detail=f"Dataset with name '{name}' already exists", + ) + raise + + +def get_tts_dataset_by_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, +) -> EvaluationDataset | None: + """Get a TTS dataset by ID. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationDataset | None: Dataset if found + """ + statement = select(EvaluationDataset).where( + EvaluationDataset.id == dataset_id, + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.TTS.value, + ) + + return session.exec(statement).one_or_none() + + +def list_tts_datasets( + *, + session: Session, + org_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> tuple[list[TTSDatasetPublic], int]: + """List TTS datasets for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[TTSDatasetPublic], int]: Datasets and total count + """ + base_filter = ( + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.TTS.value, + ) + + count_stmt = select(func.count(EvaluationDataset.id)).where(*base_filter) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationDataset) + .where(*base_filter) + .order_by(EvaluationDataset.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + datasets = session.exec(statement).all() + + result = [ + TTSDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + for dataset in datasets + ] + + return result, total diff --git a/backend/app/crud/tts_evaluations/result.py b/backend/app/crud/tts_evaluations/result.py new file mode 100644 index 00000000..d21ac6e9 --- /dev/null +++ b/backend/app/crud/tts_evaluations/result.py @@ -0,0 +1,328 @@ +"""CRUD operations for TTS evaluation results.""" + +import logging +from typing import Any + +from sqlmodel import Session, func, select + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult, TTSResultPublic + +logger = logging.getLogger(__name__) + + +def create_tts_results( + *, + session: Session, + sample_texts: list[str], + evaluation_run_id: int, + org_id: int, + project_id: int, + models: list[str], +) -> list[TTSResult]: + """Create TTS result records for all sample texts and models. + + Creates one result per sample text per model. + + Args: + session: Database session + sample_texts: List of text strings to synthesize + evaluation_run_id: Run ID + org_id: Organization ID + project_id: Project ID + models: List of TTS models + + Returns: + list[TTSResult]: Created results + """ + logger.info( + f"[create_tts_results] Creating TTS results | " + f"run_id: {evaluation_run_id}, sample_count: {len(sample_texts)}, " + f"model_count: {len(models)}" + ) + + timestamp = now() + results = [ + TTSResult( + sample_text=text, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=model, + status=JobStatus.PENDING.value, + inserted_at=timestamp, + updated_at=timestamp, + ) + for text in sample_texts + for model in models + ] + + session.add_all(results) + session.flush() + session.commit() + + logger.info( + f"[create_tts_results] TTS results created | " + f"run_id: {evaluation_run_id}, result_count: {len(results)}" + ) + + return results + + +def get_tts_result_by_id( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, +) -> TTSResult | None: + """Get a TTS result by ID. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + + Returns: + TTSResult | None: Result if found + """ + statement = select(TTSResult).where( + TTSResult.id == result_id, + TTSResult.organization_id == org_id, + TTSResult.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def get_results_by_run_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, + provider: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, +) -> tuple[list[TTSResultPublic], int]: + """Get results for an evaluation run. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + provider: Optional filter by provider + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[TTSResultPublic], int]: Results and total count + """ + where_clauses = [ + TTSResult.evaluation_run_id == run_id, + TTSResult.organization_id == org_id, + TTSResult.project_id == project_id, + ] + + if provider is not None: + where_clauses.append(TTSResult.provider == provider) + + if status is not None: + where_clauses.append(TTSResult.status == status) + + count_stmt = select(func.count(TTSResult.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + statement = ( + select(TTSResult) + .where(*where_clauses) + .order_by(TTSResult.id) + .offset(offset) + .limit(limit) + ) + + rows = session.exec(statement).all() + + results = [ + TTSResultPublic( + id=result.id, + sample_text=result.sample_text, + object_store_url=result.object_store_url, + duration_seconds=(result.metadata_ or {}).get("duration_seconds"), + size_bytes=(result.metadata_ or {}).get("size_bytes"), + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + for result in rows + ] + + return results, total + + +def update_tts_result( + *, + session: Session, + result_id: int, + object_store_url: str | None = None, + metadata: dict[str, Any] | None = None, + status: str | None = None, + error_message: str | None = None, +) -> TTSResult | None: + """Update a TTS result. + + Args: + session: Database session + result_id: Result ID + object_store_url: S3 URL of generated WAV + metadata: Audio metadata (duration_seconds, size_bytes) + status: New status + error_message: Error message if failed + + Returns: + TTSResult | None: Updated result + """ + statement = select(TTSResult).where(TTSResult.id == result_id) + result = session.exec(statement).one_or_none() + + if not result: + return None + + if object_store_url is not None: + result.object_store_url = object_store_url + if metadata is not None: + result.metadata_ = metadata + if status is not None: + result.status = status + if error_message is not None: + result.error_message = error_message + + result.updated_at = now() + + session.add(result) + session.flush() + + return result + + +def update_tts_human_feedback( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, + is_correct: bool | None = None, + comment: str | None = None, +) -> TTSResult | None: + """Update human feedback on a TTS result. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + is_correct: Human verification of quality + comment: Feedback comment + + Returns: + TTSResult | None: Updated result + + Raises: + HTTPException: If result not found + """ + result = get_tts_result_by_id( + session=session, + result_id=result_id, + org_id=org_id, + project_id=project_id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + if is_correct is not None: + result.is_correct = is_correct + + if comment is not None: + result.comment = comment + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + logger.info( + f"[update_tts_human_feedback] Human feedback updated | " + f"result_id: {result_id}, is_correct: {is_correct}" + ) + + return result + + +def get_pending_results_for_run( + *, + session: Session, + run_id: int, + provider: str | None = None, +) -> list[TTSResult]: + """Get all pending results for a run. + + Args: + session: Database session + run_id: Run ID + provider: Optional filter by provider + + Returns: + list[TTSResult]: Pending results + """ + where_clauses = [ + TTSResult.evaluation_run_id == run_id, + TTSResult.status == JobStatus.PENDING.value, + ] + + if provider is not None: + where_clauses.append(TTSResult.provider == provider) + + statement = select(TTSResult).where(*where_clauses) + + return list(session.exec(statement).all()) + + +def count_results_by_status( + *, + session: Session, + run_id: int, +) -> dict[str, int]: + """Count results by status for a run. + + Args: + session: Database session + run_id: Run ID + + Returns: + dict[str, int]: Counts by status + """ + statement = ( + select(TTSResult.status, func.count(TTSResult.id)) + .where(TTSResult.evaluation_run_id == run_id) + .group_by(TTSResult.status) + ) + + rows = session.exec(statement).all() + + return {status: count for status, count in rows} diff --git a/backend/app/crud/tts_evaluations/run.py b/backend/app/crud/tts_evaluations/run.py new file mode 100644 index 00000000..272bced6 --- /dev/null +++ b/backend/app/crud/tts_evaluations/run.py @@ -0,0 +1,230 @@ +"""CRUD operations for TTS evaluation runs.""" + +import logging +from typing import Any + +from sqlmodel import Session, func, select + +from app.core.util import now +from app.models import EvaluationDataset, EvaluationRun +from app.models.stt_evaluation import EvaluationType +from app.models.tts_evaluation import TTSEvaluationRunPublic + +logger = logging.getLogger(__name__) + + +def create_tts_run( + *, + session: Session, + run_name: str, + dataset_id: int, + dataset_name: str, + org_id: int, + project_id: int, + models: list[str], + language_id: int | None = None, + total_items: int = 0, +) -> EvaluationRun: + """Create a new TTS evaluation run. + + Args: + session: Database session + run_name: Name for the run + dataset_id: ID of the dataset to evaluate + dataset_name: Name of the dataset + org_id: Organization ID + project_id: Project ID + models: List of TTS models to use + language_id: Optional language ID override + total_items: Total number of items to process + + Returns: + EvaluationRun: Created run + """ + logger.info( + f"[create_tts_run] Creating TTS evaluation run | " + f"run_name: {run_name}, dataset_id: {dataset_id}, " + f"models: {models}" + ) + + run = EvaluationRun( + run_name=run_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + type=EvaluationType.TTS.value, + language_id=language_id, + providers=models, + status="pending", + total_items=total_items, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[create_tts_run] TTS evaluation run created | " + f"run_id: {run.id}, run_name: {run_name}" + ) + + return run + + +def get_tts_run_by_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, +) -> EvaluationRun | None: + """Get a TTS evaluation run by ID. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationRun | None: Run if found + """ + statement = select(EvaluationRun).where( + EvaluationRun.id == run_id, + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.TTS.value, + ) + + return session.exec(statement).one_or_none() + + +def list_tts_runs( + *, + session: Session, + org_id: int, + project_id: int, + dataset_id: int | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, +) -> tuple[list[TTSEvaluationRunPublic], int]: + """List TTS evaluation runs for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + dataset_id: Optional filter by dataset + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[TTSEvaluationRunPublic], int]: Runs and total count + """ + where_clauses = [ + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.TTS.value, + ] + + if dataset_id is not None: + where_clauses.append(EvaluationRun.dataset_id == dataset_id) + + if status is not None: + where_clauses.append(EvaluationRun.status == status) + + count_stmt = select(func.count(EvaluationRun.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationRun) + .where(*where_clauses) + .order_by(EvaluationRun.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + runs = session.exec(statement).all() + + result = [ + TTSEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + for run in runs + ] + + return result, total + + +def update_tts_run( + *, + session: Session, + run_id: int, + status: str | None = None, + score: dict[str, Any] | None = None, + error_message: str | None = None, + object_store_url: str | None = None, + batch_job_id: int | None = None, +) -> EvaluationRun | None: + """Update a TTS evaluation run. + + Args: + session: Database session + run_id: Run ID + status: New status + score: Score data + error_message: Error message + object_store_url: URL to stored results + batch_job_id: ID of the associated batch job + + Returns: + EvaluationRun | None: Updated run + """ + statement = select(EvaluationRun).where(EvaluationRun.id == run_id) + run = session.exec(statement).one_or_none() + + if not run: + return None + + updates = { + "status": status, + "score": score, + "error_message": error_message, + "object_store_url": object_store_url, + "batch_job_id": batch_job_id, + } + + for field, value in updates.items(): + if value is not None: + setattr(run, field, value) + + run.updated_at = now() + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[update_tts_run] TTS run updated | run_id: {run_id}, status: {run.status}" + ) + + return run diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py new file mode 100644 index 00000000..48aa0a51 --- /dev/null +++ b/backend/app/models/tts_evaluation.py @@ -0,0 +1,258 @@ +"""TTS Evaluation models for Text-to-Speech evaluation feature.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import Column, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel + +from app.core.util import now +from app.models.job import JobStatus +from app.models.stt_evaluation import EvaluationType + +# Supported TTS models for evaluation +SUPPORTED_TTS_MODELS = ["gemini-2.5-pro-preview-tts"] + + +class TTSResult(SQLModel, table=True): + """Database table for TTS synthesis results.""" + + __tablename__ = "tts_result" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the TTS result"}, + ) + + sample_text: str = SQLField( + sa_column=Column( + Text, + nullable=False, + comment="Input text that was synthesized to speech", + ), + description="Input text that was synthesized to speech", + ) + + object_store_url: str | None = SQLField( + default=None, + sa_column_kwargs={"comment": "S3 URL of the generated WAV audio file"}, + ) + + metadata_: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + "metadata", + JSONB, + nullable=True, + comment="Audio metadata: {duration_seconds, size_bytes}", + ), + description="Audio metadata (duration_seconds, size_bytes)", + ) + + provider: str = SQLField( + max_length=100, + description="TTS provider used (e.g., gemini-2.5-pro-preview-tts)", + sa_column_kwargs={ + "comment": "TTS provider used (e.g., gemini-2.5-pro-preview-tts)" + }, + ) + + status: str = SQLField( + default=JobStatus.PENDING.value, + max_length=20, + description="Result status: PENDING, SUCCESS, FAILED", + sa_column_kwargs={"comment": "Result status: PENDING, SUCCESS, FAILED"}, + ) + + score: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Extensible evaluation metrics (null in Phase 1)", + ), + description="Extensible evaluation metrics", + ) + + is_correct: bool | None = SQLField( + default=None, + description="Human feedback: audio quality correctness", + sa_column_kwargs={ + "comment": "Human feedback: audio quality correctness (null=not reviewed)" + }, + ) + comment: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Human feedback comment", + ), + description="Human feedback comment", + ) + + error_message: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Error message if synthesis failed", + ), + description="Error message if synthesis failed", + ) + + evaluation_run_id: int = SQLField( + foreign_key="evaluation_run.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the evaluation run"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, + ) + + +# --- Pydantic request/response models --- + + +class TTSSampleCreate(BaseModel): + """Request model for a single TTS sample.""" + + text: str = Field(..., description="Text to synthesize", min_length=1) + + +class TTSDatasetCreate(BaseModel): + """Request model for creating a TTS dataset.""" + + name: str = Field(..., description="Dataset name", min_length=1) + description: str | None = Field(None, description="Dataset description") + language_id: int | None = Field( + None, description="ID of the language from global languages table" + ) + samples: list[TTSSampleCreate] = Field( + ..., description="List of text samples", min_length=1 + ) + + +class TTSDatasetPublic(BaseModel): + """Public model for TTS datasets.""" + + id: int + name: str + description: str | None + type: str + language_id: int | None + object_store_url: str | None + dataset_metadata: dict[str, Any] + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class TTSResultPublic(BaseModel): + """Public model for TTS results.""" + + id: int + sample_text: str + object_store_url: str | None + duration_seconds: float | None = None + size_bytes: int | None = None + provider: str + status: str + score: dict[str, Any] | None + is_correct: bool | None + comment: str | None + error_message: str | None + evaluation_run_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class TTSFeedbackUpdate(BaseModel): + """Request model for updating human feedback on a TTS result.""" + + is_correct: bool | None = Field( + None, description="Is the synthesized audio correct?" + ) + comment: str | None = Field(None, description="Feedback comment") + + +class TTSEvaluationRunCreate(BaseModel): + """Request model for starting a TTS evaluation run.""" + + run_name: str = Field(..., description="Name for this evaluation run", min_length=1) + dataset_id: int = Field(..., description="ID of the TTS dataset to evaluate") + models: list[str] = Field( + default_factory=lambda: ["gemini-2.5-pro-preview-tts"], + description="List of TTS models to use", + min_length=1, + ) + + @field_validator("models") + @classmethod + def validate_models(cls, valid_model: list[str]) -> list[str]: + """Validate that all models are supported.""" + if not valid_model: + raise ValueError("At least one model must be specified") + unsupported = [m for m in valid_model if m not in SUPPORTED_TTS_MODELS] + if unsupported: + raise ValueError( + f"Unsupported model(s): {', '.join(unsupported)}. " + f"Supported models are: {', '.join(SUPPORTED_TTS_MODELS)}" + ) + return valid_model + + +class TTSEvaluationRunPublic(BaseModel): + """Public model for TTS evaluation runs.""" + + id: int + run_name: str + dataset_name: str + type: str + language_id: int | None + models: list[str] | None + dataset_id: int + status: str + total_items: int + score: dict[str, Any] | None + error_message: str | None + run_metadata: dict[str, Any] | None = None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class TTSEvaluationRunWithResults(TTSEvaluationRunPublic): + """TTS evaluation run with embedded results.""" + + results: list[TTSResultPublic] + results_total: int = Field(0, description="Total number of results") diff --git a/backend/app/services/tts_evaluations/__init__.py b/backend/app/services/tts_evaluations/__init__.py new file mode 100644 index 00000000..fc19336d --- /dev/null +++ b/backend/app/services/tts_evaluations/__init__.py @@ -0,0 +1 @@ +"""TTS Evaluation services.""" diff --git a/backend/app/services/tts_evaluations/audio.py b/backend/app/services/tts_evaluations/audio.py new file mode 100644 index 00000000..964b6c28 --- /dev/null +++ b/backend/app/services/tts_evaluations/audio.py @@ -0,0 +1,57 @@ +"""Audio processing utilities for TTS evaluation.""" + +import io +import struct +import wave + + +def pcm_to_wav( + pcm_data: bytes, + sample_rate: int = 24000, + bits_per_sample: int = 16, + channels: int = 1, +) -> bytes: + """Wrap raw PCM audio data in a WAV container. + + Args: + pcm_data: Raw PCM audio bytes + sample_rate: Sample rate in Hz (default: 24000 for Gemini TTS) + bits_per_sample: Bits per sample (default: 16) + channels: Number of audio channels (default: 1 mono) + + Returns: + WAV file bytes with proper headers + """ + output = io.BytesIO() + + with wave.open(output, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bits_per_sample // 8) + wav_file.setframerate(sample_rate) + wav_file.writeframes(pcm_data) + + return output.getvalue() + + +def calculate_duration( + pcm_size: int, + sample_rate: int = 24000, + bits_per_sample: int = 16, + channels: int = 1, +) -> float: + """Calculate audio duration from PCM data size. + + Args: + pcm_size: Size of raw PCM data in bytes + sample_rate: Sample rate in Hz + bits_per_sample: Bits per sample + channels: Number of audio channels + + Returns: + Duration in seconds + """ + bytes_per_sample = bits_per_sample // 8 + bytes_per_second = sample_rate * bytes_per_sample * channels + if bytes_per_second == 0: + return 0.0 + return pcm_size / bytes_per_second diff --git a/backend/app/services/tts_evaluations/constants.py b/backend/app/services/tts_evaluations/constants.py new file mode 100644 index 00000000..3a0b812c --- /dev/null +++ b/backend/app/services/tts_evaluations/constants.py @@ -0,0 +1,10 @@ +"""Shared constants for TTS evaluation services.""" + +# Default TTS model +DEFAULT_TTS_MODEL = "gemini-2.5-pro-preview-tts" + +# Default voice configuration +DEFAULT_VOICE_NAME = "Kore" + +# Default style prompt for TTS synthesis +DEFAULT_STYLE_PROMPT = "Read in a calm, professional customer service tone" diff --git a/backend/app/services/tts_evaluations/dataset.py b/backend/app/services/tts_evaluations/dataset.py new file mode 100644 index 00000000..fb2ffa33 --- /dev/null +++ b/backend/app/services/tts_evaluations/dataset.py @@ -0,0 +1,180 @@ +"""Dataset management service for TTS evaluations.""" + +import csv +import io +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import ( + generate_timestamped_filename, + upload_to_object_store, +) +from app.crud.tts_evaluations import create_tts_dataset +from app.models import EvaluationDataset +from app.models.tts_evaluation import TTSSampleCreate + +logger = logging.getLogger(__name__) + + +def upload_tts_dataset( + session: Session, + name: str, + samples: list[TTSSampleCreate], + organization_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, +) -> EvaluationDataset: + """Orchestrate TTS dataset upload workflow. + + Steps: + 1. Convert samples to CSV format + 2. Upload CSV to object store + 3. Create dataset record in database + + Args: + session: Database session + name: Dataset name + samples: List of TTS text samples + organization_id: Organization ID + project_id: Project ID + description: Optional dataset description + language_id: Optional reference to global.languages table + + Returns: + Created dataset record + """ + logger.info( + f"[upload_tts_dataset] Uploading TTS dataset | name={name} | " + f"sample_count={len(samples)} | org_id={organization_id} | " + f"project_id={project_id}" + ) + + # Step 1: Convert samples to CSV and upload to object store + object_store_url = _upload_samples_to_object_store( + session=session, + project_id=project_id, + dataset_name=name, + samples=samples, + ) + + # Step 2: Calculate metadata + metadata: dict[str, Any] = { + "sample_count": len(samples), + } + + # Step 3: Create dataset record + try: + dataset = create_tts_dataset( + session=session, + name=name, + org_id=organization_id, + project_id=project_id, + description=description, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=metadata, + ) + + logger.info( + f"[upload_tts_dataset] Created dataset record | " + f"id={dataset.id} | name={name}" + ) + + session.commit() + + return dataset + + except Exception: + session.rollback() + raise + + +def _upload_samples_to_object_store( + session: Session, + project_id: int, + dataset_name: str, + samples: list[TTSSampleCreate], +) -> str | None: + """Upload TTS samples as CSV to object store. + + Args: + session: Database session + project_id: Project ID for storage credentials + dataset_name: Dataset name for filename + samples: List of samples to upload + + Returns: + Object store URL if successful, None otherwise + """ + try: + storage = get_cloud_storage(session=session, project_id=project_id) + + csv_content = _samples_to_csv(samples) + + filename = generate_timestamped_filename(dataset_name, "csv") + object_store_url = upload_to_object_store( + storage=storage, + content=csv_content, + filename=filename, + subdirectory="tts_datasets", + content_type="text/csv", + ) + + if object_store_url: + logger.info( + f"[_upload_samples_to_object_store] Upload successful | " + f"url={object_store_url}" + ) + else: + logger.info( + "[_upload_samples_to_object_store] Upload returned None | " + "continuing without object store storage" + ) + + return object_store_url + + except Exception as e: + logger.warning( + f"[_upload_samples_to_object_store] Failed to upload | {e}", + exc_info=True, + ) + return None + + +def _samples_to_csv(samples: list[TTSSampleCreate]) -> bytes: + """Convert TTS samples to CSV format. + + Args: + samples: List of TTS samples + + Returns: + CSV content as bytes + """ + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(["text"]) + for sample in samples: + writer.writerow([sample.text]) + return output.getvalue().encode("utf-8") + + +def parse_tts_samples_from_csv(csv_content: bytes) -> list[dict[str, Any]]: + """Parse TTS samples from CSV content. + + Args: + csv_content: CSV file content as bytes + + Returns: + List of dicts with {index, text} for each sample + """ + reader = csv.DictReader(io.StringIO(csv_content.decode("utf-8"))) + samples = [] + for i, row in enumerate(reader): + text = row.get("text", "").strip() + if text: + samples.append({"index": i, "text": text}) + return samples From e63e8bdbc552171f826c17e19c25a41d879637aa Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 10:59:21 +0530 Subject: [PATCH 2/7] updated migratino --- ...luation_tables.py => 048_add_tts_evaluation_tables.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename backend/app/alembic/versions/{047_add_tts_evaluation_tables.py => 048_add_tts_evaluation_tables.py} (98%) diff --git a/backend/app/alembic/versions/047_add_tts_evaluation_tables.py b/backend/app/alembic/versions/048_add_tts_evaluation_tables.py similarity index 98% rename from backend/app/alembic/versions/047_add_tts_evaluation_tables.py rename to backend/app/alembic/versions/048_add_tts_evaluation_tables.py index 3a970b00..0b9ae126 100644 --- a/backend/app/alembic/versions/047_add_tts_evaluation_tables.py +++ b/backend/app/alembic/versions/048_add_tts_evaluation_tables.py @@ -1,7 +1,7 @@ """add tts evaluation tables -Revision ID: 047 -Revises: 046 +Revision ID: 048 +Revises: 047 Create Date: 2026-02-14 12:00:00.000000 """ @@ -11,8 +11,8 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = "047" -down_revision = "046" +revision = "048" +down_revision = "047" branch_labels = None depends_on = None From 348e81bd6523b9ed7beb5fe465e07b25d4d3e831 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 11:24:53 +0530 Subject: [PATCH 3/7] update to custom id --- backend/app/crud/tts_evaluations/cron.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py index d9a32bab..9f1e0bfe 100644 --- a/backend/app/crud/tts_evaluations/cron.py +++ b/backend/app/crud/tts_evaluations/cron.py @@ -16,7 +16,12 @@ from sqlalchemy import Integer from sqlmodel import Session, select -from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status +from app.core.batch import ( + BATCH_KEY, + BatchJobState, + GeminiBatchProvider, + poll_batch_status, +) from app.core.cloud.storage import get_cloud_storage from app.core.storage_utils import upload_to_object_store from app.crud.tts_evaluations.result import count_results_by_status, update_tts_result @@ -412,13 +417,13 @@ async def process_completed_tts_batch( ) for batch_result in results: - custom_id = batch_result["custom_id"] + custom_id = batch_result[BATCH_KEY] try: result_id = int(custom_id) except (ValueError, TypeError): logger.warning( - f"[process_completed_tts_batch] Invalid custom_id | " - f"batch_job_id={batch_job.id}, custom_id={custom_id}" + f"[process_completed_tts_batch] Invalid {BATCH_KEY} | " + f"batch_job_id={batch_job.id}, {BATCH_KEY}={custom_id}" ) failed_count += 1 continue From f9cfdb648eab2327831eb450bcb0f96818a93c76 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 11:37:29 +0530 Subject: [PATCH 4/7] fix bug for parsing audio --- backend/app/core/batch/gemini.py | 6 +++--- backend/app/crud/stt_evaluations/cron.py | 24 ++++++++++++++++++--- backend/app/tests/core/batch/test_gemini.py | 14 +++++++++--- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 40f42957..4349997d 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -219,14 +219,14 @@ def download_batch_results( parsed = json.loads(line) custom_id = parsed.get("key", str(i)) - # Extract text from response + # Return the raw response so callers can extract + # text (STT) or audio (TTS) as needed. response_obj = parsed.get("response") if response_obj: - text = self._extract_text_from_response_dict(response_obj) results.append( { BATCH_KEY: custom_id, - "response": {"text": text}, + "response": response_obj, "error": None, } ) diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index ceaf4a2a..fd6c3e8d 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -14,8 +14,12 @@ from sqlalchemy import Integer from sqlmodel import Session, select -from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status -from app.core.batch.base import BATCH_KEY +from app.core.batch import ( + BATCH_KEY, + BatchJobState, + GeminiBatchProvider, + poll_batch_status, +) from app.core.util import now from app.crud.stt_evaluations.result import count_results_by_status from app.crud.stt_evaluations.run import update_stt_run @@ -27,6 +31,20 @@ logger = logging.getLogger(__name__) + +def _extract_text_from_response(response: dict[str, Any]) -> str: + """Extract text from a raw Gemini response dict.""" + if "text" in response: + return response["text"] + text = "" + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text += part["text"] + return text + + # Terminal states that indicate batch processing is complete TERMINAL_STATES = { BatchJobState.SUCCEEDED.value, @@ -415,7 +433,7 @@ async def process_completed_stt_batch( } if response.get("response"): - row["transcription"] = response["response"].get("text", "") + row["transcription"] = _extract_text_from_response(response["response"]) row["status"] = JobStatus.SUCCESS.value success_count += 1 else: diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index fc62f7b4..c9281cfb 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -233,9 +233,15 @@ def test_download_batch_results_success(self, provider, mock_genai_client): ) assert len(results) == 2 assert results[0]["custom_id"] == "req-1" - assert results[0]["response"]["text"] == "Hello" + assert ( + results[0]["response"]["candidates"][0]["content"]["parts"][0]["text"] + == "Hello" + ) assert results[1]["custom_id"] == "req-2" - assert results[1]["response"]["text"] == "World" + assert ( + results[1]["response"]["candidates"][0]["content"]["parts"][0]["text"] + == "World" + ) def test_download_batch_results_with_direct_text_response( self, provider, mock_genai_client @@ -255,7 +261,9 @@ def test_download_batch_results_with_direct_text_response( results = provider.download_batch_results(batch_id) assert len(results) == 1 - assert results[0]["response"]["text"] == "Direct text" + assert ( + results[0]["response"]["text"] == "Direct text" + ) # raw response passthrough def test_download_batch_results_with_errors(self, provider, mock_genai_client): """Test downloading batch results that contain errors.""" From 8b37b8bb29f80ce4bdc5afdb1b521f82151d5707 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 11:54:59 +0530 Subject: [PATCH 5/7] updats --- backend/app/crud/tts_evaluations/cron.py | 25 ++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py index 9f1e0bfe..528d7223 100644 --- a/backend/app/crud/tts_evaluations/cron.py +++ b/backend/app/crud/tts_evaluations/cron.py @@ -24,7 +24,11 @@ ) from app.core.cloud.storage import get_cloud_storage from app.core.storage_utils import upload_to_object_store -from app.crud.tts_evaluations.result import count_results_by_status, update_tts_result +from app.crud.tts_evaluations.result import ( + count_results_by_status, + get_pending_results_for_run, + update_tts_result, +) from app.crud.tts_evaluations.run import update_tts_run from app.models import EvaluationRun from app.models.batch_job import BatchJob @@ -285,9 +289,26 @@ async def poll_tts_run( for batch_job in batch_jobs: provider_name = batch_job.config.get("tts_provider", "unknown") - # Skip batch jobs already in terminal state + # Handle batch jobs already in terminal state if batch_job.provider_status in TERMINAL_STATES: if batch_job.provider_status == BatchJobState.SUCCEEDED.value: + # Check if there are still unprocessed results for this batch. + # This handles retries when a previous processing attempt failed. + pending_results = get_pending_results_for_run( + session=session, run_id=run.id, provider=provider_name + ) + if pending_results: + logger.info( + f"[poll_tts_run] {log_prefix} Reprocessing SUCCEEDED batch " + f"with {len(pending_results)} pending results | " + f"batch_job_id={batch_job.id}" + ) + await process_completed_tts_batch( + session=session, + run=run, + batch_job=batch_job, + batch_provider=batch_provider, + ) any_succeeded = True else: any_failed = True From 603a6eeaf4173742586c4a6a210ee82c3334086b Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 12:50:08 +0530 Subject: [PATCH 6/7] first stab at moving to celery --- backend/app/core/batch/gemini.py | 12 +- backend/app/crud/tts_evaluations/cron.py | 270 +++++------------ .../batch_result_processing.py | 274 ++++++++++++++++++ 3 files changed, 355 insertions(+), 201 deletions(-) create mode 100644 backend/app/services/tts_evaluations/batch_result_processing.py diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 4349997d..9a3bb877 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -207,11 +207,19 @@ def download_batch_results( results: list[dict[str, Any]] = [] # Handle file-based results (keys are included in each response line) - if ( + has_dest_file = ( batch_job.dest and hasattr(batch_job.dest, "file_name") and batch_job.dest.file_name - ): + ) + if not has_dest_file: + logger.warning( + f"[download_batch_results] No dest file found | " + f"batch_id={output_file_id} | " + f"dest={batch_job.dest} | " + f"dest_attrs={dir(batch_job.dest) if batch_job.dest else 'None'}" + ) + if has_dest_file: file_content = self.download_file(batch_job.dest.file_name) lines = file_content.strip().split("\n") for i, line in enumerate(lines): diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py index 528d7223..15864033 100644 --- a/backend/app/crud/tts_evaluations/cron.py +++ b/backend/app/crud/tts_evaluations/cron.py @@ -1,42 +1,36 @@ """Cron processing functions for TTS evaluations. This module provides functions that are called periodically to process -pending TTS evaluations - polling batch status and processing completed batches. +pending TTS evaluations - polling batch status and dispatching result +processing to Celery workers. Follows the same pattern as STT evaluations: single query to fetch all processing runs, grouped by project_id for credential management. """ -import base64 import logging -import uuid from collections import defaultdict from typing import Any from sqlalchemy import Integer from sqlmodel import Session, select +from app.celery.utils import start_low_priority_job from app.core.batch import ( - BATCH_KEY, BatchJobState, GeminiBatchProvider, poll_batch_status, ) -from app.core.cloud.storage import get_cloud_storage -from app.core.storage_utils import upload_to_object_store from app.crud.tts_evaluations.result import ( count_results_by_status, get_pending_results_for_run, - update_tts_result, ) from app.crud.tts_evaluations.run import update_tts_run from app.models import EvaluationRun from app.models.batch_job import BatchJob from app.models.job import JobStatus from app.models.stt_evaluation import EvaluationType -from app.models.tts_evaluation import TTSResult from app.services.stt_evaluations.gemini import GeminiClient -from app.services.tts_evaluations.audio import calculate_duration, pcm_to_wav logger = logging.getLogger(__name__) @@ -48,6 +42,11 @@ BatchJobState.EXPIRED.value, } +# Function path for Celery task dispatch +_TTS_RESULT_PROCESSING_PATH = ( + "app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing" +) + async def poll_all_pending_tts_evaluations( session: Session, @@ -145,7 +144,7 @@ async def poll_all_pending_tts_evaluations( ) all_results.append(result) - if result["action"] in ("completed", "processed"): + if result["action"] in ("completed", "processed", "dispatched"): total_processed += 1 elif result["action"] == "failed": total_failed += 1 @@ -236,6 +235,42 @@ def _get_batch_jobs_for_run( return list(session.exec(stmt).all()) +def _dispatch_tts_result_processing( + run: EvaluationRun, + batch_job: BatchJob, + org_id: int, + provider_name: str, +) -> str: + """Dispatch TTS result processing to Celery low priority queue. + + Args: + run: The evaluation run + batch_job: The batch job record + org_id: Organization ID + provider_name: TTS provider/model name + + Returns: + str: Celery task ID + """ + celery_task_id = start_low_priority_job( + function_path=_TTS_RESULT_PROCESSING_PATH, + project_id=run.project_id, + job_id=str(batch_job.id), + organization_id=org_id, + evaluation_run_id=run.id, + tts_provider=provider_name, + provider_batch_id=batch_job.provider_batch_id, + ) + + logger.info( + f"[_dispatch_tts_result_processing] Dispatched to Celery | " + f"run_id={run.id}, batch_job_id={batch_job.id}, " + f"provider={provider_name}, celery_task_id={celery_task_id}" + ) + + return celery_task_id + + async def poll_tts_run( session: Session, run: EvaluationRun, @@ -245,7 +280,7 @@ async def poll_tts_run( """Poll a single TTS evaluation run's batch status. Finds all batch jobs for this run (one per provider) and polls each. - Only marks the run as complete when all batch jobs are in terminal states. + When a batch reaches SUCCEEDED, dispatches result processing to Celery. Args: session: Database session @@ -284,6 +319,7 @@ async def poll_tts_run( all_terminal = True any_succeeded = False any_failed = False + dispatched = False errors: list[str] = [] for batch_job in batch_jobs: @@ -299,16 +335,17 @@ async def poll_tts_run( ) if pending_results: logger.info( - f"[poll_tts_run] {log_prefix} Reprocessing SUCCEEDED batch " - f"with {len(pending_results)} pending results | " + f"[poll_tts_run] {log_prefix} Dispatching reprocessing for " + f"{len(pending_results)} pending results | " f"batch_job_id={batch_job.id}" ) - await process_completed_tts_batch( - session=session, + _dispatch_tts_result_processing( run=run, batch_job=batch_job, - batch_provider=batch_provider, + org_id=org_id, + provider_name=provider_name, ) + dispatched = True any_succeeded = True else: any_failed = True @@ -337,15 +374,16 @@ async def poll_tts_run( all_terminal = False continue - # Batch reached terminal state - process it + # Batch reached terminal state - dispatch processing to Celery if provider_status == BatchJobState.SUCCEEDED.value: - await process_completed_tts_batch( - session=session, + _dispatch_tts_result_processing( run=run, batch_job=batch_job, - batch_provider=batch_provider, + org_id=org_id, + provider_name=provider_name, ) any_succeeded = True + dispatched = True else: any_failed = True errors.append( @@ -362,7 +400,19 @@ async def poll_tts_run( "action": "no_change", } - # All batch jobs are done - finalize the run + # If we dispatched processing to Celery, keep the run as "processing". + # The Celery task will finalize the run status when done. + if dispatched: + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": "processing", + "action": "dispatched", + } + + # All batch jobs are done and no dispatching needed - finalize the run status_counts = count_results_by_status(session=session, run_id=run.id) pending = status_counts.get(JobStatus.PENDING.value, 0) failed_count = status_counts.get(JobStatus.FAILED.value, 0) @@ -392,181 +442,3 @@ async def poll_tts_run( "action": action, **({"error": error_message} if error_message else {}), } - - -async def process_completed_tts_batch( - session: Session, - run: EvaluationRun, - batch_job: Any, - batch_provider: GeminiBatchProvider, -) -> None: - """Process completed Gemini batch - download results, convert audio, upload to S3. - - For each result: - 1. Download JSONL results from Gemini - 2. Extract base64-encoded PCM audio from response - 3. Decode base64 -> raw PCM bytes - 4. Wrap in WAV container (24kHz, 16-bit, mono) - 5. Upload WAV to S3 - 6. Update TTSResult with object_store_url and metadata - - Args: - session: Database session - run: The evaluation run - batch_job: The BatchJob record - batch_provider: Initialized GeminiBatchProvider - """ - logger.info( - f"[process_completed_tts_batch] Processing batch results | " - f"run_id={run.id}, batch_job_id={batch_job.id}" - ) - - tts_provider = batch_job.config.get("tts_provider", "gemini-2.5-pro-preview-tts") - - # Get cloud storage for S3 uploads - storage = get_cloud_storage(session=session, project_id=run.project_id) - - processed_count = 0 - failed_count = 0 - - try: - results = batch_provider.download_batch_results(batch_job.provider_batch_id) - - logger.info( - f"[process_completed_tts_batch] Got batch results | " - f"batch_job_id={batch_job.id}, result_count={len(results)}" - ) - - for batch_result in results: - custom_id = batch_result[BATCH_KEY] - try: - result_id = int(custom_id) - except (ValueError, TypeError): - logger.warning( - f"[process_completed_tts_batch] Invalid {BATCH_KEY} | " - f"batch_job_id={batch_job.id}, {BATCH_KEY}={custom_id}" - ) - failed_count += 1 - continue - - # Find result record - stmt = select(TTSResult).where( - TTSResult.id == result_id, - TTSResult.evaluation_run_id == run.id, - TTSResult.provider == tts_provider, - ) - result_record = session.exec(stmt).one_or_none() - - if not result_record: - logger.warning( - f"[process_completed_tts_batch] Result record not found | " - f"result_id={result_id}" - ) - failed_count += 1 - continue - - if batch_result.get("response"): - try: - # Extract base64 audio from Gemini TTS response - audio_b64 = _extract_audio_from_response(batch_result["response"]) - - if not audio_b64: - update_tts_result( - session=session, - result_id=result_record.id, - status=JobStatus.FAILED.value, - error_message="No audio data in response", - ) - failed_count += 1 - continue - - # Decode base64 -> raw PCM bytes - pcm_data = base64.b64decode(audio_b64) - - # Wrap in WAV container - wav_data = pcm_to_wav(pcm_data) - - # Calculate duration - duration = calculate_duration(len(pcm_data)) - - # Upload WAV to S3 - audio_filename = f"{uuid.uuid4()}.wav" - audio_url = upload_to_object_store( - storage=storage, - content=wav_data, - filename=audio_filename, - subdirectory=f"evaluations/tts/audio", - content_type="audio/wav", - ) - - # Update result - update_tts_result( - session=session, - result_id=result_record.id, - object_store_url=audio_url, - metadata={ - "duration_seconds": round(duration, 3), - "size_bytes": len(wav_data), - }, - status=JobStatus.SUCCESS.value, - ) - processed_count += 1 - - except Exception as audio_err: - logger.error( - f"[process_completed_tts_batch] Audio processing failed | " - f"result_id={result_id}, error={str(audio_err)}" - ) - update_tts_result( - session=session, - result_id=result_record.id, - status=JobStatus.FAILED.value, - error_message=f"Audio processing failed: {str(audio_err)}", - ) - failed_count += 1 - else: - update_tts_result( - session=session, - result_id=result_record.id, - status=JobStatus.FAILED.value, - error_message=batch_result.get("error", "Unknown error"), - ) - failed_count += 1 - - session.commit() - - except Exception as e: - logger.error( - f"[process_completed_tts_batch] Failed to process batch results | " - f"batch_job_id={batch_job.id}, error={str(e)}", - exc_info=True, - ) - raise - - logger.info( - f"[process_completed_tts_batch] Batch results processed | " - f"run_id={run.id}, provider={tts_provider}, " - f"processed={processed_count}, failed={failed_count}" - ) - - -def _extract_audio_from_response(response: dict[str, Any]) -> str | None: - """Extract base64-encoded audio data from a Gemini TTS response. - - Gemini TTS returns audio as base64-encoded PCM data in the - inlineData field of the response parts. - - Args: - response: Gemini response dictionary - - Returns: - Base64 encoded audio string, or None if not found - """ - # Navigate: candidates -> content -> parts -> inlineData -> data - for candidate in response.get("candidates", []): - content = candidate.get("content", {}) - for part in content.get("parts", []): - inline_data = part.get("inlineData", {}) - if inline_data.get("data"): - return inline_data["data"] - return None diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py new file mode 100644 index 00000000..ccfdf623 --- /dev/null +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -0,0 +1,274 @@ +"""Celery task function for TTS evaluation result processing. + +Processes completed Gemini TTS batch results: downloads JSONL, +extracts audio, converts PCM to WAV, uploads to S3, updates DB. +""" + +import base64 +import logging +import uuid +from typing import Any + +from sqlmodel import Session, select + +from app.core.batch import BATCH_KEY, GeminiBatchProvider +from app.core.cloud.storage import get_cloud_storage +from app.core.db import engine +from app.core.storage_utils import upload_to_object_store +from app.crud.tts_evaluations.result import ( + count_results_by_status, + update_tts_result, +) +from app.crud.tts_evaluations.run import update_tts_run +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult +from app.services.stt_evaluations.gemini import GeminiClient +from app.services.tts_evaluations.audio import calculate_duration, pcm_to_wav + +logger = logging.getLogger(__name__) + + +def execute_tts_result_processing( + project_id: int, + job_id: str, + task_id: str, + task_instance: Any, + organization_id: int, + evaluation_run_id: int, + tts_provider: str, + provider_batch_id: str, + **kwargs: Any, +) -> dict: + """Process completed TTS batch results in a Celery worker. + + Downloads batch results from Gemini, extracts audio, converts to WAV, + uploads to S3, and updates TTSResult records. + + Args: + project_id: Project ID + job_id: Batch job ID (as string) + task_id: Celery task ID + task_instance: Celery task instance + organization_id: Organization ID + evaluation_run_id: Evaluation run ID + tts_provider: TTS provider/model name + provider_batch_id: Gemini batch job ID + + Returns: + dict: Result summary with processed/failed counts + """ + logger.info( + f"[execute_tts_result_processing] Starting | " + f"run_id={evaluation_run_id}, batch_job_id={job_id}, " + f"provider={tts_provider}, celery_task_id={task_id}" + ) + + with Session(engine) as session: + try: + # Initialize Gemini client and batch provider + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + # Get cloud storage for S3 uploads + storage = get_cloud_storage(session=session, project_id=project_id) + + # Download batch results + results = batch_provider.download_batch_results(provider_batch_id) + + logger.info( + f"[execute_tts_result_processing] Got batch results | " + f"run_id={evaluation_run_id}, result_count={len(results)}" + ) + + if results: + first = results[0] + resp = first.get("response") or {} + logger.info( + f"[execute_tts_result_processing] First result structure | " + f"keys={list(first.keys())}, " + f"response_keys={list(resp.keys()) if isinstance(resp, dict) else type(resp).__name__}" + ) + + processed_count = 0 + failed_count = 0 + + for batch_result in results: + custom_id = batch_result[BATCH_KEY] + try: + result_id = int(custom_id) + except (ValueError, TypeError): + logger.warning( + f"[execute_tts_result_processing] Invalid {BATCH_KEY} | " + f"run_id={evaluation_run_id}, {BATCH_KEY}={custom_id}" + ) + failed_count += 1 + continue + + # Find result record + stmt = select(TTSResult).where( + TTSResult.id == result_id, + TTSResult.evaluation_run_id == evaluation_run_id, + TTSResult.provider == tts_provider, + ) + result_record = session.exec(stmt).one_or_none() + + if not result_record: + logger.warning( + f"[execute_tts_result_processing] Result record not found | " + f"result_id={result_id}" + ) + failed_count += 1 + continue + + if batch_result.get("response"): + try: + audio_b64 = _extract_audio_from_response( + batch_result["response"] + ) + + if not audio_b64: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message="No audio data in response", + ) + failed_count += 1 + continue + + # Decode base64 -> raw PCM bytes + pcm_data = base64.b64decode(audio_b64) + + # Wrap in WAV container + wav_data = pcm_to_wav(pcm_data) + + # Calculate duration + duration = calculate_duration(len(pcm_data)) + + # Upload WAV to S3 + audio_filename = f"{uuid.uuid4()}.wav" + audio_url = upload_to_object_store( + storage=storage, + content=wav_data, + filename=audio_filename, + subdirectory="evaluations/tts/audio", + content_type="audio/wav", + ) + + # Update result + update_tts_result( + session=session, + result_id=result_record.id, + object_store_url=audio_url, + metadata={ + "duration_seconds": round(duration, 3), + "size_bytes": len(wav_data), + }, + status=JobStatus.SUCCESS.value, + ) + processed_count += 1 + + except Exception as audio_err: + logger.error( + f"[execute_tts_result_processing] Audio processing failed | " + f"result_id={result_id}, error={str(audio_err)}" + ) + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=f"Audio processing failed: {str(audio_err)}", + ) + failed_count += 1 + else: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 + + session.commit() + + # Finalize run status + status_counts = count_results_by_status( + session=session, run_id=evaluation_run_id + ) + pending = status_counts.get(JobStatus.PENDING.value, 0) + total_failed = status_counts.get(JobStatus.FAILED.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = ( + f"{total_failed} synthesis(es) failed" if total_failed > 0 else None + ) + + update_tts_run( + session=session, + run_id=evaluation_run_id, + status=final_status, + error_message=error_message, + ) + + logger.info( + f"[execute_tts_result_processing] Completed | " + f"run_id={evaluation_run_id}, provider={tts_provider}, " + f"processed={processed_count}, failed={failed_count}, " + f"run_status={final_status}" + ) + + return { + "success": True, + "run_id": evaluation_run_id, + "processed": processed_count, + "failed": failed_count, + "run_status": final_status, + } + + except Exception as e: + logger.error( + f"[execute_tts_result_processing] Failed | " + f"run_id={evaluation_run_id}, error={str(e)}", + exc_info=True, + ) + update_tts_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message=f"Result processing failed: {str(e)}", + ) + return {"success": False, "error": str(e)} + + +def _extract_audio_from_response(response: dict[str, Any]) -> str | None: + """Extract base64-encoded audio data from a Gemini TTS response. + + Gemini TTS returns audio as base64-encoded PCM data in the + inlineData field of the response parts. Handles both camelCase + (REST API) and snake_case (Python SDK / batch JSONL) field names. + + Args: + response: Gemini response dictionary + + Returns: + Base64 encoded audio string, or None if not found + """ + # Navigate: candidates -> content -> parts -> inlineData/inline_data -> data + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + # Handle both camelCase (inlineData) and snake_case (inline_data) + inline_data = part.get("inlineData") or part.get("inline_data") or {} + if inline_data.get("data"): + return inline_data["data"] + + logger.warning( + f"[_extract_audio_from_response] No audio data found | " + f"response_keys={list(response.keys())}, " + f"parts={[list(p.keys()) for c in response.get('candidates', []) for p in c.get('content', {}).get('parts', [])]}" + ) + return None From 3b6c2fe9c42c97546d0bf66612ef90441a4f876f Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 24 Feb 2026 13:16:46 +0530 Subject: [PATCH 7/7] cleanups and refactoring --- .../app/api/routes/tts_evaluations/dataset.py | 30 +---- .../api/routes/tts_evaluations/evaluation.py | 8 -- .../app/api/routes/tts_evaluations/result.py | 42 +------ backend/app/core/batch/__init__.py | 2 + backend/app/core/batch/gemini.py | 45 ++++--- backend/app/crud/evaluations/cron.py | 6 +- backend/app/crud/evaluations/cron_utils.py | 117 ++++++++++++++++++ backend/app/crud/stt_evaluations/cron.py | 108 ++++------------ backend/app/crud/tts_evaluations/cron.py | 95 +++----------- backend/app/crud/tts_evaluations/dataset.py | 17 +-- backend/app/crud/tts_evaluations/result.py | 50 +------- backend/app/models/tts_evaluation.py | 39 ++++++ 12 files changed, 240 insertions(+), 319 deletions(-) create mode 100644 backend/app/crud/evaluations/cron_utils.py diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py index 79b3acd4..9c185143 100644 --- a/backend/app/api/routes/tts_evaluations/dataset.py +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -56,21 +56,7 @@ def create_dataset( language_id=dataset_create.language_id, ) - return APIResponse.success_response( - data=TTSDatasetPublic( - id=dataset.id, - name=dataset.name, - description=dataset.description, - type=dataset.type, - language_id=dataset.language_id, - object_store_url=dataset.object_store_url, - dataset_metadata=dataset.dataset_metadata, - organization_id=dataset.organization_id, - project_id=dataset.project_id, - inserted_at=dataset.inserted_at, - updated_at=dataset.updated_at, - ) - ) + return APIResponse.success_response(data=TTSDatasetPublic.from_model(dataset)) @router.get( @@ -125,19 +111,7 @@ def get_dataset( raise HTTPException(status_code=404, detail="Dataset not found") return APIResponse.success_response( - data=TTSDatasetPublic( - id=dataset.id, - name=dataset.name, - description=dataset.description, - type=dataset.type, - language_id=dataset.language_id, - object_store_url=dataset.object_store_url, - dataset_metadata=dataset.dataset_metadata, - organization_id=dataset.organization_id, - project_id=dataset.project_id, - inserted_at=dataset.inserted_at, - updated_at=dataset.updated_at, - ), + data=TTSDatasetPublic.from_model(dataset), metadata={ "sample_count": (dataset.dataset_metadata or {}).get("sample_count", 0) }, diff --git a/backend/app/api/routes/tts_evaluations/evaluation.py b/backend/app/api/routes/tts_evaluations/evaluation.py index d60f32b4..a1e4575d 100644 --- a/backend/app/api/routes/tts_evaluations/evaluation.py +++ b/backend/app/api/routes/tts_evaluations/evaluation.py @@ -243,10 +243,6 @@ def get_tts_evaluation_run( auth_context: AuthContextDep, run_id: int, include_results: bool = Query(True, description="Include results in response"), - result_limit: int = Query(100, ge=1, le=1000, description="Max results to return"), - result_offset: int = Query(0, ge=0, description="Result offset"), - provider: str | None = Query(None, description="Filter results by provider"), - status: str | None = Query(None, description="Filter results by status"), ) -> APIResponse[TTSEvaluationRunWithResults]: """Get a TTS evaluation run with results.""" run = get_tts_run_by_id( @@ -268,10 +264,6 @@ def get_tts_evaluation_run( run_id=run_id, org_id=auth_context.organization_.id, project_id=auth_context.project_.id, - provider=provider, - status=status, - limit=result_limit, - offset=result_offset, ) return APIResponse.success_response( diff --git a/backend/app/api/routes/tts_evaluations/result.py b/backend/app/api/routes/tts_evaluations/result.py index 577e69a4..e4450b3c 100644 --- a/backend/app/api/routes/tts_evaluations/result.py +++ b/backend/app/api/routes/tts_evaluations/result.py @@ -61,26 +61,7 @@ def update_result_feedback( comment=feedback.comment, ) - return APIResponse.success_response( - data=TTSResultPublic( - id=result.id, - sample_text=result.sample_text, - object_store_url=result.object_store_url, - duration_seconds=(result.metadata_ or {}).get("duration_seconds"), - size_bytes=(result.metadata_ or {}).get("size_bytes"), - provider=result.provider, - status=result.status, - score=result.score, - is_correct=result.is_correct, - comment=result.comment, - error_message=result.error_message, - evaluation_run_id=result.evaluation_run_id, - organization_id=result.organization_id, - project_id=result.project_id, - inserted_at=result.inserted_at, - updated_at=result.updated_at, - ) - ) + return APIResponse.success_response(data=TTSResultPublic.from_model(result)) @router.get( @@ -106,23 +87,4 @@ def get_result( if not result: raise HTTPException(status_code=404, detail="Result not found") - return APIResponse.success_response( - data=TTSResultPublic( - id=result.id, - sample_text=result.sample_text, - object_store_url=result.object_store_url, - duration_seconds=(result.metadata_ or {}).get("duration_seconds"), - size_bytes=(result.metadata_ or {}).get("size_bytes"), - provider=result.provider, - status=result.status, - score=result.score, - is_correct=result.is_correct, - comment=result.comment, - error_message=result.error_message, - evaluation_run_id=result.evaluation_run_id, - organization_id=result.organization_id, - project_id=result.project_id, - inserted_at=result.inserted_at, - updated_at=result.updated_at, - ) - ) + return APIResponse.success_response(data=TTSResultPublic.from_model(result)) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 5e0560e0..292bca0d 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -6,6 +6,7 @@ GeminiBatchProvider, create_stt_batch_requests, create_tts_batch_requests, + extract_text_from_response_dict, ) from .openai import OpenAIBatchProvider from .operations import ( @@ -24,6 +25,7 @@ "OpenAIBatchProvider", "create_stt_batch_requests", "create_tts_batch_requests", + "extract_text_from_response_dict", "start_batch_job", "download_batch_results", "process_completed_batch", diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 9a3bb877..e824e646 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -18,6 +18,29 @@ logger = logging.getLogger(__name__) +def extract_text_from_response_dict(response: dict[str, Any]) -> str: + """Extract text content from a Gemini response dictionary. + + Args: + response: Gemini response as a dictionary + + Returns: + str: Extracted text + """ + # Try direct text field first + if "text" in response: + return response["text"] + + # Extract from candidates structure + text = "" + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text += part["text"] + return text + + class BatchJobState(str, Enum): """Gemini batch job states.""" @@ -269,26 +292,8 @@ def download_batch_results( @staticmethod def _extract_text_from_response_dict(response: dict[str, Any]) -> str: - """Extract text content from a Gemini response dictionary. - - Args: - response: Gemini response as a dictionary - - Returns: - str: Extracted text - """ - # Try direct text field first - if "text" in response: - return response["text"] - - # Extract from candidates structure - text = "" - for candidate in response.get("candidates", []): - content = candidate.get("content", {}) - for part in content.get("parts", []): - if "text" in part: - text += part["text"] - return text + """Extract text content from a Gemini response dictionary.""" + return extract_text_from_response_dict(response) def upload_file(self, content: str, purpose: str = "batch") -> str: """Upload a JSONL file to Gemini Files API. diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index 309fbe93..61fd64c4 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -13,8 +13,6 @@ from sqlmodel import Session from app.crud.evaluations.processing import poll_all_pending_evaluations -from app.crud.stt_evaluations import poll_all_pending_stt_evaluations -from app.crud.tts_evaluations import poll_all_pending_tts_evaluations logger = logging.getLogger(__name__) @@ -39,6 +37,10 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: # Poll text evaluations (single query, grouped by project) text_summary = await poll_all_pending_evaluations(session=session) + # Lazy imports to avoid circular dependency with cron_utils + from app.crud.stt_evaluations import poll_all_pending_stt_evaluations + from app.crud.tts_evaluations import poll_all_pending_tts_evaluations + # Poll STT evaluations (single query, grouped by project) stt_summary = await poll_all_pending_stt_evaluations(session=session) diff --git a/backend/app/crud/evaluations/cron_utils.py b/backend/app/crud/evaluations/cron_utils.py new file mode 100644 index 00000000..63e4639a --- /dev/null +++ b/backend/app/crud/evaluations/cron_utils.py @@ -0,0 +1,117 @@ +"""Shared utilities for evaluation cron processing. + +Common constants, queries, and helpers used by both STT and TTS +evaluation polling loops. +""" + +from collections import defaultdict + +from sqlalchemy import Integer +from sqlmodel import Session, select + +from app.core.batch import BatchJobState +from app.models import EvaluationRun +from app.models.batch_job import BatchJob + +# Terminal states that indicate batch processing is complete +TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +def fetch_processing_runs( + session: Session, + eval_type: str, +) -> list[EvaluationRun]: + """Fetch all evaluation runs with status='processing' for a given type. + + Args: + session: Database session + eval_type: Evaluation type value (e.g. EvaluationType.STT.value) + + Returns: + list[EvaluationRun]: Runs currently processing + """ + statement = select(EvaluationRun).where( + EvaluationRun.type == eval_type, + EvaluationRun.status == "processing", + EvaluationRun.batch_job_id.is_not(None), + ) + return list(session.exec(statement).all()) + + +def group_runs_by_project( + runs: list[EvaluationRun], +) -> dict[int, list[EvaluationRun]]: + """Group evaluation runs by project_id. + + Args: + runs: List of evaluation runs + + Returns: + dict mapping project_id to list of runs + """ + by_project: dict[int, list[EvaluationRun]] = defaultdict(list) + for run in runs: + by_project[run.project_id].append(run) + return by_project + + +def get_batch_jobs_for_run( + session: Session, + run: EvaluationRun, + job_type: str, +) -> list[BatchJob]: + """Find all batch jobs associated with an evaluation run. + + Args: + session: Database session + run: The evaluation run + job_type: Batch job type (e.g. "stt_evaluation", "tts_evaluation") + + Returns: + list[BatchJob]: All batch jobs for this run + """ + stmt = select(BatchJob).where( + BatchJob.job_type == job_type, + BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, + ) + return list(session.exec(stmt).all()) + + +def make_empty_summary() -> dict: + """Return an empty polling summary.""" + return { + "total": 0, + "processed": 0, + "failed": 0, + "still_processing": 0, + "details": [], + } + + +def make_failure_result( + run: EvaluationRun, + eval_type: str, + error: str, +) -> dict: + """Build a failure result dict for a run. + + Args: + run: The evaluation run + eval_type: Short type label ("stt" or "tts") + error: Error message + + Returns: + dict with run_id, run_name, type, action, and error + """ + return { + "run_id": run.id, + "run_name": run.run_name, + "type": eval_type, + "action": "failed", + "error": error, + } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index fd6c3e8d..e9d30a0e 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -8,19 +8,26 @@ """ import logging -from collections import defaultdict from typing import Any -from sqlalchemy import Integer -from sqlmodel import Session, select +from sqlmodel import Session from app.core.batch import ( BATCH_KEY, BatchJobState, GeminiBatchProvider, + extract_text_from_response_dict, poll_batch_status, ) from app.core.util import now +from app.crud.evaluations.cron_utils import ( + TERMINAL_STATES, + fetch_processing_runs, + get_batch_jobs_for_run, + group_runs_by_project, + make_empty_summary, + make_failure_result, +) from app.crud.stt_evaluations.result import count_results_by_status from app.crud.stt_evaluations.run import update_stt_run from app.models import EvaluationRun @@ -32,28 +39,6 @@ logger = logging.getLogger(__name__) -def _extract_text_from_response(response: dict[str, Any]) -> str: - """Extract text from a raw Gemini response dict.""" - if "text" in response: - return response["text"] - text = "" - for candidate in response.get("candidates", []): - content = candidate.get("content", {}) - for part in content.get("parts", []): - if "text" in part: - text += part["text"] - return text - - -# Terminal states that indicate batch processing is complete -TERMINAL_STATES = { - BatchJobState.SUCCEEDED.value, - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, -} - - async def poll_all_pending_stt_evaluations( session: Session, ) -> dict[str, Any]: @@ -78,32 +63,17 @@ async def poll_all_pending_stt_evaluations( """ logger.info("[poll_all_pending_stt_evaluations] Starting STT evaluation polling") - # Single query to fetch all processing STT evaluation runs - statement = select(EvaluationRun).where( - EvaluationRun.type == EvaluationType.STT.value, - EvaluationRun.status == "processing", - EvaluationRun.batch_job_id.is_not(None), - ) - pending_runs = session.exec(statement).all() + pending_runs = fetch_processing_runs(session, EvaluationType.STT.value) if not pending_runs: logger.info("[poll_all_pending_stt_evaluations] No pending STT runs found") - return { - "total": 0, - "processed": 0, - "failed": 0, - "still_processing": 0, - "details": [], - } + return make_empty_summary() logger.info( f"[poll_all_pending_stt_evaluations] Found {len(pending_runs)} pending STT runs" ) - # Group evaluations by project_id since credentials are per project - evaluations_by_project: dict[int, list[EvaluationRun]] = defaultdict(list) - for run in pending_runs: - evaluations_by_project[run.project_id].append(run) + evaluations_by_project = group_runs_by_project(pending_runs) # Process each project separately all_results: list[dict[str, Any]] = [] @@ -153,15 +123,7 @@ async def poll_all_pending_stt_evaluations( status="failed", error_message=f"Polling failed: {str(e)}", ) - all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "stt", - "action": "failed", - "error": str(e), - } - ) + all_results.append(make_failure_result(run, "stt", str(e))) total_failed += 1 except Exception as e: @@ -178,13 +140,9 @@ async def poll_all_pending_stt_evaluations( error_message=f"Project processing failed: {str(e)}", ) all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "stt", - "action": "failed", - "error": f"Project processing failed: {str(e)}", - } + make_failure_result( + run, "stt", f"Project processing failed: {str(e)}" + ) ) total_failed += len(project_runs) @@ -205,28 +163,6 @@ async def poll_all_pending_stt_evaluations( return summary -def _get_batch_jobs_for_run( - session: Session, - run: EvaluationRun, -) -> list[BatchJob]: - """Find all batch jobs associated with an STT evaluation run. - - Queries batch_job table where config contains the evaluation_run_id. - - Args: - session: Database session - run: The evaluation run - - Returns: - list[BatchJob]: All batch jobs for this run - """ - stmt = select(BatchJob).where( - BatchJob.job_type == "stt_evaluation", - BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, - ) - return list(session.exec(stmt).all()) - - async def poll_stt_run( session: Session, run: EvaluationRun, @@ -255,7 +191,9 @@ async def poll_stt_run( previous_status = run.status # Find all batch jobs for this run - batch_jobs = _get_batch_jobs_for_run(session=session, run=run) + batch_jobs = get_batch_jobs_for_run( + session=session, run=run, job_type="stt_evaluation" + ) if not batch_jobs: logger.warning(f"[poll_stt_run] No batch jobs found | run_id: {run.id}") @@ -373,7 +311,7 @@ async def poll_stt_run( async def process_completed_stt_batch( session: Session, run: EvaluationRun, - batch_job: Any, + batch_job: BatchJob, batch_provider: GeminiBatchProvider, ) -> None: """Process completed Gemini batch - download results and create STT result records. @@ -433,7 +371,9 @@ async def process_completed_stt_batch( } if response.get("response"): - row["transcription"] = _extract_text_from_response(response["response"]) + row["transcription"] = extract_text_from_response_dict( + response["response"] + ) row["status"] = JobStatus.SUCCESS.value success_count += 1 else: diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py index 15864033..545536fb 100644 --- a/backend/app/crud/tts_evaluations/cron.py +++ b/backend/app/crud/tts_evaluations/cron.py @@ -9,11 +9,9 @@ """ import logging -from collections import defaultdict from typing import Any -from sqlalchemy import Integer -from sqlmodel import Session, select +from sqlmodel import Session from app.celery.utils import start_low_priority_job from app.core.batch import ( @@ -21,6 +19,14 @@ GeminiBatchProvider, poll_batch_status, ) +from app.crud.evaluations.cron_utils import ( + TERMINAL_STATES, + fetch_processing_runs, + get_batch_jobs_for_run, + group_runs_by_project, + make_empty_summary, + make_failure_result, +) from app.crud.tts_evaluations.result import ( count_results_by_status, get_pending_results_for_run, @@ -34,14 +40,6 @@ logger = logging.getLogger(__name__) -# Terminal states that indicate batch processing is complete -TERMINAL_STATES = { - BatchJobState.SUCCEEDED.value, - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, -} - # Function path for Celery task dispatch _TTS_RESULT_PROCESSING_PATH = ( "app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing" @@ -65,32 +63,17 @@ async def poll_all_pending_tts_evaluations( """ logger.info("[poll_all_pending_tts_evaluations] Starting TTS evaluation polling") - # Single query to fetch all processing TTS evaluation runs - statement = select(EvaluationRun).where( - EvaluationRun.type == EvaluationType.TTS.value, - EvaluationRun.status == "processing", - EvaluationRun.batch_job_id.is_not(None), - ) - pending_runs = session.exec(statement).all() + pending_runs = fetch_processing_runs(session, EvaluationType.TTS.value) if not pending_runs: logger.info("[poll_all_pending_tts_evaluations] No pending TTS runs found") - return { - "total": 0, - "processed": 0, - "failed": 0, - "still_processing": 0, - "details": [], - } + return make_empty_summary() logger.info( f"[poll_all_pending_tts_evaluations] Found {len(pending_runs)} pending TTS runs" ) - # Group evaluations by project_id since credentials are per project - evaluations_by_project: dict[int, list[EvaluationRun]] = defaultdict(list) - for run in pending_runs: - evaluations_by_project[run.project_id].append(run) + evaluations_by_project = group_runs_by_project(pending_runs) # Process each project separately all_results: list[dict[str, Any]] = [] @@ -120,15 +103,7 @@ async def poll_all_pending_tts_evaluations( status="failed", error_message=f"Gemini client initialization failed: {str(client_err)}", ) - all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "tts", - "action": "failed", - "error": str(client_err), - } - ) + all_results.append(make_failure_result(run, "tts", str(client_err))) total_failed += 1 continue @@ -163,15 +138,7 @@ async def poll_all_pending_tts_evaluations( status="failed", error_message=f"Polling failed: {str(e)}", ) - all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "tts", - "action": "failed", - "error": str(e), - } - ) + all_results.append(make_failure_result(run, "tts", str(e))) total_failed += 1 except Exception as e: @@ -188,13 +155,9 @@ async def poll_all_pending_tts_evaluations( error_message=f"Project processing failed: {str(e)}", ) all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "tts", - "action": "failed", - "error": f"Project processing failed: {str(e)}", - } + make_failure_result( + run, "tts", f"Project processing failed: {str(e)}" + ) ) total_failed += 1 @@ -215,26 +178,6 @@ async def poll_all_pending_tts_evaluations( return summary -def _get_batch_jobs_for_run( - session: Session, - run: EvaluationRun, -) -> list[BatchJob]: - """Find all batch jobs associated with a TTS evaluation run. - - Args: - session: Database session - run: The evaluation run - - Returns: - list[BatchJob]: All batch jobs for this run - """ - stmt = select(BatchJob).where( - BatchJob.job_type == "tts_evaluation", - BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, - ) - return list(session.exec(stmt).all()) - - def _dispatch_tts_result_processing( run: EvaluationRun, batch_job: BatchJob, @@ -296,7 +239,9 @@ async def poll_tts_run( previous_status = run.status - batch_jobs = _get_batch_jobs_for_run(session=session, run=run) + batch_jobs = get_batch_jobs_for_run( + session=session, run=run, job_type="tts_evaluation" + ) if not batch_jobs: logger.warning(f"[poll_tts_run] {log_prefix} No batch jobs found") diff --git a/backend/app/crud/tts_evaluations/dataset.py b/backend/app/crud/tts_evaluations/dataset.py index a2352073..e2956885 100644 --- a/backend/app/crud/tts_evaluations/dataset.py +++ b/backend/app/crud/tts_evaluations/dataset.py @@ -153,21 +153,6 @@ def list_tts_datasets( datasets = session.exec(statement).all() - result = [ - TTSDatasetPublic( - id=dataset.id, - name=dataset.name, - description=dataset.description, - type=dataset.type, - language_id=dataset.language_id, - object_store_url=dataset.object_store_url, - dataset_metadata=dataset.dataset_metadata, - organization_id=dataset.organization_id, - project_id=dataset.project_id, - inserted_at=dataset.inserted_at, - updated_at=dataset.updated_at, - ) - for dataset in datasets - ] + result = [TTSDatasetPublic.from_model(dataset) for dataset in datasets] return result, total diff --git a/backend/app/crud/tts_evaluations/result.py b/backend/app/crud/tts_evaluations/result.py index d21ac6e9..c2de66cf 100644 --- a/backend/app/crud/tts_evaluations/result.py +++ b/backend/app/crud/tts_evaluations/result.py @@ -104,22 +104,14 @@ def get_results_by_run_id( run_id: int, org_id: int, project_id: int, - provider: str | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, ) -> tuple[list[TTSResultPublic], int]: - """Get results for an evaluation run. + """Get all results for an evaluation run. Args: session: Database session run_id: Run ID org_id: Organization ID project_id: Project ID - provider: Optional filter by provider - status: Optional filter by status - limit: Maximum results to return - offset: Number of results to skip Returns: tuple[list[TTSResultPublic], int]: Results and total count @@ -130,46 +122,12 @@ def get_results_by_run_id( TTSResult.project_id == project_id, ] - if provider is not None: - where_clauses.append(TTSResult.provider == provider) - - if status is not None: - where_clauses.append(TTSResult.status == status) - - count_stmt = select(func.count(TTSResult.id)).where(*where_clauses) - total = session.exec(count_stmt).one() - - statement = ( - select(TTSResult) - .where(*where_clauses) - .order_by(TTSResult.id) - .offset(offset) - .limit(limit) - ) + statement = select(TTSResult).where(*where_clauses).order_by(TTSResult.id) rows = session.exec(statement).all() + total = len(rows) - results = [ - TTSResultPublic( - id=result.id, - sample_text=result.sample_text, - object_store_url=result.object_store_url, - duration_seconds=(result.metadata_ or {}).get("duration_seconds"), - size_bytes=(result.metadata_ or {}).get("size_bytes"), - provider=result.provider, - status=result.status, - score=result.score, - is_correct=result.is_correct, - comment=result.comment, - error_message=result.error_message, - evaluation_run_id=result.evaluation_run_id, - organization_id=result.organization_id, - project_id=result.project_id, - inserted_at=result.inserted_at, - updated_at=result.updated_at, - ) - for result in rows - ] + results = [TTSResultPublic.from_model(result) for result in rows] return results, total diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py index 48aa0a51..caa45ae7 100644 --- a/backend/app/models/tts_evaluation.py +++ b/backend/app/models/tts_evaluation.py @@ -173,6 +173,23 @@ class TTSDatasetPublic(BaseModel): inserted_at: datetime updated_at: datetime + @classmethod + def from_model(cls, dataset: "EvaluationDataset") -> "TTSDatasetPublic": + """Create from an EvaluationDataset model instance.""" + return cls( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + class TTSResultPublic(BaseModel): """Public model for TTS results.""" @@ -194,6 +211,28 @@ class TTSResultPublic(BaseModel): inserted_at: datetime updated_at: datetime + @classmethod + def from_model(cls, result: "TTSResult") -> "TTSResultPublic": + """Create from a TTSResult model instance.""" + return cls( + id=result.id, + sample_text=result.sample_text, + object_store_url=result.object_store_url, + duration_seconds=(result.metadata_ or {}).get("duration_seconds"), + size_bytes=(result.metadata_ or {}).get("size_bytes"), + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + class TTSFeedbackUpdate(BaseModel): """Request model for updating human feedback on a TTS result."""