Merge pull request #1 from eschmidbauer/master

add speed
This commit is contained in:
remsky 2024-12-30 12:39:50 -07:00 committed by GitHub
commit 5afb9e9be8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 33 additions and 28 deletions

View file

@ -25,6 +25,7 @@ class QueueDB:
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
@ -42,18 +43,19 @@ class QueueDB:
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -62,8 +64,8 @@ class QueueDB:
self._ensure_table_if_needed(conn)
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -71,13 +73,13 @@ class QueueDB:
finally:
conn.close()
def get_next_pending(self) -> Optional[Tuple[int, str, str]]:
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
"""Get the next pending request"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist

View file

@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
text: str
voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes
speed: float = 1.0
stitch_long_output: bool = True # Whether to stitch together long outputs
@ -20,4 +21,3 @@ class TTSResponse(BaseModel):
class VoicesResponse(BaseModel):
voices: list[str]
default: str

View file

@ -35,6 +35,7 @@ async def create_tts(request: TTSRequest):
request_id = tts_service.create_tts_request(
request.text,
request.voice,
request.speed,
request.stitch_long_output
)
return {

View file

@ -103,7 +103,7 @@ class TTSService:
return chunks
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
@ -118,7 +118,7 @@ class TTSService:
audio_chunks = []
for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
audio_chunks.append(chunk_audio)
# Concatenate audio chunks
@ -149,10 +149,10 @@ class TTSService:
while True:
next_request = self.db.get_next_pending()
if next_request:
request_id, text, voice, stitch_long_output = next_request
request_id, text, voice, speed, stitch_long_output = next_request
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
# Save to file
output_file = os.path.abspath(os.path.join(
@ -186,9 +186,9 @@ class TTSService:
print(f"Error listing voices: {str(e)}")
return voices
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, stitch_long_output)
return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status(
self, request_id: int

View file

@ -22,11 +22,11 @@ def get_voices(
def submit_tts_request(
text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880"
text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880"
) -> Optional[int]:
"""Submit a TTS request and return the request ID"""
try:
payload = {"text": text, "voice": voice} if voice else {"text": text}
payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed}
response = requests.post(f"{base_url}/tts", json=payload)
if response.status_code != 200:
print(f"Error submitting request: {response.text}")
@ -83,13 +83,14 @@ def download_audio(
def generate_speech(
text: str,
voice: Optional[str] = None,
speed: Optional[float] = 1.0,
base_url: str = "http://localhost:8880",
download: bool = True,
) -> bool:
"""Generate speech from text"""
# Submit request
print("Submitting request...")
request_id = submit_tts_request(text, voice, base_url)
request_id = submit_tts_request(text, voice, speed, base_url)
if not request_id:
return False
@ -149,6 +150,7 @@ def main():
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
parser.add_argument("text", nargs="?", help="Text to convert to speech")
parser.add_argument("--voice", help="Voice to use")
parser.add_argument("--speed", default=1.0, help="speed of speech")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument(
@ -177,7 +179,7 @@ def main():
)
success = generate_speech(
args.text, args.voice, args.url, download=not args.no_download
args.text, args.voice, args.speed, args.url, download=not args.no_download
)
if not success:
sys.exit(1)