Merge pull request #1 from eschmidbauer/master

add speed
This commit is contained in:
remsky 2024-12-30 12:39:50 -07:00 committed by GitHub
commit 5afb9e9be8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 33 additions and 28 deletions

View file

@ -25,6 +25,7 @@ class QueueDB:
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
@ -42,18 +43,19 @@ class QueueDB:
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -62,8 +64,8 @@ class QueueDB:
self._ensure_table_if_needed(conn)
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -71,13 +73,13 @@ class QueueDB:
finally:
conn.close()
def get_next_pending(self) -> Optional[Tuple[int, str, str]]:
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
"""Get the next pending request"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist

View file

@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
text: str
voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes
speed: float = 1.0
stitch_long_output: bool = True # Whether to stitch together long outputs
@ -20,4 +21,3 @@ class TTSResponse(BaseModel):
class VoicesResponse(BaseModel):
voices: list[str]
default: str

View file

@ -33,8 +33,9 @@ async def create_tts(request: TTSRequest):
# Queue the request
request_id = tts_service.create_tts_request(
request.text,
request.text,
request.voice,
request.speed,
request.stitch_long_output
)
return {

View file

@ -65,16 +65,16 @@ class TTSService:
# Look for the last occurrence of marker before max_tokens
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
last_idx = test_text.rfind(marker)
if last_idx != -1:
# Verify this boundary is within our token limit
candidate = text[:last_idx + len(marker)].strip()
ps = phonemize(candidate, voice[0])
tokens = tokenize(ps)
if len(tokens) <= max_tokens:
return last_idx + len(marker)
# If no good boundary found, find last whitespace within limit
test_text = text[:max_tokens]
last_space = test_text.rfind(' ')
@ -85,7 +85,7 @@ class TTSService:
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
chunks = []
remaining = text
while remaining:
# If remaining text is within limit, add it as final chunk
ps = phonemize(remaining, voice[0])
@ -93,17 +93,17 @@ class TTSService:
if len(tokens) <= MAX_TOKENS:
chunks.append(remaining.strip())
break
# Find best boundary position
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
# Add chunk and continue with remaining text
chunks.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:].strip()
return chunks
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
@ -116,11 +116,11 @@ class TTSService:
# Split text if needed and generate audio for each chunk
chunks = self._split_text(text, voice)
audio_chunks = []
for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
audio_chunks.append(chunk_audio)
# Concatenate audio chunks
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
@ -149,10 +149,10 @@ class TTSService:
while True:
next_request = self.db.get_next_pending()
if next_request:
request_id, text, voice, stitch_long_output = next_request
request_id, text, voice, speed, stitch_long_output = next_request
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
# Save to file
output_file = os.path.abspath(os.path.join(
@ -186,9 +186,9 @@ class TTSService:
print(f"Error listing voices: {str(e)}")
return voices
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, stitch_long_output)
return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status(
self, request_id: int

View file

@ -22,11 +22,11 @@ def get_voices(
def submit_tts_request(
text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880"
text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880"
) -> Optional[int]:
"""Submit a TTS request and return the request ID"""
try:
payload = {"text": text, "voice": voice} if voice else {"text": text}
payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed}
response = requests.post(f"{base_url}/tts", json=payload)
if response.status_code != 200:
print(f"Error submitting request: {response.text}")
@ -83,13 +83,14 @@ def download_audio(
def generate_speech(
text: str,
voice: Optional[str] = None,
speed: Optional[float] = 1.0,
base_url: str = "http://localhost:8880",
download: bool = True,
) -> bool:
"""Generate speech from text"""
# Submit request
print("Submitting request...")
request_id = submit_tts_request(text, voice, base_url)
request_id = submit_tts_request(text, voice, speed, base_url)
if not request_id:
return False
@ -149,6 +150,7 @@ def main():
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
parser.add_argument("text", nargs="?", help="Text to convert to speech")
parser.add_argument("--voice", help="Voice to use")
parser.add_argument("--speed", default=1.0, help="speed of speech")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument(
@ -177,7 +179,7 @@ def main():
)
success = generate_speech(
args.text, args.voice, args.url, download=not args.no_download
args.text, args.voice, args.speed, args.url, download=not args.no_download
)
if not success:
sys.exit(1)