From f95e526a3f7a92185432ecd56d2e416c427db638 Mon Sep 17 00:00:00 2001 From: Emmanuel Schmidbauer Date: Mon, 30 Dec 2024 13:39:35 -0500 Subject: [PATCH] add speed --- api/src/database/queue.py | 16 +++++++++------- api/src/models/schemas.py | 2 +- api/src/routers/tts.py | 3 ++- api/src/services/tts.py | 30 +++++++++++++++--------------- examples/test_tts.py | 10 ++++++---- 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/api/src/database/queue.py b/api/src/database/queue.py index 70a43d6..1e18e60 100644 --- a/api/src/database/queue.py +++ b/api/src/database/queue.py @@ -25,6 +25,7 @@ class QueueDB: status TEXT DEFAULT 'pending', output_file TEXT, processing_time REAL, + speed REAL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) """) conn.commit() @@ -42,18 +43,19 @@ class QueueDB: status TEXT DEFAULT 'pending', output_file TEXT, processing_time REAL, + speed REAL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) """) conn.commit() - def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int: + def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int: """Add a new TTS request to the queue""" conn = sqlite3.connect(self.db_path) try: c = conn.cursor() c.execute( - "INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)", - (text, voice, stitch_long_output) + "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)", + (text, voice, speed, stitch_long_output) ) request_id = c.lastrowid conn.commit() @@ -62,8 +64,8 @@ class QueueDB: self._ensure_table_if_needed(conn) c = conn.cursor() c.execute( - "INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)", - (text, voice, stitch_long_output) + "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)", + (text, voice, speed, stitch_long_output) ) request_id = c.lastrowid conn.commit() @@ -71,13 +73,13 @@ class QueueDB: finally: conn.close() - def get_next_pending(self) -> Optional[Tuple[int, str, str]]: + def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]: """Get the next pending request""" conn = sqlite3.connect(self.db_path) try: c = conn.cursor() c.execute( - 'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1' + 'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1' ) return c.fetchone() except sqlite3.OperationalError: # Table doesn't exist diff --git a/api/src/models/schemas.py b/api/src/models/schemas.py index daac584..4c02f45 100644 --- a/api/src/models/schemas.py +++ b/api/src/models/schemas.py @@ -7,6 +7,7 @@ class TTSRequest(BaseModel): text: str voice: str = "af" # Default voice local: bool = False # Whether to save file locally or return bytes + speed: float = 1.0 stitch_long_output: bool = True # Whether to stitch together long outputs @@ -20,4 +21,3 @@ class TTSResponse(BaseModel): class VoicesResponse(BaseModel): voices: list[str] default: str - diff --git a/api/src/routers/tts.py b/api/src/routers/tts.py index ae09e77..7bfa023 100644 --- a/api/src/routers/tts.py +++ b/api/src/routers/tts.py @@ -33,8 +33,9 @@ async def create_tts(request: TTSRequest): # Queue the request request_id = tts_service.create_tts_request( - request.text, + request.text, request.voice, + request.speed, request.stitch_long_output ) return { diff --git a/api/src/services/tts.py b/api/src/services/tts.py index e3c6a8d..dcf1c6b 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -65,16 +65,16 @@ class TTSService: # Look for the last occurrence of marker before max_tokens test_text = text[:max_tokens + margin] # Look a bit beyond the limit last_idx = test_text.rfind(marker) - + if last_idx != -1: # Verify this boundary is within our token limit candidate = text[:last_idx + len(marker)].strip() ps = phonemize(candidate, voice[0]) tokens = tokenize(ps) - + if len(tokens) <= max_tokens: return last_idx + len(marker) - + # If no good boundary found, find last whitespace within limit test_text = text[:max_tokens] last_space = test_text.rfind(' ') @@ -85,7 +85,7 @@ class TTSService: MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences chunks = [] remaining = text - + while remaining: # If remaining text is within limit, add it as final chunk ps = phonemize(remaining, voice[0]) @@ -93,17 +93,17 @@ class TTSService: if len(tokens) <= MAX_TOKENS: chunks.append(remaining.strip()) break - + # Find best boundary position split_pos = self._find_boundary(remaining, MAX_TOKENS, voice) - + # Add chunk and continue with remaining text chunks.append(remaining[:split_pos].strip()) remaining = remaining[split_pos:].strip() - + return chunks - def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: + def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: """Generate audio and measure processing time""" start_time = time.time() @@ -116,11 +116,11 @@ class TTSService: # Split text if needed and generate audio for each chunk chunks = self._split_text(text, voice) audio_chunks = [] - + for chunk in chunks: - chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0]) + chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed) audio_chunks.append(chunk_audio) - + # Concatenate audio chunks if len(audio_chunks) > 1: audio = np.concatenate(audio_chunks) @@ -149,10 +149,10 @@ class TTSService: while True: next_request = self.db.get_next_pending() if next_request: - request_id, text, voice, stitch_long_output = next_request + request_id, text, voice, speed, stitch_long_output = next_request try: # Generate audio and measure time - audio, processing_time = self._generate_audio(text, voice, stitch_long_output) + audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output) # Save to file output_file = os.path.abspath(os.path.join( @@ -186,9 +186,9 @@ class TTSService: print(f"Error listing voices: {str(e)}") return voices - def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int: + def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int: """Create a new TTS request and return the request ID""" - return self.db.add_request(text, voice, stitch_long_output) + return self.db.add_request(text, voice, speed, stitch_long_output) def get_request_status( self, request_id: int diff --git a/examples/test_tts.py b/examples/test_tts.py index e4fa7c6..c89b19d 100644 --- a/examples/test_tts.py +++ b/examples/test_tts.py @@ -22,11 +22,11 @@ def get_voices( def submit_tts_request( - text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880" + text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880" ) -> Optional[int]: """Submit a TTS request and return the request ID""" try: - payload = {"text": text, "voice": voice} if voice else {"text": text} + payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed} response = requests.post(f"{base_url}/tts", json=payload) if response.status_code != 200: print(f"Error submitting request: {response.text}") @@ -83,13 +83,14 @@ def download_audio( def generate_speech( text: str, voice: Optional[str] = None, + speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880", download: bool = True, ) -> bool: """Generate speech from text""" # Submit request print("Submitting request...") - request_id = submit_tts_request(text, voice, base_url) + request_id = submit_tts_request(text, voice, speed, base_url) if not request_id: return False @@ -149,6 +150,7 @@ def main(): parser = argparse.ArgumentParser(description="Kokoro TTS CLI") parser.add_argument("text", nargs="?", help="Text to convert to speech") parser.add_argument("--voice", help="Voice to use") + parser.add_argument("--speed", default=1.0, help="speed of speech") parser.add_argument("--url", default="http://localhost:8880", help="API base URL") parser.add_argument("--debug", action="store_true", help="Enable debug logging") parser.add_argument( @@ -177,7 +179,7 @@ def main(): ) success = generate_speech( - args.text, args.voice, args.url, download=not args.no_download + args.text, args.voice, args.speed, args.url, download=not args.no_download ) if not success: sys.exit(1)