Merge pull request #1 from eschmidbauer/master

add speed
This commit is contained in:
remsky 2024-12-30 12:39:50 -07:00 committed by GitHub
commit 5afb9e9be8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 33 additions and 28 deletions

View file

@ -25,6 +25,7 @@ class QueueDB:
status TEXT DEFAULT 'pending', status TEXT DEFAULT 'pending',
output_file TEXT, output_file TEXT,
processing_time REAL, processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""") """)
conn.commit() conn.commit()
@ -42,18 +43,19 @@ class QueueDB:
status TEXT DEFAULT 'pending', status TEXT DEFAULT 'pending',
output_file TEXT, output_file TEXT,
processing_time REAL, processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""") """)
conn.commit() conn.commit()
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int: def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue""" """Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
try: try:
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)", "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, stitch_long_output) (text, voice, speed, stitch_long_output)
) )
request_id = c.lastrowid request_id = c.lastrowid
conn.commit() conn.commit()
@ -62,8 +64,8 @@ class QueueDB:
self._ensure_table_if_needed(conn) self._ensure_table_if_needed(conn)
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)", "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, stitch_long_output) (text, voice, speed, stitch_long_output)
) )
request_id = c.lastrowid request_id = c.lastrowid
conn.commit() conn.commit()
@ -71,13 +73,13 @@ class QueueDB:
finally: finally:
conn.close() conn.close()
def get_next_pending(self) -> Optional[Tuple[int, str, str]]: def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
"""Get the next pending request""" """Get the next pending request"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
try: try:
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1' 'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
) )
return c.fetchone() return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist except sqlite3.OperationalError: # Table doesn't exist

View file

@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
text: str text: str
voice: str = "af" # Default voice voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes local: bool = False # Whether to save file locally or return bytes
speed: float = 1.0
stitch_long_output: bool = True # Whether to stitch together long outputs stitch_long_output: bool = True # Whether to stitch together long outputs
@ -20,4 +21,3 @@ class TTSResponse(BaseModel):
class VoicesResponse(BaseModel): class VoicesResponse(BaseModel):
voices: list[str] voices: list[str]
default: str default: str

View file

@ -35,6 +35,7 @@ async def create_tts(request: TTSRequest):
request_id = tts_service.create_tts_request( request_id = tts_service.create_tts_request(
request.text, request.text,
request.voice, request.voice,
request.speed,
request.stitch_long_output request.stitch_long_output
) )
return { return {

View file

@ -103,7 +103,7 @@ class TTSService:
return chunks return chunks
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time""" """Generate audio and measure processing time"""
start_time = time.time() start_time = time.time()
@ -118,7 +118,7 @@ class TTSService:
audio_chunks = [] audio_chunks = []
for chunk in chunks: for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0]) chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
audio_chunks.append(chunk_audio) audio_chunks.append(chunk_audio)
# Concatenate audio chunks # Concatenate audio chunks
@ -149,10 +149,10 @@ class TTSService:
while True: while True:
next_request = self.db.get_next_pending() next_request = self.db.get_next_pending()
if next_request: if next_request:
request_id, text, voice, stitch_long_output = next_request request_id, text, voice, speed, stitch_long_output = next_request
try: try:
# Generate audio and measure time # Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice, stitch_long_output) audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
# Save to file # Save to file
output_file = os.path.abspath(os.path.join( output_file = os.path.abspath(os.path.join(
@ -186,9 +186,9 @@ class TTSService:
print(f"Error listing voices: {str(e)}") print(f"Error listing voices: {str(e)}")
return voices return voices
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int: def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID""" """Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, stitch_long_output) return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status( def get_request_status(
self, request_id: int self, request_id: int

View file

@ -22,11 +22,11 @@ def get_voices(
def submit_tts_request( def submit_tts_request(
text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880" text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880"
) -> Optional[int]: ) -> Optional[int]:
"""Submit a TTS request and return the request ID""" """Submit a TTS request and return the request ID"""
try: try:
payload = {"text": text, "voice": voice} if voice else {"text": text} payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed}
response = requests.post(f"{base_url}/tts", json=payload) response = requests.post(f"{base_url}/tts", json=payload)
if response.status_code != 200: if response.status_code != 200:
print(f"Error submitting request: {response.text}") print(f"Error submitting request: {response.text}")
@ -83,13 +83,14 @@ def download_audio(
def generate_speech( def generate_speech(
text: str, text: str,
voice: Optional[str] = None, voice: Optional[str] = None,
speed: Optional[float] = 1.0,
base_url: str = "http://localhost:8880", base_url: str = "http://localhost:8880",
download: bool = True, download: bool = True,
) -> bool: ) -> bool:
"""Generate speech from text""" """Generate speech from text"""
# Submit request # Submit request
print("Submitting request...") print("Submitting request...")
request_id = submit_tts_request(text, voice, base_url) request_id = submit_tts_request(text, voice, speed, base_url)
if not request_id: if not request_id:
return False return False
@ -149,6 +150,7 @@ def main():
parser = argparse.ArgumentParser(description="Kokoro TTS CLI") parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
parser.add_argument("text", nargs="?", help="Text to convert to speech") parser.add_argument("text", nargs="?", help="Text to convert to speech")
parser.add_argument("--voice", help="Voice to use") parser.add_argument("--voice", help="Voice to use")
parser.add_argument("--speed", default=1.0, help="speed of speech")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL") parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--debug", action="store_true", help="Enable debug logging") parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument( parser.add_argument(
@ -177,7 +179,7 @@ def main():
) )
success = generate_speech( success = generate_speech(
args.text, args.voice, args.url, download=not args.no_download args.text, args.voice, args.speed, args.url, download=not args.no_download
) )
if not success: if not success:
sys.exit(1) sys.exit(1)