mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
commit
5afb9e9be8
5 changed files with 33 additions and 28 deletions
|
@ -25,6 +25,7 @@ class QueueDB:
|
|||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
speed REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||
""")
|
||||
conn.commit()
|
||||
|
@ -42,18 +43,19 @@ class QueueDB:
|
|||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
speed REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
|
||||
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
|
||||
"""Add a new TTS request to the queue"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||
(text, voice, stitch_long_output)
|
||||
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||
(text, voice, speed, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
|
@ -62,8 +64,8 @@ class QueueDB:
|
|||
self._ensure_table_if_needed(conn)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||
(text, voice, stitch_long_output)
|
||||
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||
(text, voice, speed, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
|
@ -71,13 +73,13 @@ class QueueDB:
|
|||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_next_pending(self) -> Optional[Tuple[int, str, str]]:
|
||||
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
|
||||
"""Get the next pending request"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||
)
|
||||
return c.fetchone()
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
|
|
|
@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
|
|||
text: str
|
||||
voice: str = "af" # Default voice
|
||||
local: bool = False # Whether to save file locally or return bytes
|
||||
speed: float = 1.0
|
||||
stitch_long_output: bool = True # Whether to stitch together long outputs
|
||||
|
||||
|
||||
|
@ -20,4 +21,3 @@ class TTSResponse(BaseModel):
|
|||
class VoicesResponse(BaseModel):
|
||||
voices: list[str]
|
||||
default: str
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ async def create_tts(request: TTSRequest):
|
|||
request_id = tts_service.create_tts_request(
|
||||
request.text,
|
||||
request.voice,
|
||||
request.speed,
|
||||
request.stitch_long_output
|
||||
)
|
||||
return {
|
||||
|
|
|
@ -103,7 +103,7 @@ class TTSService:
|
|||
|
||||
return chunks
|
||||
|
||||
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -118,7 +118,7 @@ class TTSService:
|
|||
audio_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
|
||||
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
|
||||
audio_chunks.append(chunk_audio)
|
||||
|
||||
# Concatenate audio chunks
|
||||
|
@ -149,10 +149,10 @@ class TTSService:
|
|||
while True:
|
||||
next_request = self.db.get_next_pending()
|
||||
if next_request:
|
||||
request_id, text, voice, stitch_long_output = next_request
|
||||
request_id, text, voice, speed, stitch_long_output = next_request
|
||||
try:
|
||||
# Generate audio and measure time
|
||||
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
|
||||
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
|
||||
|
||||
# Save to file
|
||||
output_file = os.path.abspath(os.path.join(
|
||||
|
@ -186,9 +186,9 @@ class TTSService:
|
|||
print(f"Error listing voices: {str(e)}")
|
||||
return voices
|
||||
|
||||
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
|
||||
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
|
||||
"""Create a new TTS request and return the request ID"""
|
||||
return self.db.add_request(text, voice, stitch_long_output)
|
||||
return self.db.add_request(text, voice, speed, stitch_long_output)
|
||||
|
||||
def get_request_status(
|
||||
self, request_id: int
|
||||
|
|
|
@ -22,11 +22,11 @@ def get_voices(
|
|||
|
||||
|
||||
def submit_tts_request(
|
||||
text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880"
|
||||
text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880"
|
||||
) -> Optional[int]:
|
||||
"""Submit a TTS request and return the request ID"""
|
||||
try:
|
||||
payload = {"text": text, "voice": voice} if voice else {"text": text}
|
||||
payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed}
|
||||
response = requests.post(f"{base_url}/tts", json=payload)
|
||||
if response.status_code != 200:
|
||||
print(f"Error submitting request: {response.text}")
|
||||
|
@ -83,13 +83,14 @@ def download_audio(
|
|||
def generate_speech(
|
||||
text: str,
|
||||
voice: Optional[str] = None,
|
||||
speed: Optional[float] = 1.0,
|
||||
base_url: str = "http://localhost:8880",
|
||||
download: bool = True,
|
||||
) -> bool:
|
||||
"""Generate speech from text"""
|
||||
# Submit request
|
||||
print("Submitting request...")
|
||||
request_id = submit_tts_request(text, voice, base_url)
|
||||
request_id = submit_tts_request(text, voice, speed, base_url)
|
||||
if not request_id:
|
||||
return False
|
||||
|
||||
|
@ -149,6 +150,7 @@ def main():
|
|||
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
|
||||
parser.add_argument("text", nargs="?", help="Text to convert to speech")
|
||||
parser.add_argument("--voice", help="Voice to use")
|
||||
parser.add_argument("--speed", default=1.0, help="speed of speech")
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
parser.add_argument(
|
||||
|
@ -177,7 +179,7 @@ def main():
|
|||
)
|
||||
|
||||
success = generate_speech(
|
||||
args.text, args.voice, args.url, download=not args.no_download
|
||||
args.text, args.voice, args.speed, args.url, download=not args.no_download
|
||||
)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
|
Loading…
Add table
Reference in a new issue