diff --git a/backend/app/alembic/versions/048_add_tts_evaluation_tables.py b/backend/app/alembic/versions/048_add_tts_evaluation_tables.py new file mode 100644 index 00000000..0b9ae126 --- /dev/null +++ b/backend/app/alembic/versions/048_add_tts_evaluation_tables.py @@ -0,0 +1,157 @@ +"""add tts evaluation tables + +Revision ID: 048 +Revises: 047 +Create Date: 2026-02-14 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "048" +down_revision = "047" +branch_labels = None +depends_on = None + + +def upgrade(): + # Create tts_result table + op.create_table( + "tts_result", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the TTS result", + ), + sa.Column( + "sample_text", + sa.Text(), + nullable=False, + comment="Input text that was synthesized to speech", + ), + sa.Column( + "object_store_url", + sa.String(), + nullable=True, + comment="S3 URL of the generated WAV audio file", + ), + sa.Column( + "metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Audio metadata: {duration_seconds, size_bytes}", + ), + sa.Column( + "provider", + sa.String(length=100), + nullable=False, + comment="TTS provider used (e.g., gemini-2.5-pro-preview-tts)", + ), + sa.Column( + "status", + sa.String(length=20), + nullable=False, + server_default="PENDING", + comment="Result status: PENDING, SUCCESS, FAILED", + ), + sa.Column( + "score", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Extensible evaluation metrics (null in Phase 1)", + ), + sa.Column( + "is_correct", + sa.Boolean(), + nullable=True, + comment="Human feedback: audio quality correctness (null=not reviewed)", + ), + sa.Column( + "comment", + sa.Text(), + nullable=True, + comment="Human feedback comment", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if synthesis failed", + ), + sa.Column( + "evaluation_run_id", + sa.Integer(), + nullable=False, + comment="Reference to the evaluation run", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was last updated", + ), + sa.ForeignKeyConstraint( + ["evaluation_run_id"], + ["evaluation_run.id"], + name="fk_tts_result_run_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_tts_result_run_id", + "tts_result", + ["evaluation_run_id"], + unique=False, + ) + op.create_index( + "idx_tts_result_feedback", + "tts_result", + ["evaluation_run_id", "is_correct"], + unique=False, + ) + op.create_index( + "idx_tts_result_status", + "tts_result", + ["evaluation_run_id", "status"], + unique=False, + ) + + +def downgrade(): + op.drop_index("idx_tts_result_status", table_name="tts_result") + op.drop_index("idx_tts_result_feedback", table_name="tts_result") + op.drop_index("ix_tts_result_run_id", table_name="tts_result") + op.drop_table("tts_result") diff --git a/backend/app/api/docs/tts_evaluation/create_dataset.md b/backend/app/api/docs/tts_evaluation/create_dataset.md new file mode 100644 index 00000000..7e20bfc2 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/create_dataset.md @@ -0,0 +1,4 @@ +Create a new TTS evaluation dataset with text samples. + +Each sample requires: +- **text**: Text string to be synthesized into speech diff --git a/backend/app/api/docs/tts_evaluation/get_dataset.md b/backend/app/api/docs/tts_evaluation/get_dataset.md new file mode 100644 index 00000000..92531270 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_dataset.md @@ -0,0 +1,3 @@ +Get a TTS evaluation dataset by ID. + +Returns dataset metadata including sample count. diff --git a/backend/app/api/docs/tts_evaluation/get_result.md b/backend/app/api/docs/tts_evaluation/get_result.md new file mode 100644 index 00000000..0d89d90c --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_result.md @@ -0,0 +1,3 @@ +Get a single TTS synthesis result by ID. + +Returns the result including audio URL, metadata, and human feedback status. diff --git a/backend/app/api/docs/tts_evaluation/get_run.md b/backend/app/api/docs/tts_evaluation/get_run.md new file mode 100644 index 00000000..66c141a9 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/get_run.md @@ -0,0 +1,6 @@ +Get a TTS evaluation run by ID with optional results. + +Use query parameters to control result inclusion and pagination: +- `include_results`: Include synthesis results (default: true) +- `result_limit` / `result_offset`: Paginate results +- `provider` / `status`: Filter results diff --git a/backend/app/api/docs/tts_evaluation/list_datasets.md b/backend/app/api/docs/tts_evaluation/list_datasets.md new file mode 100644 index 00000000..6fc2f6c1 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/list_datasets.md @@ -0,0 +1,3 @@ +List all TTS evaluation datasets for the current project. + +Supports pagination with `limit` and `offset` parameters. diff --git a/backend/app/api/docs/tts_evaluation/list_runs.md b/backend/app/api/docs/tts_evaluation/list_runs.md new file mode 100644 index 00000000..e0119a69 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/list_runs.md @@ -0,0 +1,3 @@ +List TTS evaluation runs for the current project. + +Supports filtering by `dataset_id` and `status`, with pagination via `limit` and `offset`. diff --git a/backend/app/api/docs/tts_evaluation/start_evaluation.md b/backend/app/api/docs/tts_evaluation/start_evaluation.md new file mode 100644 index 00000000..543ed2de --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/start_evaluation.md @@ -0,0 +1,8 @@ +Start a TTS evaluation run on a dataset. + +The evaluation will: +1. Process each text sample through the specified TTS providers +2. Generate speech audio using Gemini Batch API +3. Store WAV audio files in S3 for human review + +**Supported providers:** gemini-2.5-pro-preview-tts diff --git a/backend/app/api/docs/tts_evaluation/update_feedback.md b/backend/app/api/docs/tts_evaluation/update_feedback.md new file mode 100644 index 00000000..7701bb52 --- /dev/null +++ b/backend/app/api/docs/tts_evaluation/update_feedback.md @@ -0,0 +1,5 @@ +Update human feedback on a TTS synthesis result. + +Fields: +- **is_correct**: Whether the synthesized audio quality is acceptable (null to clear) +- **comment**: Optional feedback comment diff --git a/backend/app/api/routes/evaluations/__init__.py b/backend/app/api/routes/evaluations/__init__.py index 14a2d87d..c0aa7175 100644 --- a/backend/app/api/routes/evaluations/__init__.py +++ b/backend/app/api/routes/evaluations/__init__.py @@ -4,9 +4,11 @@ from app.api.routes.evaluations import dataset, evaluation from app.api.routes.stt_evaluations.router import router as stt_router +from app.api.routes.tts_evaluations.router import router as tts_router router = APIRouter() router.include_router(dataset.router) router.include_router(stt_router) +router.include_router(tts_router) router.include_router(evaluation.router) diff --git a/backend/app/api/routes/tts_evaluations/__init__.py b/backend/app/api/routes/tts_evaluations/__init__.py new file mode 100644 index 00000000..78c67295 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/__init__.py @@ -0,0 +1 @@ +"""TTS Evaluation API routes.""" diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py new file mode 100644 index 00000000..9c185143 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -0,0 +1,118 @@ +"""TTS dataset API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.language import get_language_by_id +from app.crud.tts_evaluations import ( + get_tts_dataset_by_id, + list_tts_datasets, +) +from app.models.tts_evaluation import ( + TTSDatasetCreate, + TTSDatasetPublic, +) +from app.services.tts_evaluations.dataset import upload_tts_dataset +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/datasets", + response_model=APIResponse[TTSDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Create TTS dataset", + description=load_description("tts_evaluation/create_dataset.md"), +) +def create_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_create: TTSDatasetCreate = Body(...), +) -> APIResponse[TTSDatasetPublic]: + """Create a TTS evaluation dataset.""" + # Validate language_id if provided + if dataset_create.language_id is not None: + language = get_language_by_id( + session=_session, language_id=dataset_create.language_id + ) + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + + dataset = upload_tts_dataset( + session=_session, + name=dataset_create.name, + samples=dataset_create.samples, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + description=dataset_create.description, + language_id=dataset_create.language_id, + ) + + return APIResponse.success_response(data=TTSDatasetPublic.from_model(dataset)) + + +@router.get( + "/datasets", + response_model=APIResponse[list[TTSDatasetPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List TTS datasets", + description=load_description("tts_evaluation/list_datasets.md"), +) +def list_datasets( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[TTSDatasetPublic]]: + """List TTS evaluation datasets.""" + datasets, total = list_tts_datasets( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=datasets, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/datasets/{dataset_id}", + response_model=APIResponse[TTSDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS dataset", + description=load_description("tts_evaluation/get_dataset.md"), +) +def get_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int, +) -> APIResponse[TTSDatasetPublic]: + """Get a TTS evaluation dataset.""" + dataset = get_tts_dataset_by_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + return APIResponse.success_response( + data=TTSDatasetPublic.from_model(dataset), + metadata={ + "sample_count": (dataset.dataset_metadata or {}).get("sample_count", 0) + }, + ) diff --git a/backend/app/api/routes/tts_evaluations/evaluation.py b/backend/app/api/routes/tts_evaluations/evaluation.py new file mode 100644 index 00000000..a1e4575d --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/evaluation.py @@ -0,0 +1,290 @@ +"""TTS evaluation run API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.tts_evaluations import ( + create_tts_run, + create_tts_results, + get_results_by_run_id, + get_tts_dataset_by_id, + get_tts_run_by_id, + list_tts_runs, + start_tts_evaluation_batch, + update_tts_run, +) +from app.models.tts_evaluation import ( + TTSEvaluationRunCreate, + TTSEvaluationRunPublic, + TTSEvaluationRunWithResults, + TTSSampleCreate, +) +from app.services.tts_evaluations.constants import ( + DEFAULT_STYLE_PROMPT, + DEFAULT_VOICE_NAME, +) +from app.services.tts_evaluations.dataset import parse_tts_samples_from_csv +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/runs", + response_model=APIResponse[TTSEvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Start TTS evaluation", + description=load_description("tts_evaluation/start_evaluation.md"), +) +def start_tts_evaluation( + _session: SessionDep, + auth_context: AuthContextDep, + run_create: TTSEvaluationRunCreate = Body(...), +) -> APIResponse[TTSEvaluationRunPublic]: + """Start a TTS evaluation run.""" + logger.info( + f"[start_tts_evaluation] Starting TTS evaluation | " + f"run_name: {run_create.run_name}, dataset_id: {run_create.dataset_id}, " + f"models: {run_create.models}" + ) + + # Validate dataset exists + dataset = get_tts_dataset_by_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + sample_count = (dataset.dataset_metadata or {}).get("sample_count", 0) + + if sample_count == 0: + raise HTTPException(status_code=400, detail="Dataset has no samples") + + # Get sample texts from the dataset CSV + sample_texts = _get_sample_texts_from_dataset( + _session, dataset, auth_context.project_.id + ) + + if not sample_texts: + raise HTTPException( + status_code=400, detail="Could not load samples from dataset" + ) + + language_id = dataset.language_id + + # Create run record + run = create_tts_run( + session=_session, + run_name=run_create.run_name, + dataset_id=run_create.dataset_id, + dataset_name=dataset.name, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + models=run_create.models, + language_id=language_id, + total_items=len(sample_texts) * len(run_create.models), + ) + + # Create result records for each sample text and model + results = create_tts_results( + session=_session, + sample_texts=sample_texts, + evaluation_run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + models=run_create.models, + ) + + try: + batch_result = start_tts_evaluation_batch( + session=_session, + run=run, + results=results, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + logger.info( + f"[start_tts_evaluation] TTS evaluation batch submitted | " + f"run_id: {run.id}, batch_jobs: {list(batch_result.get('batch_jobs', {}).keys())}" + ) + except Exception as e: + logger.error( + f"[start_tts_evaluation] Batch submission failed | " + f"run_id: {run.id}, error: {str(e)}" + ) + update_tts_run( + session=_session, + run_id=run.id, + status="failed", + error_message=str(e), + ) + raise HTTPException(status_code=500, detail=f"Batch submission failed: {e}") + + # Refresh run to get updated status + run = get_tts_run_by_id( + session=_session, + run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response( + data=TTSEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + run_metadata={ + "voice_name": DEFAULT_VOICE_NAME, + "style_prompt": DEFAULT_STYLE_PROMPT, + }, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + ) + + +def _get_sample_texts_from_dataset( + session: "SessionDep", + dataset: "EvaluationDataset", + project_id: int, +) -> list[str]: + """Extract sample texts from a TTS dataset's CSV in S3. + + Args: + session: Database session + dataset: The evaluation dataset record + project_id: Project ID + + Returns: + List of text strings + """ + if not dataset.object_store_url: + logger.warning( + f"[_get_sample_texts_from_dataset] No object_store_url | " + f"dataset_id={dataset.id}" + ) + return [] + + try: + from app.core.cloud.storage import get_cloud_storage + + storage = get_cloud_storage(session=session, project_id=project_id) + csv_bytes = storage.stream(dataset.object_store_url).read() + samples = parse_tts_samples_from_csv(csv_bytes) + return [s["text"] for s in samples] + except Exception as e: + logger.error( + f"[_get_sample_texts_from_dataset] Failed to load CSV | " + f"dataset_id={dataset.id}, error={str(e)}" + ) + return [] + + +@router.get( + "/runs", + response_model=APIResponse[list[TTSEvaluationRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List TTS evaluation runs", + description=load_description("tts_evaluation/list_runs.md"), +) +def list_tts_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int | None = Query(None, description="Filter by dataset ID"), + status: str | None = Query(None, description="Filter by status"), + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[TTSEvaluationRunPublic]]: + """List TTS evaluation runs.""" + runs, total = list_tts_runs( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + dataset_id=dataset_id, + status=status, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=runs, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/runs/{run_id}", + response_model=APIResponse[TTSEvaluationRunWithResults], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS evaluation run", + description=load_description("tts_evaluation/get_run.md"), +) +def get_tts_evaluation_run( + _session: SessionDep, + auth_context: AuthContextDep, + run_id: int, + include_results: bool = Query(True, description="Include results in response"), +) -> APIResponse[TTSEvaluationRunWithResults]: + """Get a TTS evaluation run with results.""" + run = get_tts_run_by_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not run: + raise HTTPException(status_code=404, detail="Evaluation run not found") + + results = [] + results_total = 0 + + if include_results: + results, results_total = get_results_by_run_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response( + data=TTSEvaluationRunWithResults( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + results=results, + results_total=results_total, + ), + metadata={"results_total": results_total}, + ) diff --git a/backend/app/api/routes/tts_evaluations/result.py b/backend/app/api/routes/tts_evaluations/result.py new file mode 100644 index 00000000..e4450b3c --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/result.py @@ -0,0 +1,90 @@ +"""TTS result feedback API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.tts_evaluations import ( + get_tts_result_by_id, + update_tts_human_feedback, +) +from app.models.tts_evaluation import ( + TTSFeedbackUpdate, + TTSResultPublic, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.patch( + "/results/{result_id}", + response_model=APIResponse[TTSResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update human feedback", + description=load_description("tts_evaluation/update_feedback.md"), +) +def update_result_feedback( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, + feedback: TTSFeedbackUpdate = Body(...), +) -> APIResponse[TTSResultPublic]: + """Update human feedback on a TTS result.""" + logger.info( + f"[update_result_feedback] Updating feedback | " + f"result_id: {result_id}, is_correct: {feedback.is_correct}" + ) + + # Verify result exists and belongs to this project + existing = get_tts_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not existing: + raise HTTPException(status_code=404, detail="Result not found") + + # Update feedback + result = update_tts_human_feedback( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + is_correct=feedback.is_correct, + comment=feedback.comment, + ) + + return APIResponse.success_response(data=TTSResultPublic.from_model(result)) + + +@router.get( + "/results/{result_id}", + response_model=APIResponse[TTSResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get TTS result", + description=load_description("tts_evaluation/get_result.md"), +) +def get_result( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, +) -> APIResponse[TTSResultPublic]: + """Get a TTS result by ID.""" + result = get_tts_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + return APIResponse.success_response(data=TTSResultPublic.from_model(result)) diff --git a/backend/app/api/routes/tts_evaluations/router.py b/backend/app/api/routes/tts_evaluations/router.py new file mode 100644 index 00000000..1a73ec30 --- /dev/null +++ b/backend/app/api/routes/tts_evaluations/router.py @@ -0,0 +1,12 @@ +"""Main router for TTS evaluation API routes.""" + +from fastapi import APIRouter + +from . import dataset, evaluation, result + +router = APIRouter(prefix="/evaluations/tts", tags=["TTS Evaluation"]) + +# Include all sub-routers +router.include_router(dataset.router) +router.include_router(evaluation.router) +router.include_router(result.router) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 05fcb03d..292bca0d 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,7 +1,13 @@ """Batch processing infrastructure for LLM providers.""" from .base import BATCH_KEY, BatchProvider -from .gemini import BatchJobState, GeminiBatchProvider, create_stt_batch_requests +from .gemini import ( + BatchJobState, + GeminiBatchProvider, + create_stt_batch_requests, + create_tts_batch_requests, + extract_text_from_response_dict, +) from .openai import OpenAIBatchProvider from .operations import ( download_batch_results, @@ -18,6 +24,8 @@ "GeminiBatchProvider", "OpenAIBatchProvider", "create_stt_batch_requests", + "create_tts_batch_requests", + "extract_text_from_response_dict", "start_batch_job", "download_batch_results", "process_completed_batch", diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 69520ef0..e824e646 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -18,6 +18,29 @@ logger = logging.getLogger(__name__) +def extract_text_from_response_dict(response: dict[str, Any]) -> str: + """Extract text content from a Gemini response dictionary. + + Args: + response: Gemini response as a dictionary + + Returns: + str: Extracted text + """ + # Try direct text field first + if "text" in response: + return response["text"] + + # Extract from candidates structure + text = "" + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text += part["text"] + return text + + class BatchJobState(str, Enum): """Gemini batch job states.""" @@ -207,11 +230,19 @@ def download_batch_results( results: list[dict[str, Any]] = [] # Handle file-based results (keys are included in each response line) - if ( + has_dest_file = ( batch_job.dest and hasattr(batch_job.dest, "file_name") and batch_job.dest.file_name - ): + ) + if not has_dest_file: + logger.warning( + f"[download_batch_results] No dest file found | " + f"batch_id={output_file_id} | " + f"dest={batch_job.dest} | " + f"dest_attrs={dir(batch_job.dest) if batch_job.dest else 'None'}" + ) + if has_dest_file: file_content = self.download_file(batch_job.dest.file_name) lines = file_content.strip().split("\n") for i, line in enumerate(lines): @@ -219,14 +250,14 @@ def download_batch_results( parsed = json.loads(line) custom_id = parsed.get("key", str(i)) - # Extract text from response + # Return the raw response so callers can extract + # text (STT) or audio (TTS) as needed. response_obj = parsed.get("response") if response_obj: - text = self._extract_text_from_response_dict(response_obj) results.append( { BATCH_KEY: custom_id, - "response": {"text": text}, + "response": response_obj, "error": None, } ) @@ -261,26 +292,8 @@ def download_batch_results( @staticmethod def _extract_text_from_response_dict(response: dict[str, Any]) -> str: - """Extract text content from a Gemini response dictionary. - - Args: - response: Gemini response as a dictionary - - Returns: - str: Extracted text - """ - # Try direct text field first - if "text" in response: - return response["text"] - - # Extract from candidates structure - text = "" - for candidate in response.get("candidates", []): - content = candidate.get("content", {}) - for part in content.get("parts", []): - if "text" in part: - text += part["text"] - return text + """Extract text content from a Gemini response dictionary.""" + return extract_text_from_response_dict(response) def upload_file(self, content: str, purpose: str = "batch") -> str: """Upload a JSONL file to Gemini Files API. @@ -441,3 +454,72 @@ def create_stt_batch_requests( logger.info(f"[create_stt_batch_requests] Created {len(requests)} batch requests") return requests + + +def create_tts_batch_requests( + texts: list[str], + voice_name: str, + style_prompt: str | None = None, + keys: list[str] | None = None, +) -> list[dict[str, Any]]: + """Create batch API requests for Gemini TTS. + + Generates request payloads in Gemini's JSONL batch format for + text-to-speech synthesis with audio output. + + Args: + texts: List of text strings to synthesize + voice_name: Prebuilt voice name (e.g., "Kore") + style_prompt: Optional style/tone instructions prepended to text + keys: Optional list of custom IDs for tracking results. If not provided, + uses 0-indexed integers as strings. + + Returns: + List of batch request dicts in Gemini JSONL format + + Example: + >>> texts = ["Hello, how can I help you?"] + >>> requests = create_tts_batch_requests( + ... texts, voice_name="Kore", + ... style_prompt="Read in a calm tone", + ... keys=["result-1"] + ... ) + """ + if keys is not None and len(keys) != len(texts): + raise ValueError( + f"Length of keys ({len(keys)}) must match texts ({len(texts)})" + ) + + requests = [] + for i, text in enumerate(texts): + key = keys[i] if keys is not None else str(i) + + # Prepend style prompt if provided + content_text = f"{style_prompt}: {text}" if style_prompt else text + + request = { + "key": key, + "request": { + "contents": [ + { + "parts": [{"text": content_text}], + "role": "user", + } + ], + "generationConfig": { + "responseModalities": ["AUDIO"], + "speechConfig": { + "voiceConfig": { + "prebuiltVoiceConfig": { + "voiceName": voice_name, + } + } + }, + }, + }, + } + requests.append(request) + + logger.info(f"[create_tts_batch_requests] Created {len(requests)} batch requests") + + return requests diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index e8393736..61fd64c4 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -13,7 +13,6 @@ from sqlmodel import Session from app.crud.evaluations.processing import poll_all_pending_evaluations -from app.crud.stt_evaluations import poll_all_pending_stt_evaluations logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: Delegates to poll_all_pending_evaluations which fetches all processing evaluation runs in a single query, groups by project, and processes them. - Also polls STT evaluations similarly. + Also polls STT and TTS evaluations similarly. Args: session: Database session @@ -38,16 +37,35 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: # Poll text evaluations (single query, grouped by project) text_summary = await poll_all_pending_evaluations(session=session) + # Lazy imports to avoid circular dependency with cron_utils + from app.crud.stt_evaluations import poll_all_pending_stt_evaluations + from app.crud.tts_evaluations import poll_all_pending_tts_evaluations + # Poll STT evaluations (single query, grouped by project) stt_summary = await poll_all_pending_stt_evaluations(session=session) + # Poll TTS evaluations (single query, grouped by project) + tts_summary = await poll_all_pending_tts_evaluations(session=session) + # Merge summaries - total_processed = text_summary["processed"] + stt_summary["processed"] - total_failed = text_summary["failed"] + stt_summary["failed"] + total_processed = ( + text_summary["processed"] + + stt_summary["processed"] + + tts_summary["processed"] + ) + total_failed = ( + text_summary["failed"] + stt_summary["failed"] + tts_summary["failed"] + ) total_still_processing = ( - text_summary["still_processing"] + stt_summary["still_processing"] + text_summary["still_processing"] + + stt_summary["still_processing"] + + tts_summary["still_processing"] + ) + all_details = ( + text_summary.get("details", []) + + stt_summary.get("details", []) + + tts_summary.get("details", []) ) - all_details = text_summary.get("details", []) + stt_summary.get("details", []) logger.info( f"[process_all_pending_evaluations] Completed: " diff --git a/backend/app/crud/evaluations/cron_utils.py b/backend/app/crud/evaluations/cron_utils.py new file mode 100644 index 00000000..63e4639a --- /dev/null +++ b/backend/app/crud/evaluations/cron_utils.py @@ -0,0 +1,117 @@ +"""Shared utilities for evaluation cron processing. + +Common constants, queries, and helpers used by both STT and TTS +evaluation polling loops. +""" + +from collections import defaultdict + +from sqlalchemy import Integer +from sqlmodel import Session, select + +from app.core.batch import BatchJobState +from app.models import EvaluationRun +from app.models.batch_job import BatchJob + +# Terminal states that indicate batch processing is complete +TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +def fetch_processing_runs( + session: Session, + eval_type: str, +) -> list[EvaluationRun]: + """Fetch all evaluation runs with status='processing' for a given type. + + Args: + session: Database session + eval_type: Evaluation type value (e.g. EvaluationType.STT.value) + + Returns: + list[EvaluationRun]: Runs currently processing + """ + statement = select(EvaluationRun).where( + EvaluationRun.type == eval_type, + EvaluationRun.status == "processing", + EvaluationRun.batch_job_id.is_not(None), + ) + return list(session.exec(statement).all()) + + +def group_runs_by_project( + runs: list[EvaluationRun], +) -> dict[int, list[EvaluationRun]]: + """Group evaluation runs by project_id. + + Args: + runs: List of evaluation runs + + Returns: + dict mapping project_id to list of runs + """ + by_project: dict[int, list[EvaluationRun]] = defaultdict(list) + for run in runs: + by_project[run.project_id].append(run) + return by_project + + +def get_batch_jobs_for_run( + session: Session, + run: EvaluationRun, + job_type: str, +) -> list[BatchJob]: + """Find all batch jobs associated with an evaluation run. + + Args: + session: Database session + run: The evaluation run + job_type: Batch job type (e.g. "stt_evaluation", "tts_evaluation") + + Returns: + list[BatchJob]: All batch jobs for this run + """ + stmt = select(BatchJob).where( + BatchJob.job_type == job_type, + BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, + ) + return list(session.exec(stmt).all()) + + +def make_empty_summary() -> dict: + """Return an empty polling summary.""" + return { + "total": 0, + "processed": 0, + "failed": 0, + "still_processing": 0, + "details": [], + } + + +def make_failure_result( + run: EvaluationRun, + eval_type: str, + error: str, +) -> dict: + """Build a failure result dict for a run. + + Args: + run: The evaluation run + eval_type: Short type label ("stt" or "tts") + error: Error message + + Returns: + dict with run_id, run_name, type, action, and error + """ + return { + "run_id": run.id, + "run_name": run.run_name, + "type": eval_type, + "action": "failed", + "error": error, + } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index ceaf4a2a..e9d30a0e 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -8,15 +8,26 @@ """ import logging -from collections import defaultdict from typing import Any -from sqlalchemy import Integer -from sqlmodel import Session, select +from sqlmodel import Session -from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status -from app.core.batch.base import BATCH_KEY +from app.core.batch import ( + BATCH_KEY, + BatchJobState, + GeminiBatchProvider, + extract_text_from_response_dict, + poll_batch_status, +) from app.core.util import now +from app.crud.evaluations.cron_utils import ( + TERMINAL_STATES, + fetch_processing_runs, + get_batch_jobs_for_run, + group_runs_by_project, + make_empty_summary, + make_failure_result, +) from app.crud.stt_evaluations.result import count_results_by_status from app.crud.stt_evaluations.run import update_stt_run from app.models import EvaluationRun @@ -27,14 +38,6 @@ logger = logging.getLogger(__name__) -# Terminal states that indicate batch processing is complete -TERMINAL_STATES = { - BatchJobState.SUCCEEDED.value, - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, -} - async def poll_all_pending_stt_evaluations( session: Session, @@ -60,32 +63,17 @@ async def poll_all_pending_stt_evaluations( """ logger.info("[poll_all_pending_stt_evaluations] Starting STT evaluation polling") - # Single query to fetch all processing STT evaluation runs - statement = select(EvaluationRun).where( - EvaluationRun.type == EvaluationType.STT.value, - EvaluationRun.status == "processing", - EvaluationRun.batch_job_id.is_not(None), - ) - pending_runs = session.exec(statement).all() + pending_runs = fetch_processing_runs(session, EvaluationType.STT.value) if not pending_runs: logger.info("[poll_all_pending_stt_evaluations] No pending STT runs found") - return { - "total": 0, - "processed": 0, - "failed": 0, - "still_processing": 0, - "details": [], - } + return make_empty_summary() logger.info( f"[poll_all_pending_stt_evaluations] Found {len(pending_runs)} pending STT runs" ) - # Group evaluations by project_id since credentials are per project - evaluations_by_project: dict[int, list[EvaluationRun]] = defaultdict(list) - for run in pending_runs: - evaluations_by_project[run.project_id].append(run) + evaluations_by_project = group_runs_by_project(pending_runs) # Process each project separately all_results: list[dict[str, Any]] = [] @@ -135,15 +123,7 @@ async def poll_all_pending_stt_evaluations( status="failed", error_message=f"Polling failed: {str(e)}", ) - all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "stt", - "action": "failed", - "error": str(e), - } - ) + all_results.append(make_failure_result(run, "stt", str(e))) total_failed += 1 except Exception as e: @@ -160,13 +140,9 @@ async def poll_all_pending_stt_evaluations( error_message=f"Project processing failed: {str(e)}", ) all_results.append( - { - "run_id": run.id, - "run_name": run.run_name, - "type": "stt", - "action": "failed", - "error": f"Project processing failed: {str(e)}", - } + make_failure_result( + run, "stt", f"Project processing failed: {str(e)}" + ) ) total_failed += len(project_runs) @@ -187,28 +163,6 @@ async def poll_all_pending_stt_evaluations( return summary -def _get_batch_jobs_for_run( - session: Session, - run: EvaluationRun, -) -> list[BatchJob]: - """Find all batch jobs associated with an STT evaluation run. - - Queries batch_job table where config contains the evaluation_run_id. - - Args: - session: Database session - run: The evaluation run - - Returns: - list[BatchJob]: All batch jobs for this run - """ - stmt = select(BatchJob).where( - BatchJob.job_type == "stt_evaluation", - BatchJob.config["evaluation_run_id"].astext.cast(Integer) == run.id, - ) - return list(session.exec(stmt).all()) - - async def poll_stt_run( session: Session, run: EvaluationRun, @@ -237,7 +191,9 @@ async def poll_stt_run( previous_status = run.status # Find all batch jobs for this run - batch_jobs = _get_batch_jobs_for_run(session=session, run=run) + batch_jobs = get_batch_jobs_for_run( + session=session, run=run, job_type="stt_evaluation" + ) if not batch_jobs: logger.warning(f"[poll_stt_run] No batch jobs found | run_id: {run.id}") @@ -355,7 +311,7 @@ async def poll_stt_run( async def process_completed_stt_batch( session: Session, run: EvaluationRun, - batch_job: Any, + batch_job: BatchJob, batch_provider: GeminiBatchProvider, ) -> None: """Process completed Gemini batch - download results and create STT result records. @@ -415,7 +371,9 @@ async def process_completed_stt_batch( } if response.get("response"): - row["transcription"] = response["response"].get("text", "") + row["transcription"] = extract_text_from_response_dict( + response["response"] + ) row["status"] = JobStatus.SUCCESS.value success_count += 1 else: diff --git a/backend/app/crud/tts_evaluations/__init__.py b/backend/app/crud/tts_evaluations/__init__.py new file mode 100644 index 00000000..fcced112 --- /dev/null +++ b/backend/app/crud/tts_evaluations/__init__.py @@ -0,0 +1,42 @@ +"""TTS Evaluation CRUD operations.""" + +from .batch import start_tts_evaluation_batch +from .cron import poll_all_pending_tts_evaluations +from .dataset import ( + create_tts_dataset, + get_tts_dataset_by_id, + list_tts_datasets, +) +from .run import ( + create_tts_run, + get_tts_run_by_id, + list_tts_runs, + update_tts_run, +) +from .result import ( + create_tts_results, + get_tts_result_by_id, + get_results_by_run_id, + update_tts_human_feedback, +) + +__all__ = [ + # Batch + "start_tts_evaluation_batch", + # Cron + "poll_all_pending_tts_evaluations", + # Dataset + "create_tts_dataset", + "get_tts_dataset_by_id", + "list_tts_datasets", + # Run + "create_tts_run", + "get_tts_run_by_id", + "list_tts_runs", + "update_tts_run", + # Result + "create_tts_results", + "get_tts_result_by_id", + "get_results_by_run_id", + "update_tts_human_feedback", +] diff --git a/backend/app/crud/tts_evaluations/batch.py b/backend/app/crud/tts_evaluations/batch.py new file mode 100644 index 00000000..37c5a63c --- /dev/null +++ b/backend/app/crud/tts_evaluations/batch.py @@ -0,0 +1,174 @@ +"""Batch submission functions for TTS evaluation processing.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch import ( + GeminiBatchProvider, + create_tts_batch_requests, + start_batch_job, +) +from app.crud.tts_evaluations.result import ( + get_pending_results_for_run, + update_tts_result, +) +from app.crud.tts_evaluations.run import update_tts_run +from app.models import EvaluationRun +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult +from app.services.stt_evaluations.gemini import GeminiClient +from app.services.tts_evaluations.constants import ( + DEFAULT_STYLE_PROMPT, + DEFAULT_TTS_MODEL, + DEFAULT_VOICE_NAME, +) + +logger = logging.getLogger(__name__) + + +def start_tts_evaluation_batch( + *, + session: Session, + run: EvaluationRun, + results: list[TTSResult], + org_id: int, + project_id: int, +) -> dict[str, Any]: + """Submit Gemini batch jobs for TTS evaluation. + + Submits one batch job per model. Each batch job is tracked via + its config containing evaluation_run_id and tts_provider. + + Args: + session: Database session + run: The evaluation run record + results: List of TTSResult records (contains sample_text) + org_id: Organization ID + project_id: Project ID + + Returns: + dict: Result with batch job information per model + + Raises: + Exception: If batch submission fails for all models + """ + models = run.providers or [DEFAULT_TTS_MODEL] + + logger.info( + f"[start_tts_evaluation_batch] Starting batch submission | " + f"run_id: {run.id}, result_count: {len(results)}, " + f"models: {models}" + ) + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + + # Collect unique sample texts and their result IDs per model + # Group results by model to build per-model batch requests + results_by_model: dict[str, list[TTSResult]] = {} + for result in results: + results_by_model.setdefault(result.provider, []).append(result) + + # Submit one batch job per model + batch_jobs: dict[str, Any] = {} + first_batch_job_id: int | None = None + + for model in models: + model_results = results_by_model.get(model, []) + if not model_results: + continue + + texts = [r.sample_text for r in model_results] + keys = [str(r.id) for r in model_results] + + # Create JSONL batch requests for TTS + jsonl_data = create_tts_batch_requests( + texts=texts, + voice_name=DEFAULT_VOICE_NAME, + style_prompt=DEFAULT_STYLE_PROMPT, + keys=keys, + ) + + model_path = f"models/{model}" + batch_provider = GeminiBatchProvider( + client=gemini_client.client, model=model_path + ) + + try: + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="google", + job_type="tts_evaluation", + organization_id=org_id, + project_id=project_id, + jsonl_data=jsonl_data, + config={ + "model": model, + "tts_provider": model, + "evaluation_run_id": run.id, + "voice_name": DEFAULT_VOICE_NAME, + "style_prompt": DEFAULT_STYLE_PROMPT, + }, + ) + + batch_jobs[model] = { + "batch_job_id": batch_job.id, + "provider_batch_id": batch_job.provider_batch_id, + } + + if first_batch_job_id is None: + first_batch_job_id = batch_job.id + + logger.info( + f"[start_tts_evaluation_batch] Batch job created | " + f"run_id: {run.id}, model: {model}, " + f"batch_job_id: {batch_job.id}" + ) + + except Exception as e: + logger.error( + f"[start_tts_evaluation_batch] Failed to submit batch | " + f"model: {model}, error: {str(e)}" + ) + pending = get_pending_results_for_run( + session=session, run_id=run.id, provider=model + ) + for result in pending: + update_tts_result( + session=session, + result_id=result.id, + status=JobStatus.FAILED.value, + error_message=f"Batch submission failed for {model}: {str(e)}", + ) + session.commit() + + if not batch_jobs: + raise Exception("Batch submission failed for all models") + + # Link first batch job to the evaluation run (for pending run detection) + update_tts_run( + session=session, + run_id=run.id, + status="processing", + batch_job_id=first_batch_job_id, + ) + + logger.info( + f"[start_tts_evaluation_batch] Batch submission complete | " + f"run_id: {run.id}, models_submitted: {list(batch_jobs.keys())}, " + f"result_count: {len(results)}" + ) + + return { + "success": True, + "run_id": run.id, + "batch_jobs": batch_jobs, + "result_count": len(results), + } diff --git a/backend/app/crud/tts_evaluations/cron.py b/backend/app/crud/tts_evaluations/cron.py new file mode 100644 index 00000000..545536fb --- /dev/null +++ b/backend/app/crud/tts_evaluations/cron.py @@ -0,0 +1,389 @@ +"""Cron processing functions for TTS evaluations. + +This module provides functions that are called periodically to process +pending TTS evaluations - polling batch status and dispatching result +processing to Celery workers. + +Follows the same pattern as STT evaluations: single query to fetch all +processing runs, grouped by project_id for credential management. +""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.celery.utils import start_low_priority_job +from app.core.batch import ( + BatchJobState, + GeminiBatchProvider, + poll_batch_status, +) +from app.crud.evaluations.cron_utils import ( + TERMINAL_STATES, + fetch_processing_runs, + get_batch_jobs_for_run, + group_runs_by_project, + make_empty_summary, + make_failure_result, +) +from app.crud.tts_evaluations.result import ( + count_results_by_status, + get_pending_results_for_run, +) +from app.crud.tts_evaluations.run import update_tts_run +from app.models import EvaluationRun +from app.models.batch_job import BatchJob +from app.models.job import JobStatus +from app.models.stt_evaluation import EvaluationType +from app.services.stt_evaluations.gemini import GeminiClient + +logger = logging.getLogger(__name__) + +# Function path for Celery task dispatch +_TTS_RESULT_PROCESSING_PATH = ( + "app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing" +) + + +async def poll_all_pending_tts_evaluations( + session: Session, +) -> dict[str, Any]: + """Poll all pending TTS evaluations across all organizations. + + Fetches all TTS evaluation runs with status='processing' in a single query, + groups them by project_id, and processes each project with its own + Gemini client. + + Args: + session: Database session + + Returns: + Summary dict with total, processed, failed, still_processing counts + """ + logger.info("[poll_all_pending_tts_evaluations] Starting TTS evaluation polling") + + pending_runs = fetch_processing_runs(session, EvaluationType.TTS.value) + + if not pending_runs: + logger.info("[poll_all_pending_tts_evaluations] No pending TTS runs found") + return make_empty_summary() + + logger.info( + f"[poll_all_pending_tts_evaluations] Found {len(pending_runs)} pending TTS runs" + ) + + evaluations_by_project = group_runs_by_project(pending_runs) + + # Process each project separately + all_results: list[dict[str, Any]] = [] + total_processed = 0 + total_failed = 0 + total_still_processing = 0 + + for project_id, project_runs in evaluations_by_project.items(): + org_id = project_runs[0].organization_id + + try: + try: + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + except Exception as client_err: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to get Gemini client | " + f"org_id={org_id} | project_id={project_id} | error={client_err}" + ) + for run in project_runs: + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Gemini client initialization failed: {str(client_err)}", + ) + all_results.append(make_failure_result(run, "tts", str(client_err))) + total_failed += 1 + continue + + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + for run in project_runs: + try: + result = await poll_tts_run( + session=session, + run=run, + batch_provider=batch_provider, + org_id=org_id, + ) + all_results.append(result) + + if result["action"] in ("completed", "processed", "dispatched"): + total_processed += 1 + elif result["action"] == "failed": + total_failed += 1 + else: + total_still_processing += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to poll TTS run | " + f"run_id={run.id} | {e}", + exc_info=True, + ) + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Polling failed: {str(e)}", + ) + all_results.append(make_failure_result(run, "tts", str(e))) + total_failed += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_tts_evaluations] Failed to process project | " + f"project_id={project_id} | {e}", + exc_info=True, + ) + for run in project_runs: + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message=f"Project processing failed: {str(e)}", + ) + all_results.append( + make_failure_result( + run, "tts", f"Project processing failed: {str(e)}" + ) + ) + total_failed += 1 + + summary = { + "total": len(pending_runs), + "processed": total_processed, + "failed": total_failed, + "still_processing": total_still_processing, + "details": all_results, + } + + logger.info( + f"[poll_all_pending_tts_evaluations] Polling summary | " + f"processed={total_processed} | failed={total_failed} | " + f"still_processing={total_still_processing}" + ) + + return summary + + +def _dispatch_tts_result_processing( + run: EvaluationRun, + batch_job: BatchJob, + org_id: int, + provider_name: str, +) -> str: + """Dispatch TTS result processing to Celery low priority queue. + + Args: + run: The evaluation run + batch_job: The batch job record + org_id: Organization ID + provider_name: TTS provider/model name + + Returns: + str: Celery task ID + """ + celery_task_id = start_low_priority_job( + function_path=_TTS_RESULT_PROCESSING_PATH, + project_id=run.project_id, + job_id=str(batch_job.id), + organization_id=org_id, + evaluation_run_id=run.id, + tts_provider=provider_name, + provider_batch_id=batch_job.provider_batch_id, + ) + + logger.info( + f"[_dispatch_tts_result_processing] Dispatched to Celery | " + f"run_id={run.id}, batch_job_id={batch_job.id}, " + f"provider={provider_name}, celery_task_id={celery_task_id}" + ) + + return celery_task_id + + +async def poll_tts_run( + session: Session, + run: EvaluationRun, + batch_provider: GeminiBatchProvider, + org_id: int, +) -> dict[str, Any]: + """Poll a single TTS evaluation run's batch status. + + Finds all batch jobs for this run (one per provider) and polls each. + When a batch reaches SUCCEEDED, dispatches result processing to Celery. + + Args: + session: Database session + run: The evaluation run to poll + batch_provider: Initialized GeminiBatchProvider + org_id: Organization ID + + Returns: + dict: Status result with run details and action taken + """ + log_prefix = f"[org={org_id}][project={run.project_id}][eval={run.id}]" + logger.info(f"[poll_tts_run] {log_prefix} Polling run") + + previous_status = run.status + + batch_jobs = get_batch_jobs_for_run( + session=session, run=run, job_type="tts_evaluation" + ) + + if not batch_jobs: + logger.warning(f"[poll_tts_run] {log_prefix} No batch jobs found") + update_tts_run( + session=session, + run_id=run.id, + status="failed", + error_message="No batch jobs found", + ) + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": "failed", + "action": "failed", + "error": "No batch jobs found", + } + + all_terminal = True + any_succeeded = False + any_failed = False + dispatched = False + errors: list[str] = [] + + for batch_job in batch_jobs: + provider_name = batch_job.config.get("tts_provider", "unknown") + + # Handle batch jobs already in terminal state + if batch_job.provider_status in TERMINAL_STATES: + if batch_job.provider_status == BatchJobState.SUCCEEDED.value: + # Check if there are still unprocessed results for this batch. + # This handles retries when a previous processing attempt failed. + pending_results = get_pending_results_for_run( + session=session, run_id=run.id, provider=provider_name + ) + if pending_results: + logger.info( + f"[poll_tts_run] {log_prefix} Dispatching reprocessing for " + f"{len(pending_results)} pending results | " + f"batch_job_id={batch_job.id}" + ) + _dispatch_tts_result_processing( + run=run, + batch_job=batch_job, + org_id=org_id, + provider_name=provider_name, + ) + dispatched = True + any_succeeded = True + else: + any_failed = True + errors.append( + f"{provider_name}: {batch_job.error_message or batch_job.provider_status}" + ) + continue + + # Poll batch job status + poll_batch_status( + session=session, + provider=batch_provider, + batch_job=batch_job, + ) + + session.refresh(batch_job) + provider_status = batch_job.provider_status + + logger.info( + f"[poll_tts_run] {log_prefix} Batch status | " + f"batch_job_id={batch_job.id} | provider={provider_name} | " + f"state={provider_status}" + ) + + if provider_status not in TERMINAL_STATES: + all_terminal = False + continue + + # Batch reached terminal state - dispatch processing to Celery + if provider_status == BatchJobState.SUCCEEDED.value: + _dispatch_tts_result_processing( + run=run, + batch_job=batch_job, + org_id=org_id, + provider_name=provider_name, + ) + any_succeeded = True + dispatched = True + else: + any_failed = True + errors.append( + f"{provider_name}: {batch_job.error_message or provider_status}" + ) + + if not all_terminal: + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": run.status, + "action": "no_change", + } + + # If we dispatched processing to Celery, keep the run as "processing". + # The Celery task will finalize the run status when done. + if dispatched: + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": "processing", + "action": "dispatched", + } + + # All batch jobs are done and no dispatching needed - finalize the run + status_counts = count_results_by_status(session=session, run_id=run.id) + pending = status_counts.get(JobStatus.PENDING.value, 0) + failed_count = status_counts.get(JobStatus.FAILED.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = None + if any_failed: + error_message = "; ".join(errors) + elif failed_count > 0: + error_message = f"{failed_count} synthesis(es) failed" + + update_tts_run( + session=session, + run_id=run.id, + status=final_status, + error_message=error_message, + ) + + action = "completed" if not any_failed else "failed" + + return { + "run_id": run.id, + "run_name": run.run_name, + "type": "tts", + "previous_status": previous_status, + "current_status": final_status, + "action": action, + **({"error": error_message} if error_message else {}), + } diff --git a/backend/app/crud/tts_evaluations/dataset.py b/backend/app/crud/tts_evaluations/dataset.py new file mode 100644 index 00000000..e2956885 --- /dev/null +++ b/backend/app/crud/tts_evaluations/dataset.py @@ -0,0 +1,158 @@ +"""CRUD operations for TTS evaluation datasets.""" + +import logging +from typing import Any + +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, func, select + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models import EvaluationDataset +from app.models.stt_evaluation import EvaluationType +from app.models.tts_evaluation import TTSDatasetPublic + +logger = logging.getLogger(__name__) + + +def create_tts_dataset( + *, + session: Session, + name: str, + org_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, + object_store_url: str | None = None, + dataset_metadata: dict[str, Any] | None = None, +) -> EvaluationDataset: + """Create a new TTS evaluation dataset. + + Args: + session: Database session + name: Dataset name + org_id: Organization ID + project_id: Project ID + description: Optional description + language_id: Optional reference to global.languages table + object_store_url: Optional object store URL + dataset_metadata: Optional metadata dict + + Returns: + EvaluationDataset: Created dataset + + Raises: + HTTPException: If dataset with same name already exists + """ + logger.info( + f"[create_tts_dataset] Creating TTS dataset | " + f"name: {name}, org_id: {org_id}, project_id: {project_id}" + ) + + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.TTS.value, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=dataset_metadata or {}, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + try: + session.add(dataset) + session.flush() + + logger.info( + f"[create_tts_dataset] TTS dataset created | " + f"dataset_id: {dataset.id}, name: {name}" + ) + + return dataset + + except IntegrityError as e: + session.rollback() + if "uq_evaluation_dataset_name_org_project" in str(e): + logger.error( + f"[create_tts_dataset] Dataset name already exists | name: {name}" + ) + raise HTTPException( + status_code=400, + detail=f"Dataset with name '{name}' already exists", + ) + raise + + +def get_tts_dataset_by_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, +) -> EvaluationDataset | None: + """Get a TTS dataset by ID. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationDataset | None: Dataset if found + """ + statement = select(EvaluationDataset).where( + EvaluationDataset.id == dataset_id, + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.TTS.value, + ) + + return session.exec(statement).one_or_none() + + +def list_tts_datasets( + *, + session: Session, + org_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> tuple[list[TTSDatasetPublic], int]: + """List TTS datasets for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[TTSDatasetPublic], int]: Datasets and total count + """ + base_filter = ( + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.TTS.value, + ) + + count_stmt = select(func.count(EvaluationDataset.id)).where(*base_filter) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationDataset) + .where(*base_filter) + .order_by(EvaluationDataset.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + datasets = session.exec(statement).all() + + result = [TTSDatasetPublic.from_model(dataset) for dataset in datasets] + + return result, total diff --git a/backend/app/crud/tts_evaluations/result.py b/backend/app/crud/tts_evaluations/result.py new file mode 100644 index 00000000..c2de66cf --- /dev/null +++ b/backend/app/crud/tts_evaluations/result.py @@ -0,0 +1,286 @@ +"""CRUD operations for TTS evaluation results.""" + +import logging +from typing import Any + +from sqlmodel import Session, func, select + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult, TTSResultPublic + +logger = logging.getLogger(__name__) + + +def create_tts_results( + *, + session: Session, + sample_texts: list[str], + evaluation_run_id: int, + org_id: int, + project_id: int, + models: list[str], +) -> list[TTSResult]: + """Create TTS result records for all sample texts and models. + + Creates one result per sample text per model. + + Args: + session: Database session + sample_texts: List of text strings to synthesize + evaluation_run_id: Run ID + org_id: Organization ID + project_id: Project ID + models: List of TTS models + + Returns: + list[TTSResult]: Created results + """ + logger.info( + f"[create_tts_results] Creating TTS results | " + f"run_id: {evaluation_run_id}, sample_count: {len(sample_texts)}, " + f"model_count: {len(models)}" + ) + + timestamp = now() + results = [ + TTSResult( + sample_text=text, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=model, + status=JobStatus.PENDING.value, + inserted_at=timestamp, + updated_at=timestamp, + ) + for text in sample_texts + for model in models + ] + + session.add_all(results) + session.flush() + session.commit() + + logger.info( + f"[create_tts_results] TTS results created | " + f"run_id: {evaluation_run_id}, result_count: {len(results)}" + ) + + return results + + +def get_tts_result_by_id( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, +) -> TTSResult | None: + """Get a TTS result by ID. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + + Returns: + TTSResult | None: Result if found + """ + statement = select(TTSResult).where( + TTSResult.id == result_id, + TTSResult.organization_id == org_id, + TTSResult.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def get_results_by_run_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, +) -> tuple[list[TTSResultPublic], int]: + """Get all results for an evaluation run. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + + Returns: + tuple[list[TTSResultPublic], int]: Results and total count + """ + where_clauses = [ + TTSResult.evaluation_run_id == run_id, + TTSResult.organization_id == org_id, + TTSResult.project_id == project_id, + ] + + statement = select(TTSResult).where(*where_clauses).order_by(TTSResult.id) + + rows = session.exec(statement).all() + total = len(rows) + + results = [TTSResultPublic.from_model(result) for result in rows] + + return results, total + + +def update_tts_result( + *, + session: Session, + result_id: int, + object_store_url: str | None = None, + metadata: dict[str, Any] | None = None, + status: str | None = None, + error_message: str | None = None, +) -> TTSResult | None: + """Update a TTS result. + + Args: + session: Database session + result_id: Result ID + object_store_url: S3 URL of generated WAV + metadata: Audio metadata (duration_seconds, size_bytes) + status: New status + error_message: Error message if failed + + Returns: + TTSResult | None: Updated result + """ + statement = select(TTSResult).where(TTSResult.id == result_id) + result = session.exec(statement).one_or_none() + + if not result: + return None + + if object_store_url is not None: + result.object_store_url = object_store_url + if metadata is not None: + result.metadata_ = metadata + if status is not None: + result.status = status + if error_message is not None: + result.error_message = error_message + + result.updated_at = now() + + session.add(result) + session.flush() + + return result + + +def update_tts_human_feedback( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, + is_correct: bool | None = None, + comment: str | None = None, +) -> TTSResult | None: + """Update human feedback on a TTS result. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + is_correct: Human verification of quality + comment: Feedback comment + + Returns: + TTSResult | None: Updated result + + Raises: + HTTPException: If result not found + """ + result = get_tts_result_by_id( + session=session, + result_id=result_id, + org_id=org_id, + project_id=project_id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + if is_correct is not None: + result.is_correct = is_correct + + if comment is not None: + result.comment = comment + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + logger.info( + f"[update_tts_human_feedback] Human feedback updated | " + f"result_id: {result_id}, is_correct: {is_correct}" + ) + + return result + + +def get_pending_results_for_run( + *, + session: Session, + run_id: int, + provider: str | None = None, +) -> list[TTSResult]: + """Get all pending results for a run. + + Args: + session: Database session + run_id: Run ID + provider: Optional filter by provider + + Returns: + list[TTSResult]: Pending results + """ + where_clauses = [ + TTSResult.evaluation_run_id == run_id, + TTSResult.status == JobStatus.PENDING.value, + ] + + if provider is not None: + where_clauses.append(TTSResult.provider == provider) + + statement = select(TTSResult).where(*where_clauses) + + return list(session.exec(statement).all()) + + +def count_results_by_status( + *, + session: Session, + run_id: int, +) -> dict[str, int]: + """Count results by status for a run. + + Args: + session: Database session + run_id: Run ID + + Returns: + dict[str, int]: Counts by status + """ + statement = ( + select(TTSResult.status, func.count(TTSResult.id)) + .where(TTSResult.evaluation_run_id == run_id) + .group_by(TTSResult.status) + ) + + rows = session.exec(statement).all() + + return {status: count for status, count in rows} diff --git a/backend/app/crud/tts_evaluations/run.py b/backend/app/crud/tts_evaluations/run.py new file mode 100644 index 00000000..272bced6 --- /dev/null +++ b/backend/app/crud/tts_evaluations/run.py @@ -0,0 +1,230 @@ +"""CRUD operations for TTS evaluation runs.""" + +import logging +from typing import Any + +from sqlmodel import Session, func, select + +from app.core.util import now +from app.models import EvaluationDataset, EvaluationRun +from app.models.stt_evaluation import EvaluationType +from app.models.tts_evaluation import TTSEvaluationRunPublic + +logger = logging.getLogger(__name__) + + +def create_tts_run( + *, + session: Session, + run_name: str, + dataset_id: int, + dataset_name: str, + org_id: int, + project_id: int, + models: list[str], + language_id: int | None = None, + total_items: int = 0, +) -> EvaluationRun: + """Create a new TTS evaluation run. + + Args: + session: Database session + run_name: Name for the run + dataset_id: ID of the dataset to evaluate + dataset_name: Name of the dataset + org_id: Organization ID + project_id: Project ID + models: List of TTS models to use + language_id: Optional language ID override + total_items: Total number of items to process + + Returns: + EvaluationRun: Created run + """ + logger.info( + f"[create_tts_run] Creating TTS evaluation run | " + f"run_name: {run_name}, dataset_id: {dataset_id}, " + f"models: {models}" + ) + + run = EvaluationRun( + run_name=run_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + type=EvaluationType.TTS.value, + language_id=language_id, + providers=models, + status="pending", + total_items=total_items, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[create_tts_run] TTS evaluation run created | " + f"run_id: {run.id}, run_name: {run_name}" + ) + + return run + + +def get_tts_run_by_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, +) -> EvaluationRun | None: + """Get a TTS evaluation run by ID. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationRun | None: Run if found + """ + statement = select(EvaluationRun).where( + EvaluationRun.id == run_id, + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.TTS.value, + ) + + return session.exec(statement).one_or_none() + + +def list_tts_runs( + *, + session: Session, + org_id: int, + project_id: int, + dataset_id: int | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, +) -> tuple[list[TTSEvaluationRunPublic], int]: + """List TTS evaluation runs for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + dataset_id: Optional filter by dataset + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[TTSEvaluationRunPublic], int]: Runs and total count + """ + where_clauses = [ + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.TTS.value, + ] + + if dataset_id is not None: + where_clauses.append(EvaluationRun.dataset_id == dataset_id) + + if status is not None: + where_clauses.append(EvaluationRun.status == status) + + count_stmt = select(func.count(EvaluationRun.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationRun) + .where(*where_clauses) + .order_by(EvaluationRun.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + runs = session.exec(statement).all() + + result = [ + TTSEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + models=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + for run in runs + ] + + return result, total + + +def update_tts_run( + *, + session: Session, + run_id: int, + status: str | None = None, + score: dict[str, Any] | None = None, + error_message: str | None = None, + object_store_url: str | None = None, + batch_job_id: int | None = None, +) -> EvaluationRun | None: + """Update a TTS evaluation run. + + Args: + session: Database session + run_id: Run ID + status: New status + score: Score data + error_message: Error message + object_store_url: URL to stored results + batch_job_id: ID of the associated batch job + + Returns: + EvaluationRun | None: Updated run + """ + statement = select(EvaluationRun).where(EvaluationRun.id == run_id) + run = session.exec(statement).one_or_none() + + if not run: + return None + + updates = { + "status": status, + "score": score, + "error_message": error_message, + "object_store_url": object_store_url, + "batch_job_id": batch_job_id, + } + + for field, value in updates.items(): + if value is not None: + setattr(run, field, value) + + run.updated_at = now() + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[update_tts_run] TTS run updated | run_id: {run_id}, status: {run.status}" + ) + + return run diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py new file mode 100644 index 00000000..caa45ae7 --- /dev/null +++ b/backend/app/models/tts_evaluation.py @@ -0,0 +1,297 @@ +"""TTS Evaluation models for Text-to-Speech evaluation feature.""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import Column, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel + +from app.core.util import now +from app.models.job import JobStatus +from app.models.stt_evaluation import EvaluationType + +# Supported TTS models for evaluation +SUPPORTED_TTS_MODELS = ["gemini-2.5-pro-preview-tts"] + + +class TTSResult(SQLModel, table=True): + """Database table for TTS synthesis results.""" + + __tablename__ = "tts_result" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the TTS result"}, + ) + + sample_text: str = SQLField( + sa_column=Column( + Text, + nullable=False, + comment="Input text that was synthesized to speech", + ), + description="Input text that was synthesized to speech", + ) + + object_store_url: str | None = SQLField( + default=None, + sa_column_kwargs={"comment": "S3 URL of the generated WAV audio file"}, + ) + + metadata_: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + "metadata", + JSONB, + nullable=True, + comment="Audio metadata: {duration_seconds, size_bytes}", + ), + description="Audio metadata (duration_seconds, size_bytes)", + ) + + provider: str = SQLField( + max_length=100, + description="TTS provider used (e.g., gemini-2.5-pro-preview-tts)", + sa_column_kwargs={ + "comment": "TTS provider used (e.g., gemini-2.5-pro-preview-tts)" + }, + ) + + status: str = SQLField( + default=JobStatus.PENDING.value, + max_length=20, + description="Result status: PENDING, SUCCESS, FAILED", + sa_column_kwargs={"comment": "Result status: PENDING, SUCCESS, FAILED"}, + ) + + score: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Extensible evaluation metrics (null in Phase 1)", + ), + description="Extensible evaluation metrics", + ) + + is_correct: bool | None = SQLField( + default=None, + description="Human feedback: audio quality correctness", + sa_column_kwargs={ + "comment": "Human feedback: audio quality correctness (null=not reviewed)" + }, + ) + comment: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Human feedback comment", + ), + description="Human feedback comment", + ) + + error_message: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Error message if synthesis failed", + ), + description="Error message if synthesis failed", + ) + + evaluation_run_id: int = SQLField( + foreign_key="evaluation_run.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the evaluation run"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, + ) + + +# --- Pydantic request/response models --- + + +class TTSSampleCreate(BaseModel): + """Request model for a single TTS sample.""" + + text: str = Field(..., description="Text to synthesize", min_length=1) + + +class TTSDatasetCreate(BaseModel): + """Request model for creating a TTS dataset.""" + + name: str = Field(..., description="Dataset name", min_length=1) + description: str | None = Field(None, description="Dataset description") + language_id: int | None = Field( + None, description="ID of the language from global languages table" + ) + samples: list[TTSSampleCreate] = Field( + ..., description="List of text samples", min_length=1 + ) + + +class TTSDatasetPublic(BaseModel): + """Public model for TTS datasets.""" + + id: int + name: str + description: str | None + type: str + language_id: int | None + object_store_url: str | None + dataset_metadata: dict[str, Any] + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + @classmethod + def from_model(cls, dataset: "EvaluationDataset") -> "TTSDatasetPublic": + """Create from an EvaluationDataset model instance.""" + return cls( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + + +class TTSResultPublic(BaseModel): + """Public model for TTS results.""" + + id: int + sample_text: str + object_store_url: str | None + duration_seconds: float | None = None + size_bytes: int | None = None + provider: str + status: str + score: dict[str, Any] | None + is_correct: bool | None + comment: str | None + error_message: str | None + evaluation_run_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + @classmethod + def from_model(cls, result: "TTSResult") -> "TTSResultPublic": + """Create from a TTSResult model instance.""" + return cls( + id=result.id, + sample_text=result.sample_text, + object_store_url=result.object_store_url, + duration_seconds=(result.metadata_ or {}).get("duration_seconds"), + size_bytes=(result.metadata_ or {}).get("size_bytes"), + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + + +class TTSFeedbackUpdate(BaseModel): + """Request model for updating human feedback on a TTS result.""" + + is_correct: bool | None = Field( + None, description="Is the synthesized audio correct?" + ) + comment: str | None = Field(None, description="Feedback comment") + + +class TTSEvaluationRunCreate(BaseModel): + """Request model for starting a TTS evaluation run.""" + + run_name: str = Field(..., description="Name for this evaluation run", min_length=1) + dataset_id: int = Field(..., description="ID of the TTS dataset to evaluate") + models: list[str] = Field( + default_factory=lambda: ["gemini-2.5-pro-preview-tts"], + description="List of TTS models to use", + min_length=1, + ) + + @field_validator("models") + @classmethod + def validate_models(cls, valid_model: list[str]) -> list[str]: + """Validate that all models are supported.""" + if not valid_model: + raise ValueError("At least one model must be specified") + unsupported = [m for m in valid_model if m not in SUPPORTED_TTS_MODELS] + if unsupported: + raise ValueError( + f"Unsupported model(s): {', '.join(unsupported)}. " + f"Supported models are: {', '.join(SUPPORTED_TTS_MODELS)}" + ) + return valid_model + + +class TTSEvaluationRunPublic(BaseModel): + """Public model for TTS evaluation runs.""" + + id: int + run_name: str + dataset_name: str + type: str + language_id: int | None + models: list[str] | None + dataset_id: int + status: str + total_items: int + score: dict[str, Any] | None + error_message: str | None + run_metadata: dict[str, Any] | None = None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class TTSEvaluationRunWithResults(TTSEvaluationRunPublic): + """TTS evaluation run with embedded results.""" + + results: list[TTSResultPublic] + results_total: int = Field(0, description="Total number of results") diff --git a/backend/app/services/tts_evaluations/__init__.py b/backend/app/services/tts_evaluations/__init__.py new file mode 100644 index 00000000..fc19336d --- /dev/null +++ b/backend/app/services/tts_evaluations/__init__.py @@ -0,0 +1 @@ +"""TTS Evaluation services.""" diff --git a/backend/app/services/tts_evaluations/audio.py b/backend/app/services/tts_evaluations/audio.py new file mode 100644 index 00000000..964b6c28 --- /dev/null +++ b/backend/app/services/tts_evaluations/audio.py @@ -0,0 +1,57 @@ +"""Audio processing utilities for TTS evaluation.""" + +import io +import struct +import wave + + +def pcm_to_wav( + pcm_data: bytes, + sample_rate: int = 24000, + bits_per_sample: int = 16, + channels: int = 1, +) -> bytes: + """Wrap raw PCM audio data in a WAV container. + + Args: + pcm_data: Raw PCM audio bytes + sample_rate: Sample rate in Hz (default: 24000 for Gemini TTS) + bits_per_sample: Bits per sample (default: 16) + channels: Number of audio channels (default: 1 mono) + + Returns: + WAV file bytes with proper headers + """ + output = io.BytesIO() + + with wave.open(output, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bits_per_sample // 8) + wav_file.setframerate(sample_rate) + wav_file.writeframes(pcm_data) + + return output.getvalue() + + +def calculate_duration( + pcm_size: int, + sample_rate: int = 24000, + bits_per_sample: int = 16, + channels: int = 1, +) -> float: + """Calculate audio duration from PCM data size. + + Args: + pcm_size: Size of raw PCM data in bytes + sample_rate: Sample rate in Hz + bits_per_sample: Bits per sample + channels: Number of audio channels + + Returns: + Duration in seconds + """ + bytes_per_sample = bits_per_sample // 8 + bytes_per_second = sample_rate * bytes_per_sample * channels + if bytes_per_second == 0: + return 0.0 + return pcm_size / bytes_per_second diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py new file mode 100644 index 00000000..ccfdf623 --- /dev/null +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -0,0 +1,274 @@ +"""Celery task function for TTS evaluation result processing. + +Processes completed Gemini TTS batch results: downloads JSONL, +extracts audio, converts PCM to WAV, uploads to S3, updates DB. +""" + +import base64 +import logging +import uuid +from typing import Any + +from sqlmodel import Session, select + +from app.core.batch import BATCH_KEY, GeminiBatchProvider +from app.core.cloud.storage import get_cloud_storage +from app.core.db import engine +from app.core.storage_utils import upload_to_object_store +from app.crud.tts_evaluations.result import ( + count_results_by_status, + update_tts_result, +) +from app.crud.tts_evaluations.run import update_tts_run +from app.models.job import JobStatus +from app.models.tts_evaluation import TTSResult +from app.services.stt_evaluations.gemini import GeminiClient +from app.services.tts_evaluations.audio import calculate_duration, pcm_to_wav + +logger = logging.getLogger(__name__) + + +def execute_tts_result_processing( + project_id: int, + job_id: str, + task_id: str, + task_instance: Any, + organization_id: int, + evaluation_run_id: int, + tts_provider: str, + provider_batch_id: str, + **kwargs: Any, +) -> dict: + """Process completed TTS batch results in a Celery worker. + + Downloads batch results from Gemini, extracts audio, converts to WAV, + uploads to S3, and updates TTSResult records. + + Args: + project_id: Project ID + job_id: Batch job ID (as string) + task_id: Celery task ID + task_instance: Celery task instance + organization_id: Organization ID + evaluation_run_id: Evaluation run ID + tts_provider: TTS provider/model name + provider_batch_id: Gemini batch job ID + + Returns: + dict: Result summary with processed/failed counts + """ + logger.info( + f"[execute_tts_result_processing] Starting | " + f"run_id={evaluation_run_id}, batch_job_id={job_id}, " + f"provider={tts_provider}, celery_task_id={task_id}" + ) + + with Session(engine) as session: + try: + # Initialize Gemini client and batch provider + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + # Get cloud storage for S3 uploads + storage = get_cloud_storage(session=session, project_id=project_id) + + # Download batch results + results = batch_provider.download_batch_results(provider_batch_id) + + logger.info( + f"[execute_tts_result_processing] Got batch results | " + f"run_id={evaluation_run_id}, result_count={len(results)}" + ) + + if results: + first = results[0] + resp = first.get("response") or {} + logger.info( + f"[execute_tts_result_processing] First result structure | " + f"keys={list(first.keys())}, " + f"response_keys={list(resp.keys()) if isinstance(resp, dict) else type(resp).__name__}" + ) + + processed_count = 0 + failed_count = 0 + + for batch_result in results: + custom_id = batch_result[BATCH_KEY] + try: + result_id = int(custom_id) + except (ValueError, TypeError): + logger.warning( + f"[execute_tts_result_processing] Invalid {BATCH_KEY} | " + f"run_id={evaluation_run_id}, {BATCH_KEY}={custom_id}" + ) + failed_count += 1 + continue + + # Find result record + stmt = select(TTSResult).where( + TTSResult.id == result_id, + TTSResult.evaluation_run_id == evaluation_run_id, + TTSResult.provider == tts_provider, + ) + result_record = session.exec(stmt).one_or_none() + + if not result_record: + logger.warning( + f"[execute_tts_result_processing] Result record not found | " + f"result_id={result_id}" + ) + failed_count += 1 + continue + + if batch_result.get("response"): + try: + audio_b64 = _extract_audio_from_response( + batch_result["response"] + ) + + if not audio_b64: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message="No audio data in response", + ) + failed_count += 1 + continue + + # Decode base64 -> raw PCM bytes + pcm_data = base64.b64decode(audio_b64) + + # Wrap in WAV container + wav_data = pcm_to_wav(pcm_data) + + # Calculate duration + duration = calculate_duration(len(pcm_data)) + + # Upload WAV to S3 + audio_filename = f"{uuid.uuid4()}.wav" + audio_url = upload_to_object_store( + storage=storage, + content=wav_data, + filename=audio_filename, + subdirectory="evaluations/tts/audio", + content_type="audio/wav", + ) + + # Update result + update_tts_result( + session=session, + result_id=result_record.id, + object_store_url=audio_url, + metadata={ + "duration_seconds": round(duration, 3), + "size_bytes": len(wav_data), + }, + status=JobStatus.SUCCESS.value, + ) + processed_count += 1 + + except Exception as audio_err: + logger.error( + f"[execute_tts_result_processing] Audio processing failed | " + f"result_id={result_id}, error={str(audio_err)}" + ) + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=f"Audio processing failed: {str(audio_err)}", + ) + failed_count += 1 + else: + update_tts_result( + session=session, + result_id=result_record.id, + status=JobStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 + + session.commit() + + # Finalize run status + status_counts = count_results_by_status( + session=session, run_id=evaluation_run_id + ) + pending = status_counts.get(JobStatus.PENDING.value, 0) + total_failed = status_counts.get(JobStatus.FAILED.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = ( + f"{total_failed} synthesis(es) failed" if total_failed > 0 else None + ) + + update_tts_run( + session=session, + run_id=evaluation_run_id, + status=final_status, + error_message=error_message, + ) + + logger.info( + f"[execute_tts_result_processing] Completed | " + f"run_id={evaluation_run_id}, provider={tts_provider}, " + f"processed={processed_count}, failed={failed_count}, " + f"run_status={final_status}" + ) + + return { + "success": True, + "run_id": evaluation_run_id, + "processed": processed_count, + "failed": failed_count, + "run_status": final_status, + } + + except Exception as e: + logger.error( + f"[execute_tts_result_processing] Failed | " + f"run_id={evaluation_run_id}, error={str(e)}", + exc_info=True, + ) + update_tts_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message=f"Result processing failed: {str(e)}", + ) + return {"success": False, "error": str(e)} + + +def _extract_audio_from_response(response: dict[str, Any]) -> str | None: + """Extract base64-encoded audio data from a Gemini TTS response. + + Gemini TTS returns audio as base64-encoded PCM data in the + inlineData field of the response parts. Handles both camelCase + (REST API) and snake_case (Python SDK / batch JSONL) field names. + + Args: + response: Gemini response dictionary + + Returns: + Base64 encoded audio string, or None if not found + """ + # Navigate: candidates -> content -> parts -> inlineData/inline_data -> data + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + # Handle both camelCase (inlineData) and snake_case (inline_data) + inline_data = part.get("inlineData") or part.get("inline_data") or {} + if inline_data.get("data"): + return inline_data["data"] + + logger.warning( + f"[_extract_audio_from_response] No audio data found | " + f"response_keys={list(response.keys())}, " + f"parts={[list(p.keys()) for c in response.get('candidates', []) for p in c.get('content', {}).get('parts', [])]}" + ) + return None diff --git a/backend/app/services/tts_evaluations/constants.py b/backend/app/services/tts_evaluations/constants.py new file mode 100644 index 00000000..3a0b812c --- /dev/null +++ b/backend/app/services/tts_evaluations/constants.py @@ -0,0 +1,10 @@ +"""Shared constants for TTS evaluation services.""" + +# Default TTS model +DEFAULT_TTS_MODEL = "gemini-2.5-pro-preview-tts" + +# Default voice configuration +DEFAULT_VOICE_NAME = "Kore" + +# Default style prompt for TTS synthesis +DEFAULT_STYLE_PROMPT = "Read in a calm, professional customer service tone" diff --git a/backend/app/services/tts_evaluations/dataset.py b/backend/app/services/tts_evaluations/dataset.py new file mode 100644 index 00000000..fb2ffa33 --- /dev/null +++ b/backend/app/services/tts_evaluations/dataset.py @@ -0,0 +1,180 @@ +"""Dataset management service for TTS evaluations.""" + +import csv +import io +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import ( + generate_timestamped_filename, + upload_to_object_store, +) +from app.crud.tts_evaluations import create_tts_dataset +from app.models import EvaluationDataset +from app.models.tts_evaluation import TTSSampleCreate + +logger = logging.getLogger(__name__) + + +def upload_tts_dataset( + session: Session, + name: str, + samples: list[TTSSampleCreate], + organization_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, +) -> EvaluationDataset: + """Orchestrate TTS dataset upload workflow. + + Steps: + 1. Convert samples to CSV format + 2. Upload CSV to object store + 3. Create dataset record in database + + Args: + session: Database session + name: Dataset name + samples: List of TTS text samples + organization_id: Organization ID + project_id: Project ID + description: Optional dataset description + language_id: Optional reference to global.languages table + + Returns: + Created dataset record + """ + logger.info( + f"[upload_tts_dataset] Uploading TTS dataset | name={name} | " + f"sample_count={len(samples)} | org_id={organization_id} | " + f"project_id={project_id}" + ) + + # Step 1: Convert samples to CSV and upload to object store + object_store_url = _upload_samples_to_object_store( + session=session, + project_id=project_id, + dataset_name=name, + samples=samples, + ) + + # Step 2: Calculate metadata + metadata: dict[str, Any] = { + "sample_count": len(samples), + } + + # Step 3: Create dataset record + try: + dataset = create_tts_dataset( + session=session, + name=name, + org_id=organization_id, + project_id=project_id, + description=description, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=metadata, + ) + + logger.info( + f"[upload_tts_dataset] Created dataset record | " + f"id={dataset.id} | name={name}" + ) + + session.commit() + + return dataset + + except Exception: + session.rollback() + raise + + +def _upload_samples_to_object_store( + session: Session, + project_id: int, + dataset_name: str, + samples: list[TTSSampleCreate], +) -> str | None: + """Upload TTS samples as CSV to object store. + + Args: + session: Database session + project_id: Project ID for storage credentials + dataset_name: Dataset name for filename + samples: List of samples to upload + + Returns: + Object store URL if successful, None otherwise + """ + try: + storage = get_cloud_storage(session=session, project_id=project_id) + + csv_content = _samples_to_csv(samples) + + filename = generate_timestamped_filename(dataset_name, "csv") + object_store_url = upload_to_object_store( + storage=storage, + content=csv_content, + filename=filename, + subdirectory="tts_datasets", + content_type="text/csv", + ) + + if object_store_url: + logger.info( + f"[_upload_samples_to_object_store] Upload successful | " + f"url={object_store_url}" + ) + else: + logger.info( + "[_upload_samples_to_object_store] Upload returned None | " + "continuing without object store storage" + ) + + return object_store_url + + except Exception as e: + logger.warning( + f"[_upload_samples_to_object_store] Failed to upload | {e}", + exc_info=True, + ) + return None + + +def _samples_to_csv(samples: list[TTSSampleCreate]) -> bytes: + """Convert TTS samples to CSV format. + + Args: + samples: List of TTS samples + + Returns: + CSV content as bytes + """ + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(["text"]) + for sample in samples: + writer.writerow([sample.text]) + return output.getvalue().encode("utf-8") + + +def parse_tts_samples_from_csv(csv_content: bytes) -> list[dict[str, Any]]: + """Parse TTS samples from CSV content. + + Args: + csv_content: CSV file content as bytes + + Returns: + List of dicts with {index, text} for each sample + """ + reader = csv.DictReader(io.StringIO(csv_content.decode("utf-8"))) + samples = [] + for i, row in enumerate(reader): + text = row.get("text", "").strip() + if text: + samples.append({"index": i, "text": text}) + return samples diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index fc62f7b4..c9281cfb 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -233,9 +233,15 @@ def test_download_batch_results_success(self, provider, mock_genai_client): ) assert len(results) == 2 assert results[0]["custom_id"] == "req-1" - assert results[0]["response"]["text"] == "Hello" + assert ( + results[0]["response"]["candidates"][0]["content"]["parts"][0]["text"] + == "Hello" + ) assert results[1]["custom_id"] == "req-2" - assert results[1]["response"]["text"] == "World" + assert ( + results[1]["response"]["candidates"][0]["content"]["parts"][0]["text"] + == "World" + ) def test_download_batch_results_with_direct_text_response( self, provider, mock_genai_client @@ -255,7 +261,9 @@ def test_download_batch_results_with_direct_text_response( results = provider.download_batch_results(batch_id) assert len(results) == 1 - assert results[0]["response"]["text"] == "Direct text" + assert ( + results[0]["response"]["text"] == "Direct text" + ) # raw response passthrough def test_download_batch_results_with_errors(self, provider, mock_genai_client): """Test downloading batch results that contain errors."""