mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add API authentication and configuration improvements
- Implement OpenAI-compatible API key authentication - Add configuration options for GPU instances, concurrency, and request handling - Update README with authentication instructions - Modify configuration and routing to support optional API key verification - Enhance system information and debug endpoints to expose authentication status
This commit is contained in:
parent
a578d22084
commit
5c8f941f06
17 changed files with 789 additions and 268 deletions
38
.env.example
Normal file
38
.env.example
Normal file
|
@ -0,0 +1,38 @@
|
|||
# API Settings
|
||||
API_TITLE="Kokoro TTS API"
|
||||
API_DESCRIPTION="API for text-to-speech generation using Kokoro"
|
||||
API_VERSION="1.0.0"
|
||||
HOST="0.0.0.0"
|
||||
PORT=8880
|
||||
|
||||
# Authentication Settings
|
||||
ENABLE_AUTH=False # Set to True to enable API key authentication
|
||||
API_KEYS=["sk-kokoro-1234567890abcdef", "sk-kokoro-0987654321fedcba"] # List of valid API keys
|
||||
|
||||
# Application Settings
|
||||
OUTPUT_DIR="output"
|
||||
OUTPUT_DIR_SIZE_LIMIT_MB=500.0
|
||||
DEFAULT_VOICE="af_heart"
|
||||
# DEFAULT_VOICE_CODE can be set to a language code like "a" for American English, "j" for Japanese, etc.
|
||||
# Valid codes: a (American English), b (British English), e (Spanish), f (French), h (Hindi),
|
||||
# i (Italian), p (Portuguese), j (Japanese), z (Mandarin Chinese)
|
||||
# Set to null or leave commented out to use the first letter of the voice name
|
||||
DEFAULT_VOICE_CODE="a"
|
||||
|
||||
USE_GPU=True
|
||||
ALLOW_LOCAL_VOICE_SAVING=False
|
||||
|
||||
# Audio Settings
|
||||
SAMPLE_RATE=24000
|
||||
|
||||
# Web Player Settings
|
||||
ENABLE_WEB_PLAYER=True
|
||||
WEB_PLAYER_PATH="web"
|
||||
CORS_ORIGINS=["*"]
|
||||
CORS_ENABLED=True
|
||||
|
||||
# Temp File Settings
|
||||
TEMP_FILE_DIR="api/temp_files"
|
||||
MAX_TEMP_DIR_SIZE_MB=2048
|
||||
MAX_TEMP_DIR_AGE_HOURS=1
|
||||
MAX_TEMP_DIR_COUNT=3
|
41
README.md
41
README.md
|
@ -124,6 +124,47 @@ with client.audio.speech.with_streaming_response.create(
|
|||
<img src="assets/webui-screenshot.png" width="42%" alt="Web UI Screenshot" style="border: 2px solid #333; padding: 10px;">
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary>API Authentication</summary>
|
||||
|
||||
The API supports OpenAI-compatible API key authentication. This feature is disabled by default but can be enabled through environment variables.
|
||||
|
||||
To enable authentication:
|
||||
|
||||
1. Create a `.env` file in the project root (or copy from `.env.example`)
|
||||
2. Set the following variables:
|
||||
```
|
||||
ENABLE_AUTH=True
|
||||
API_KEYS=["sk-kokoro-your-api-key-1", "sk-kokoro-your-api-key-2"]
|
||||
```
|
||||
|
||||
When authentication is enabled, all OpenAI-compatible endpoints will require an API key to be provided in the `Authorization` header. The API supports both `Bearer sk-kokoro-xxx` and `sk-kokoro-xxx` formats.
|
||||
|
||||
Example usage with authentication:
|
||||
```bash
|
||||
# Using curl
|
||||
curl -X POST "http://localhost:8880/v1/audio/speech" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-kokoro-your-api-key" \
|
||||
-d '{"model":"kokoro", "input":"Hello world", "voice":"af_heart"}'
|
||||
|
||||
# Using Python with OpenAI client
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="sk-kokoro-your-api-key",
|
||||
base_url="http://localhost:8880/v1"
|
||||
)
|
||||
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro",
|
||||
voice="af_heart",
|
||||
input="Hello world"
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
|
88
Test copy.py
88
Test copy.py
|
@ -1,88 +0,0 @@
|
|||
import requests
|
||||
import base64
|
||||
import json
|
||||
import pydub
|
||||
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
||||
|
||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
|
||||
|
||||
The very title, "Jet Black Heart," immediately establishes the song's central motif: an intrinsic darkness residing within the narrator's emotional core. The phrase "jet black" is not simply a descriptor of color; it evokes a sense of absolute darkness, a void devoid of light, and a profound absence of hope. This is not a heart merely bruised by external circumstances, but one fundamentally shaded by internal struggles, suggesting a chronic condition of emotional pain. The opening lines, "Everybody's got their demons, even wide awake or dreaming," acknowledge the universality of inner conflict, a shared human experience of battling internal anxieties and insecurities. However, the designation of a "jet black heart" elevates this struggle to a more profound and potentially entrenched level. It suggests a darkness that is not fleeting or situational, but rather a deeply ingrained aspect of the narrator's being, casting a long shadow over their life and relationships. This internal darkness is further amplified by the subsequent metaphor, "there's a hurricane underneath it." The imagery of a hurricane is intensely evocative, conjuring images of destructive force, uncontrollable chaos, and overwhelming power. This "hurricane" represents the tumultuous emotions and internal disorder raging beneath the surface of the narrator’s composed exterior. It is a maelstrom of pain, anxiety, and self-doubt that threatens to erupt and engulf everything in its path. Crucially, this internal hurricane is not merely passive suffering; it is actively "trying to keep us apart," revealing the insidious way in which these inner demons sabotage connections and erect formidable barriers to genuine intimacy and meaningful relationships.
|
||||
|
||||
Expanding on this internal struggle, "Jet Black Heart" delves into the narrator's self-destructive patterns, particularly within the realm of romantic relationships. The lyrics "See a war, I wanna fight it, See a match, I wanna strike it" paint a stark picture of a deeply ingrained tendency towards conflict and destruction. This is not simply a reactive response to external aggression, but rather an active seeking out of discord, a subconscious drive to ignite conflict even in peaceful situations. This behavior can be interpreted as a manifestation of their inner turmoil, a projection of their internal chaos onto their external world. Perhaps the narrator, accustomed to internal strife, unconsciously recreates this turbulence in their relationships, finding a perverse sense of familiarity or even control within the chaos. This destructive impulse is further emphasized by the line "Every fire I've ignited faded to gray." The imagery of fire, initially representing passion, intensity, or perhaps even anger, ultimately devolving into "gray" underscores a recurring cycle of destructive behavior that culminates in emptiness and disappointment. The color gray, often associated with neutrality, lifelessness, and a lack of vibrancy, perfectly encapsulates the emotional aftermath of these self-inflicted relational fires. The initial spark of connection or excitement is inevitably extinguished, leaving behind a landscape of emotional flatness and a profound sense of failure in sustaining meaningful bonds. Further solidifying this theme of self-sabotage is the powerful phrase "I write with a poison pen." This metaphor extends beyond mere hurtful words, encompassing actions, behaviors, and the narrator's overall negative influence on their relationships. The "poison pen" suggests a deliberate, albeit perhaps unconscious, act of inflicting harm, highlighting the narrator's painful awareness of their own damaging tendencies and their capacity to erode the very connections they seemingly desire.
|
||||
|
||||
However, amidst this pervasive darkness and self-destructive cycle, "Jet Black Heart" subtly introduces a fragile glimmer of hope, a faint light flickering in the abyss. The pivotal moment of vulnerability and potential transformation arrives with the plaintive plea, "But now that I'm broken, now that you're knowing, caught up in a moment, can you see inside?" This is a desperate and profoundly vulnerable call for understanding, a raw and unfiltered exposure of the "jet black heart" after reaching a critical breaking point. The narrator, stripped bare by the weight of their own struggles and the consequences of their self-destructive behavior, finally seeks empathy and genuine connection. The admission of being "broken" is not a declaration of defeat, but rather a necessary precursor to potential healing. It is in this state of vulnerability, in the raw aftermath of emotional collapse, that the narrator dares to ask, "Can you see inside?" This question is laden with yearning, a desperate hope that someone, perhaps a partner in the strained relationship, can perceive beyond the surface darkness and recognize the wounded humanity beneath the "jet black heart." It is a plea for acceptance, not despite the darkness, but perhaps even because of it, a hope that vulnerability will be met not with judgment or rejection, but with compassion and understanding. Despite the acknowledgement of their "poison pen" and destructive tendencies, the narrator also recognizes a paradoxical source of potential redemption within the very relationship that is strained by their inner darkness: "these chemicals moving between us are the reason to start again." The ambiguous term "chemicals" can be interpreted on multiple levels. It could symbolize the complex and often volatile dynamics of human connection, the unpredictable and sometimes turbulent interplay of emotions and personalities in a relationship. Alternatively, "chemicals" might allude to a more literal, perhaps even neurochemical, imbalance within the narrator, suggesting that the very forces driving their darkness might also hold the key to transformation. Crucially, the phrase "reason to start again" emphasizes the potential for renewal and redemption, not a guaranteed outcome. It is a tentative step towards hope, acknowledging that the path forward will be fraught with challenges, but that the possibility of healing and rebuilding remains, however fragile.
|
||||
|
||||
The concluding verses of "Jet Black Heart" further solidify this nascent theme of potential transformation and tentative redemption. "The blood in my veins is made up of mistakes" is a powerful and profoundly honest admission of past errors and a crucial acceptance of human imperfection. This acknowledgement of fallibility is essential for personal growth and relational healing. By owning their mistakes, the narrator begins to dismantle the cycle of self-blame and self-destruction, paving the way for a more compassionate and forgiving self-perception. The subsequent lines, "let's forget who we are and dive into the dark, as we burst into color, returning to life," present a radical and transformative vision of shared vulnerability and mutual healing. The call to "forget who we are" is not an invitation to erase individual identity, but rather a suggestion to shed the constructed personas, ego-driven defenses, and pre-conceived notions that often hinder genuine connection. It is about stripping away the masks and embracing a state of raw, unfiltered vulnerability. The imperative to "dive into the dark" is perhaps the most challenging and transformative element of the song. It is a call to confront the pain, to face the demons, and to embrace the shared vulnerability that lies at the heart of genuine intimacy. This shared descent into darkness is not an act of succumbing to despair, but rather a courageous journey towards healing, suggesting that true connection and growth can only emerge from acknowledging and confronting the deepest, most painful aspects of ourselves and each other. The subsequent image of "bursting into color, returning to life" provides a powerful counterpoint to the prevailing darkness, symbolizing transformation, healing, and a vibrant renewal of life and connection. "Bursting into color" evokes a sense of vibrancy, joy, and emotional richness that stands in stark contrast to the "jet black" and "gray" imagery prevalent throughout the song. This suggests that by confronting and embracing the darkness, there is a possibility of emerging transformed, experiencing a rebirth and a renewed sense of purpose and joy in life. "Returning to life" further reinforces this idea of resurrection and revitalization, implying that the journey through darkness is not an end in itself, but rather a necessary passage towards a fuller, more authentic, and more vibrant existence.
|
||||
|
||||
Beyond the lyrical content, the musical elements of "Jet Black Heart" contribute significantly to its overall meaning and emotional impact. Compared to 5 Seconds of Summer's earlier, more upbeat work, "Jet Black Heart" adopts a heavier, more brooding sonic landscape. The driving rhythm, the prominent bassline, and the raw, emotive vocal delivery all mirror the thematic weight of the lyrics, creating an atmosphere of intense emotionality and vulnerability. The song's structure, building from a quiet, introspective beginning to a powerful, anthemic chorus, reflects the narrator's journey from internal struggle to a desperate plea for connection and ultimately a tentative hope for transformation.
|
||||
|
||||
In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typical pop song; it is a poignant and deeply resonant exploration of inner darkness, self-destructive tendencies, and the fragile yet persistent hope for human connection and redemption. Through its powerful central metaphor of the "jet black heart," its unflinching portrayal of internal turmoil, and its subtle yet potent message of vulnerability and potential transformation, the song resonates with anyone who has grappled with their own inner demons and the complexities of human relationships. It is a reminder that even in the deepest darkness, a flicker of hope can endure, and that true healing and connection often emerge from the courageous act of confronting and sharing our most vulnerable selves. "Jet Black Heart" stands as a testament to 5 Seconds of Summer's artistic growth, showcasing their capacity to delve into profound emotional territories and create music that is not only catchy and engaging but also deeply meaningful and emotionally resonant, solidifying their position as a band capable of capturing the complexities of the human experience."""
|
||||
|
||||
"""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
||||
|
||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
||||
|
||||
|
||||
Type="mp3"
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/captioned_speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af_heart+af_sky",
|
||||
"speed": 1.0,
|
||||
"response_format": Type,
|
||||
"stream": True,
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
f=open(f"outputstream.{Type}","wb")
|
||||
for chunk in response.iter_lines(decode_unicode=True):
|
||||
if chunk:
|
||||
temp_json=json.loads(chunk)
|
||||
if temp_json["timestamps"] != []:
|
||||
chunk_json=temp_json
|
||||
|
||||
# Decode base 64 stream to bytes
|
||||
chunk_audio=base64.b64decode(temp_json["audio"].encode("utf-8"))
|
||||
|
||||
# Process streaming chunks
|
||||
f.write(chunk_audio)
|
||||
|
||||
# Print word level timestamps
|
||||
last3=chunk_json["timestamps"][-3]
|
||||
|
||||
print(f"CUTTING TO {last3['word']}")
|
||||
|
||||
audioseg=pydub.AudioSegment.from_file(f"outputstream.{Type}",format=Type)
|
||||
audioseg=audioseg[last3["start_time"]*1000:last3["end_time"] * 1000]
|
||||
audioseg.export(f"outputstreamcut.{Type}",format=Type)
|
||||
|
||||
|
||||
"""
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/captioned_speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af_heart+af_sky",
|
||||
"speed": 1.0,
|
||||
"response_format": Type,
|
||||
"stream": False,
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
with open(f"outputnostream.{Type}", "wb") as f:
|
||||
audio_json=json.loads(response.content)
|
||||
|
||||
# Decode base 64 stream to bytes
|
||||
chunk_audio=base64.b64decode(audio_json["audio"].encode("utf-8"))
|
||||
|
||||
# Process streaming chunks
|
||||
f.write(chunk_audio)
|
||||
|
||||
# Print word level timestamps
|
||||
print(audio_json["timestamps"])
|
||||
"""
|
60
api/src/core/auth.py
Normal file
60
api/src/core/auth.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""Authentication utilities for the API"""
|
||||
|
||||
from fastapi import Depends, HTTPException, Header, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
from typing import Optional
|
||||
|
||||
from .config import settings
|
||||
|
||||
# Define API key header
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
authorization: Optional[str] = Depends(api_key_header),
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Verify the API key from the Authorization header.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value
|
||||
|
||||
Returns:
|
||||
The API key if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication is enabled and the API key is invalid
|
||||
"""
|
||||
# If authentication is disabled, allow all requests
|
||||
if not settings.enable_auth:
|
||||
return None
|
||||
|
||||
# Check if Authorization header is present
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"error": "authentication_error",
|
||||
"message": "API key is required",
|
||||
"type": "unauthorized",
|
||||
},
|
||||
)
|
||||
|
||||
# Extract the API key from the Authorization header
|
||||
# Support both "Bearer sk-xxx" and "sk-xxx" formats
|
||||
api_key = authorization
|
||||
if authorization.lower().startswith("bearer "):
|
||||
api_key = authorization[7:].strip()
|
||||
|
||||
# Check if the API key is valid
|
||||
if not settings.api_keys or api_key not in settings.api_keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
"error": "authentication_error",
|
||||
"message": "Invalid API key",
|
||||
"type": "unauthorized",
|
||||
},
|
||||
)
|
||||
|
||||
return api_key
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
@ -9,11 +10,22 @@ class Settings(BaseSettings):
|
|||
host: str = "0.0.0.0"
|
||||
port: int = 8880
|
||||
|
||||
# GPU and Concurrency Settings
|
||||
gpu_device: int = 0 # The GPU device ID to use
|
||||
instances_per_gpu: int = 4 # Number of instances to run on each GPU
|
||||
max_concurrent: int = instances_per_gpu # Maximum number of concurrent model instances
|
||||
request_queue_size: int = 100 # Maximum size of request queue
|
||||
request_timeout: int = 30 # Request timeout in seconds
|
||||
|
||||
# Authentication Settings
|
||||
enable_auth: bool = False # Whether to enable API key authentication
|
||||
api_keys: list[str] = [] # List of valid API keys
|
||||
|
||||
# Application Settings
|
||||
output_dir: str = "output"
|
||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af_heart"
|
||||
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||
default_voice_code: Optional[str] = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||
allow_local_voice_saving: bool = (
|
||||
False # Whether to allow saving combined voices locally
|
||||
|
@ -50,5 +62,11 @@ class Settings(BaseSettings):
|
|||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
def model_post_init(self, __context):
|
||||
"""Post-initialization processing to handle special values"""
|
||||
# Convert string "null" to None for default_voice_code
|
||||
if self.default_voice_code == "null":
|
||||
self.default_voice_code = None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
172
api/src/inference/instance_pool.py
Normal file
172
api/src/inference/instance_pool.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
"""GPU instance pool and request queue management."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Dict, List, Any
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .model_manager import ModelManager, get_manager
|
||||
|
||||
|
||||
class GPUInstance:
|
||||
"""Represents a model instance running on a specific GPU."""
|
||||
|
||||
def __init__(self, device_id: int, instance_id: int):
|
||||
self.device_id = device_id
|
||||
self.instance_id = instance_id # Instance ID for the same GPU
|
||||
self.manager: Optional[ModelManager] = None
|
||||
self.is_busy: bool = False
|
||||
self.current_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the model instance on the specified GPU."""
|
||||
try:
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# Create a new model manager instance for this GPU instance
|
||||
self.manager = await get_manager()
|
||||
# Initialize with warmup
|
||||
await self.manager.initialize_with_warmup(None)
|
||||
logger.info(f"Initialized model instance {self.instance_id} on GPU {self.device_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize instance {self.instance_id} on GPU {self.device_id}: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when instance is destroyed."""
|
||||
if self.manager:
|
||||
self.manager.unload_all()
|
||||
self.manager = None
|
||||
|
||||
|
||||
class InstancePool:
|
||||
"""Manages multiple GPU instances and request queue."""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
self.instances: List[GPUInstance] = []
|
||||
self.request_queue: asyncio.Queue = asyncio.Queue(maxsize=settings.request_queue_size)
|
||||
self.current_instance_idx = 0
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'InstancePool':
|
||||
"""Get or create singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
await cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize GPU instances."""
|
||||
# Create multiple instances on the same GPU
|
||||
for i in range(settings.instances_per_gpu):
|
||||
instance = GPUInstance(settings.gpu_device, i)
|
||||
try:
|
||||
await instance.initialize()
|
||||
self.instances.append(instance)
|
||||
logger.info(f"Successfully initialized instance {i} on GPU {settings.gpu_device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize instance {i}: {e}")
|
||||
# If we failed to initialize any instance, cleanup and raise
|
||||
if not self.instances:
|
||||
raise RuntimeError("Failed to initialize any GPU instances")
|
||||
break
|
||||
|
||||
if not self.instances:
|
||||
raise RuntimeError("No GPU instances initialized")
|
||||
|
||||
logger.info(f"Successfully initialized {len(self.instances)} instances on GPU {settings.gpu_device}")
|
||||
|
||||
# Start request processor
|
||||
asyncio.create_task(self._process_queue())
|
||||
|
||||
def get_next_available_instance(self) -> Optional[GPUInstance]:
|
||||
"""Get next available GPU instance using round-robin."""
|
||||
start_idx = self.current_instance_idx
|
||||
|
||||
# Try to find an available instance
|
||||
for _ in range(len(self.instances)):
|
||||
instance = self.instances[self.current_instance_idx]
|
||||
self.current_instance_idx = (self.current_instance_idx + 1) % len(self.instances)
|
||||
|
||||
if not instance.is_busy:
|
||||
return instance
|
||||
|
||||
# If we're back at start, no instance is available
|
||||
if self.current_instance_idx == start_idx:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
"""Process requests from queue."""
|
||||
while True:
|
||||
try:
|
||||
# Get request from queue
|
||||
request = await self.request_queue.get()
|
||||
text, voice_info = request["text"], request["voice_info"]
|
||||
future = request["future"]
|
||||
|
||||
# Get available instance
|
||||
instance = self.get_next_available_instance()
|
||||
if instance is None:
|
||||
# No instance available, put back in queue
|
||||
await self.request_queue.put(request)
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Mark instance as busy
|
||||
instance.is_busy = True
|
||||
logger.debug(f"Processing request on instance {instance.instance_id}")
|
||||
|
||||
try:
|
||||
# Process request
|
||||
result = []
|
||||
async for chunk in instance.manager.generate(text, voice_info):
|
||||
result.append(chunk)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in instance {instance.instance_id}: {e}")
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Mark instance as available
|
||||
instance.is_busy = False
|
||||
self.request_queue.task_done()
|
||||
logger.debug(f"Instance {instance.instance_id} is now available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def process_request(self, text: str, voice_info: tuple) -> List[Any]:
|
||||
"""Submit request to queue and wait for result."""
|
||||
# Create future to get result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Create request
|
||||
request = {
|
||||
"text": text,
|
||||
"voice_info": voice_info,
|
||||
"future": future
|
||||
}
|
||||
|
||||
try:
|
||||
# Put request in queue with timeout
|
||||
await asyncio.wait_for(
|
||||
self.request_queue.put(request),
|
||||
timeout=settings.request_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError("Request queue is full")
|
||||
|
||||
try:
|
||||
# Wait for result with timeout
|
||||
result = await asyncio.wait_for(
|
||||
future,
|
||||
timeout=settings.request_timeout
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError("Request processing timed out")
|
|
@ -14,9 +14,6 @@ from .kokoro_v1 import KokoroV1
|
|||
class ModelManager:
|
||||
"""Manages Kokoro V1 model loading and inference."""
|
||||
|
||||
# Singleton instance
|
||||
_instance = None
|
||||
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize manager.
|
||||
|
||||
|
@ -158,7 +155,7 @@ Model files not found! You need to download the Kokoro V1 model:
|
|||
|
||||
|
||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get model manager instance.
|
||||
"""Create a new model manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional configuration override
|
||||
|
@ -166,6 +163,4 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|||
Returns:
|
||||
ModelManager instance
|
||||
"""
|
||||
if ModelManager._instance is None:
|
||||
ModelManager._instance = ModelManager(config)
|
||||
return ModelManager._instance
|
||||
return ModelManager(config)
|
||||
|
|
|
@ -47,7 +47,7 @@ setup_logger()
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for model initialization"""
|
||||
from .inference.model_manager import get_manager
|
||||
from .inference.instance_pool import InstancePool
|
||||
from .inference.voice_manager import get_manager as get_voice_manager
|
||||
from .services.temp_manager import cleanup_temp_files
|
||||
|
||||
|
@ -57,14 +57,17 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
try:
|
||||
# Initialize managers
|
||||
model_manager = await get_manager()
|
||||
# Initialize voice manager
|
||||
voice_manager = await get_voice_manager()
|
||||
|
||||
# Initialize model with warmup and get status
|
||||
device, model, voicepack_count = await model_manager.initialize_with_warmup(
|
||||
voice_manager
|
||||
)
|
||||
|
||||
# Initialize instance pool
|
||||
instance_pool = await InstancePool.get_instance()
|
||||
|
||||
# Get first instance for status info
|
||||
first_instance = instance_pool.instances[0]
|
||||
device = f"cuda:{first_instance.device_id}"
|
||||
model = first_instance.manager.current_backend
|
||||
instance_count = len(instance_pool.instances)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
|
@ -85,8 +88,14 @@ async def lifespan(app: FastAPI):
|
|||
{boundary}
|
||||
"""
|
||||
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||
startup_msg += f"CUDA: {torch.cuda.is_available()}"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
||||
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
|
||||
startup_msg += f"\nRunning {instance_count} instances on GPU {settings.gpu_device}"
|
||||
startup_msg += f"\nMax concurrent requests: {settings.max_concurrent}"
|
||||
startup_msg += f"\nRequest queue size: {settings.request_queue_size}"
|
||||
|
||||
# Add language code info
|
||||
lang_code_info = settings.default_voice_code or f"auto (from voice name: {settings.default_voice[0].lower()})"
|
||||
startup_msg += f"\nDefault language code: {lang_code_info}"
|
||||
|
||||
# Add web player info if enabled
|
||||
if settings.enable_web_player:
|
||||
|
@ -140,7 +149,14 @@ async def health_check():
|
|||
@app.get("/v1/test")
|
||||
async def test_endpoint():
|
||||
"""Test endpoint to verify routing"""
|
||||
return {"status": "ok"}
|
||||
from .core.config import settings
|
||||
|
||||
# Include authentication status in response
|
||||
return {
|
||||
"status": "ok",
|
||||
"auth_enabled": settings.enable_auth,
|
||||
"api_keys_configured": len(settings.api_keys) > 0 if settings.api_keys else False
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -143,6 +143,38 @@ async def get_system_info():
|
|||
}
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_config_info():
|
||||
"""Get information about the current configuration."""
|
||||
from ..core.config import settings
|
||||
from ..inference.kokoro_v1 import LANG_CODES
|
||||
|
||||
# Get the default voice code
|
||||
default_voice_code = settings.default_voice_code
|
||||
if default_voice_code == "null":
|
||||
default_voice_code = None
|
||||
|
||||
# Get the first letter of the default voice
|
||||
default_voice_first_letter = settings.default_voice[0].lower() if settings.default_voice else None
|
||||
|
||||
# Determine the effective language code
|
||||
effective_lang_code = default_voice_code or default_voice_first_letter
|
||||
|
||||
# Check if the effective language code is valid
|
||||
is_valid_lang_code = effective_lang_code in LANG_CODES if effective_lang_code else False
|
||||
|
||||
return {
|
||||
"default_voice": settings.default_voice,
|
||||
"default_voice_code": default_voice_code,
|
||||
"default_voice_first_letter": default_voice_first_letter,
|
||||
"effective_lang_code": effective_lang_code,
|
||||
"is_valid_lang_code": is_valid_lang_code,
|
||||
"available_lang_codes": LANG_CODES,
|
||||
"auth_enabled": settings.enable_auth,
|
||||
"api_keys_configured": len(settings.api_keys) > 0 if settings.api_keys else False,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/debug/session_pools")
|
||||
async def get_session_pool_info():
|
||||
"""Get information about ONNX session pools."""
|
||||
|
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Union, Tuple, Optional
|
||||
from urllib import response
|
||||
import numpy as np
|
||||
|
||||
|
@ -19,6 +19,7 @@ from loguru import logger
|
|||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..core.auth import verify_api_key
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures import OpenAISpeechRequest
|
||||
|
@ -44,6 +45,7 @@ _openai_mappings = load_openai_mappings()
|
|||
router = APIRouter(
|
||||
tags=["OpenAI Compatible TTS"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
dependencies=[Depends(verify_api_key)], # Apply authentication to all routes
|
||||
)
|
||||
|
||||
# Global TTSService instance with lock
|
||||
|
@ -142,14 +144,25 @@ async def stream_audio_chunks(
|
|||
if hasattr(request, "return_timestamps"):
|
||||
unique_properties["return_timestamps"]=request.return_timestamps
|
||||
|
||||
# Determine language code with proper fallback
|
||||
lang_code = request.lang_code
|
||||
if not lang_code:
|
||||
# Use default_voice_code from settings if available
|
||||
lang_code = settings.default_voice_code
|
||||
# Otherwise, use first letter of voice name
|
||||
if not lang_code and voice_name:
|
||||
lang_code = voice_name[0].lower()
|
||||
|
||||
# Log the language code being used
|
||||
logger.info(f"Starting audio generation with lang_code: {lang_code}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk_data in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
|
||||
lang_code=lang_code,
|
||||
normalization_options=request.normalization_options,
|
||||
return_timestamps=unique_properties["return_timestamps"],
|
||||
):
|
||||
|
@ -171,10 +184,10 @@ async def stream_audio_chunks(
|
|||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
# Validate model before processing request
|
||||
|
@ -280,12 +293,25 @@ async def create_speech(
|
|||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
|
||||
# Determine language code with proper fallback
|
||||
lang_code = request.lang_code
|
||||
if not lang_code:
|
||||
# Use default_voice_code from settings if available
|
||||
lang_code = settings.default_voice_code
|
||||
# Otherwise, use first letter of voice name
|
||||
if not lang_code and voice_name:
|
||||
lang_code = voice_name[0].lower()
|
||||
|
||||
# Log the language code being used
|
||||
logger.info(f"Starting audio generation with lang_code: {lang_code}")
|
||||
|
||||
audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
normalization_options=request.normalization_options,
|
||||
lang_code=request.lang_code,
|
||||
lang_code=lang_code,
|
||||
)
|
||||
|
||||
audio_data = await AudioService.convert_audio(
|
||||
|
@ -351,7 +377,10 @@ async def create_speech(
|
|||
|
||||
|
||||
@router.get("/download/{filename}")
|
||||
async def download_audio_file(filename: str):
|
||||
async def download_audio_file(
|
||||
filename: str,
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""Download a generated audio file from temp storage"""
|
||||
try:
|
||||
from ..core.paths import _find_file, get_content_type
|
||||
|
@ -387,7 +416,9 @@ async def download_audio_file(filename: str):
|
|||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models():
|
||||
async def list_models(
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""List all available models"""
|
||||
try:
|
||||
# Create standard model list
|
||||
|
@ -428,7 +459,10 @@ async def list_models():
|
|||
)
|
||||
|
||||
@router.get("/models/{model}")
|
||||
async def retrieve_model(model: str):
|
||||
async def retrieve_model(
|
||||
model: str,
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""Retrieve a specific model"""
|
||||
try:
|
||||
# Define available models
|
||||
|
@ -480,7 +514,9 @@ async def retrieve_model(model: str):
|
|||
)
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices():
|
||||
async def list_voices(
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""List all available voices for text-to-speech"""
|
||||
try:
|
||||
tts_service = await get_tts_service()
|
||||
|
@ -499,7 +535,10 @@ async def list_voices():
|
|||
|
||||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(request: Union[str, List[str]]):
|
||||
async def combine_voices(
|
||||
request: Union[str, List[str]],
|
||||
api_key: Optional[str] = Depends(verify_api_key),
|
||||
):
|
||||
"""Combine multiple voices into a new voice and return the .pt file.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -27,7 +27,7 @@ class StreamingAudioWriter:
|
|||
self.output_buffer = BytesIO()
|
||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
||||
self.stream.bit_rate = 128000
|
||||
self.stream.bit_rate = 96000
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
|
@ -43,34 +43,30 @@ class StreamingAudioWriter:
|
|||
|
||||
if finalize:
|
||||
if self.format != "pcm":
|
||||
packets = self.stream.encode(None)
|
||||
for packet in packets:
|
||||
# Flush encoder buffers
|
||||
for packet in self.stream.encode(None):
|
||||
self.container.mux(packet)
|
||||
|
||||
data=self.output_buffer.getvalue()
|
||||
self.container.close()
|
||||
data = self.output_buffer.getvalue()
|
||||
self.output_buffer.seek(0)
|
||||
self.output_buffer.truncate(0)
|
||||
return data
|
||||
return b""
|
||||
|
||||
if audio_data is None or len(audio_data) == 0:
|
||||
return b""
|
||||
|
||||
if self.format == "pcm":
|
||||
# Write raw bytes
|
||||
return audio_data.tobytes()
|
||||
else:
|
||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
||||
frame.sample_rate=self.sample_rate
|
||||
|
||||
|
||||
frame.sample_rate = self.sample_rate
|
||||
frame.pts = self.pts
|
||||
self.pts += frame.samples
|
||||
|
||||
packets = self.stream.encode(frame)
|
||||
for packet in packets:
|
||||
for packet in self.stream.encode(frame):
|
||||
self.container.mux(packet)
|
||||
|
||||
data = self.output_buffer.getvalue()
|
||||
self.output_buffer.seek(0)
|
||||
self.output_buffer.truncate(0)
|
||||
return data
|
||||
# 仅返回空字节,保持容器开放
|
||||
return b""
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union, Dict
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
import numpy as np
|
||||
|
@ -20,6 +20,8 @@ from .audio import AudioNormalizer, AudioService
|
|||
from .text_processing import tokenize
|
||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
from ..core import paths
|
||||
from ..inference.instance_pool import InstancePool
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
@ -27,20 +29,30 @@ class TTSService:
|
|||
# Limit concurrent chunk processing
|
||||
_chunk_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
def __init__(self, output_dir: str = None):
|
||||
def __init__(self):
|
||||
"""Initialize service."""
|
||||
self.output_dir = output_dir
|
||||
self.model_manager = None
|
||||
self._voice_manager = None
|
||||
self.voice_manager = None
|
||||
self.instance_pool = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, output_dir: str = None) -> "TTSService":
|
||||
async def create(cls) -> "TTSService":
|
||||
"""Create and initialize TTSService instance."""
|
||||
service = cls(output_dir)
|
||||
service.model_manager = await get_model_manager()
|
||||
service._voice_manager = await get_voice_manager()
|
||||
service = cls()
|
||||
await service.initialize()
|
||||
return service
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize service components."""
|
||||
# Initialize model manager
|
||||
self.model_manager = await get_model_manager()
|
||||
|
||||
# Initialize voice manager
|
||||
self.voice_manager = await get_voice_manager()
|
||||
|
||||
# Initialize instance pool
|
||||
self.instance_pool = await InstancePool.get_instance()
|
||||
|
||||
async def _process_chunk(
|
||||
self,
|
||||
chunk_text: str,
|
||||
|
@ -122,7 +134,7 @@ class TTSService:
|
|||
else:
|
||||
|
||||
# For legacy backends, load voice tensor
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_tensor = await self.voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
)
|
||||
chunk_data = await self.model_manager.generate(
|
||||
|
@ -205,7 +217,7 @@ class TTSService:
|
|||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for v, w in zip(voice_parts, weights):
|
||||
path = await self._voice_manager.get_voice_path(v)
|
||||
path = await self.voice_manager.get_voice_path(v)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {v}")
|
||||
logger.debug(f"Loading voice tensor from: {path}")
|
||||
|
@ -229,7 +241,7 @@ class TTSService:
|
|||
# Single voice
|
||||
if "(" in voice and ")" in voice:
|
||||
voice = voice.split("(")[0].strip()
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
path = await self.voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
|
@ -243,20 +255,27 @@ class TTSService:
|
|||
text: str,
|
||||
voice: str,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
output_format: str = "mp3",
|
||||
lang_code: Optional[str] = None,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
return_timestamps: Optional[bool] = False,
|
||||
normalization_options: Optional[Dict] = None,
|
||||
return_timestamps: bool = False,
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
current_offset=0.0
|
||||
try:
|
||||
# Get backend
|
||||
backend = self.model_manager.get_backend()
|
||||
"""Generate audio stream from text.
|
||||
|
||||
# Get voice path, handling combined voices
|
||||
Args:
|
||||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speech speed multiplier
|
||||
output_format: Output audio format
|
||||
lang_code: Language code for text processing
|
||||
normalization_options: Text normalization options
|
||||
return_timestamps: Whether to return timestamps
|
||||
|
||||
Yields:
|
||||
Audio chunks
|
||||
"""
|
||||
try:
|
||||
# Get voice path
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
logger.debug(f"Using voice path: {voice_path}")
|
||||
|
||||
|
@ -266,70 +285,16 @@ class TTSService:
|
|||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||
)
|
||||
|
||||
# Process request through instance pool
|
||||
chunks = await self.instance_pool.process_request(text, (voice_name, voice_path))
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
voice_path, # Pass voice path
|
||||
speed,
|
||||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
if chunk_data.word_timestamps is not None:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp.start_time+=current_offset
|
||||
timestamp.end_time+=current_offset
|
||||
|
||||
current_offset+=len(chunk_data.audio) / 24000
|
||||
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||
)
|
||||
chunk_index += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Only finalize if we successfully processed at least one chunk
|
||||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
# Yield chunks
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.error(f"Error generating audio: {e}")
|
||||
raise
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
|
@ -362,11 +327,11 @@ class TTSService:
|
|||
Returns:
|
||||
Combined voice tensor
|
||||
"""
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
return await self.voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices."""
|
||||
return await self._voice_manager.list_voices()
|
||||
return await paths.list_voices()
|
||||
|
||||
async def generate_from_phonemes(
|
||||
self,
|
||||
|
|
|
@ -1,51 +1,51 @@
|
|||
from collections.abc import AsyncIterable, Iterable
|
||||
|
||||
import json
|
||||
import typing
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
|
||||
class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||
"""StreamingResponse that also render with JSON."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Iterable | AsyncIterable,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
if isinstance(content, AsyncIterable):
|
||||
self._content_iterable: AsyncIterable = content
|
||||
else:
|
||||
self._content_iterable = iterate_in_threadpool(content)
|
||||
|
||||
|
||||
|
||||
async def body_iterator() -> AsyncIterable[bytes]:
|
||||
async for content_ in self._content_iterable:
|
||||
if isinstance(content_, BaseModel):
|
||||
content_ = content_.model_dump()
|
||||
yield self.render(content_)
|
||||
|
||||
|
||||
|
||||
self.body_iterator = body_iterator()
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return (json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
|
||||
import json
|
||||
import typing
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
|
||||
class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||
"""StreamingResponse that also render with JSON."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Iterable | AsyncIterable,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
if isinstance(content, AsyncIterable):
|
||||
self._content_iterable: AsyncIterable = content
|
||||
else:
|
||||
self._content_iterable = iterate_in_threadpool(content)
|
||||
|
||||
|
||||
|
||||
async def body_iterator() -> AsyncIterable[bytes]:
|
||||
async for content_ in self._content_iterable:
|
||||
if isinstance(content_, BaseModel):
|
||||
content_ = content_.model_dump()
|
||||
yield self.render(content_)
|
||||
|
||||
|
||||
|
||||
self.body_iterator = body_iterator()
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return (json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
) + "\n").encode("utf-8")
|
62
examples/test_auth.py
Normal file
62
examples/test_auth.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for Kokoro TTS API authentication
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_auth(base_url: str, api_key: Optional[str] = None) -> None:
|
||||
"""Test authentication with the API"""
|
||||
# Test the test endpoint
|
||||
test_url = f"{base_url}/v1/test"
|
||||
test_response = requests.get(test_url)
|
||||
test_data = test_response.json()
|
||||
|
||||
print(f"Test endpoint response: {json.dumps(test_data, indent=2)}")
|
||||
print(f"Authentication enabled: {test_data.get('auth_enabled', False)}")
|
||||
print(f"API keys configured: {test_data.get('api_keys_configured', False)}")
|
||||
|
||||
# Test the models endpoint
|
||||
models_url = f"{base_url}/v1/models"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
print("\nTesting models endpoint...")
|
||||
models_response = requests.get(models_url, headers=headers)
|
||||
|
||||
if models_response.status_code == 200:
|
||||
print("✅ Authentication successful!")
|
||||
models_data = models_response.json()
|
||||
print(f"Available models: {', '.join([model['id'] for model in models_data.get('data', [])])}")
|
||||
elif models_response.status_code == 401:
|
||||
print("❌ Authentication failed: Unauthorized")
|
||||
print(f"Error details: {models_response.json()}")
|
||||
else:
|
||||
print(f"❌ Unexpected response: {models_response.status_code}")
|
||||
print(f"Response: {models_response.text}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Test Kokoro TTS API authentication")
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="Base URL of the API")
|
||||
parser.add_argument("--key", help="API key to use for authentication")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use environment variable if key not provided
|
||||
api_key = args.key or os.environ.get("KOKORO_API_KEY")
|
||||
|
||||
test_auth(args.url, api_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
requirements-test.txt
Normal file
3
requirements-test.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
aiohttp>=3.8.0
|
||||
numpy>=1.21.0
|
||||
loguru>=0.6.0
|
|
@ -13,4 +13,4 @@ export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
|||
|
||||
# Run FastAPI with GPU extras using uv run
|
||||
uv pip install -e ".[gpu]"
|
||||
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
||||
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 50888
|
||||
|
|
172
tests/performance_test.py
Normal file
172
tests/performance_test.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import json
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
|
||||
# Test configuration
|
||||
BASE_URL = "http://localhost:50888"
|
||||
API_KEY = "sk-kokoro-f7a9b2c8e6d4g3h1j5k0"
|
||||
CONCURRENT_REQUESTS = 3 # Number of concurrent requests
|
||||
TOTAL_REQUESTS = 100 # Total number of requests to make
|
||||
TEST_DURATION = 60 # Test duration in seconds
|
||||
CHUNK_SIZE = 8192 # Increased chunk size for better performance
|
||||
|
||||
# Test payload
|
||||
TEST_PAYLOAD = {
|
||||
"model": "kokoro",
|
||||
"input": "This is a performance test text.", # Shorter test text
|
||||
"voice": "af_heart",
|
||||
"response_format": "mp3",
|
||||
"download_format": "mp3",
|
||||
"speed": 1,
|
||||
"stream": True,
|
||||
"return_download_link": False,
|
||||
"lang_code": "a"
|
||||
}
|
||||
|
||||
class PerformanceMetrics:
|
||||
def __init__(self):
|
||||
self.request_times = []
|
||||
self.audio_sizes = []
|
||||
self.success_count = 0
|
||||
self.error_count = 0
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.current_requests = 0
|
||||
self.max_concurrent = 0
|
||||
|
||||
def add_request(self, duration, audio_size=0, success=True):
|
||||
self.request_times.append(duration)
|
||||
if audio_size > 0:
|
||||
self.audio_sizes.append(audio_size)
|
||||
if success:
|
||||
self.success_count += 1
|
||||
else:
|
||||
self.error_count += 1
|
||||
|
||||
def update_concurrent(self, delta):
|
||||
self.current_requests += delta
|
||||
self.max_concurrent = max(self.max_concurrent, self.current_requests)
|
||||
|
||||
def calculate_metrics(self):
|
||||
test_duration = (self.end_time - self.start_time).total_seconds()
|
||||
qps = self.success_count / test_duration if test_duration > 0 else 0
|
||||
avg_latency = np.mean(self.request_times) if self.request_times else 0
|
||||
p95_latency = np.percentile(self.request_times, 95) if self.request_times else 0
|
||||
p99_latency = np.percentile(self.request_times, 99) if self.request_times else 0
|
||||
|
||||
total_audio_mb = sum(self.audio_sizes) / (1024 * 1024)
|
||||
audio_throughput = total_audio_mb / test_duration if test_duration > 0 else 0
|
||||
|
||||
return {
|
||||
"qps": qps,
|
||||
"avg_latency": avg_latency,
|
||||
"p95_latency": p95_latency,
|
||||
"p99_latency": p99_latency,
|
||||
"success_rate": (self.success_count / (self.success_count + self.error_count)) * 100,
|
||||
"audio_throughput_mbps": audio_throughput,
|
||||
"total_requests": self.success_count + self.error_count,
|
||||
"successful_requests": self.success_count,
|
||||
"failed_requests": self.error_count,
|
||||
"test_duration": test_duration,
|
||||
"max_concurrent": self.max_concurrent
|
||||
}
|
||||
|
||||
async def make_request(session, metrics, semaphore, request_id):
|
||||
try:
|
||||
async with semaphore:
|
||||
metrics.update_concurrent(1)
|
||||
start_time = time.time()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Accept": "audio/mpeg"
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{BASE_URL}/v1/audio/speech",
|
||||
json=TEST_PAYLOAD,
|
||||
headers=headers,
|
||||
ssl=False,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
total_size = 0
|
||||
audio_data = bytearray()
|
||||
|
||||
try:
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
if chunk: # Only process non-empty chunks
|
||||
audio_data.extend(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
duration = time.time() - start_time
|
||||
metrics.add_request(duration, total_size, True)
|
||||
logger.debug(f"Request {request_id} completed successfully: {total_size} bytes in {duration:.2f}s")
|
||||
return True
|
||||
|
||||
except Exception as chunk_error:
|
||||
logger.error(f"Chunk processing error in request {request_id}: {str(chunk_error)}")
|
||||
duration = time.time() - start_time
|
||||
metrics.add_request(duration, success=False)
|
||||
return False
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Request {request_id} failed with status {response.status}: {error_text}")
|
||||
duration = time.time() - start_time
|
||||
metrics.add_request(duration, success=False)
|
||||
return False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Request {request_id} timed out")
|
||||
duration = time.time() - start_time
|
||||
metrics.add_request(duration, success=False)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request {request_id} failed with error: {str(e)}")
|
||||
duration = time.time() - start_time
|
||||
metrics.add_request(duration, success=False)
|
||||
return False
|
||||
|
||||
finally:
|
||||
metrics.update_concurrent(-1)
|
||||
|
||||
async def run_load_test():
|
||||
metrics = PerformanceMetrics()
|
||||
metrics.start_time = datetime.now()
|
||||
|
||||
semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
for i in range(TOTAL_REQUESTS):
|
||||
task = asyncio.create_task(make_request(session, metrics, semaphore, i+1))
|
||||
tasks.append(task)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
metrics.end_time = datetime.now()
|
||||
return metrics
|
||||
|
||||
def print_results(metrics_data):
|
||||
logger.info("\n=== Performance Test Results ===")
|
||||
logger.info(f"Total Requests: {metrics_data['total_requests']}")
|
||||
logger.info(f"Successful Requests: {metrics_data['successful_requests']}")
|
||||
logger.info(f"Failed Requests: {metrics_data['failed_requests']}")
|
||||
logger.info(f"Success Rate: {metrics_data['success_rate']:.2f}%")
|
||||
logger.info(f"Test Duration: {metrics_data['test_duration']:.2f} seconds")
|
||||
logger.info(f"QPS: {metrics_data['qps']:.2f}")
|
||||
logger.info(f"Average Latency: {metrics_data['avg_latency']*1000:.2f} ms")
|
||||
logger.info(f"P95 Latency: {metrics_data['p95_latency']*1000:.2f} ms")
|
||||
logger.info(f"P99 Latency: {metrics_data['p99_latency']*1000:.2f} ms")
|
||||
logger.info(f"Audio Throughput: {metrics_data['audio_throughput_mbps']:.2f} MB/s")
|
||||
logger.info(f"Max Concurrent Requests: {metrics_data['max_concurrent']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting performance test...")
|
||||
metrics = asyncio.run(run_load_test())
|
||||
print_results(metrics.calculate_metrics())
|
Loading…
Add table
Reference in a new issue