mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
commit
5afb9e9be8
5 changed files with 33 additions and 28 deletions
|
@ -25,6 +25,7 @@ class QueueDB:
|
||||||
status TEXT DEFAULT 'pending',
|
status TEXT DEFAULT 'pending',
|
||||||
output_file TEXT,
|
output_file TEXT,
|
||||||
processing_time REAL,
|
processing_time REAL,
|
||||||
|
speed REAL,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -42,18 +43,19 @@ class QueueDB:
|
||||||
status TEXT DEFAULT 'pending',
|
status TEXT DEFAULT 'pending',
|
||||||
output_file TEXT,
|
output_file TEXT,
|
||||||
processing_time REAL,
|
processing_time REAL,
|
||||||
|
speed REAL,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
|
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
|
||||||
"""Add a new TTS request to the queue"""
|
"""Add a new TTS request to the queue"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
try:
|
try:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||||
(text, voice, stitch_long_output)
|
(text, voice, speed, stitch_long_output)
|
||||||
)
|
)
|
||||||
request_id = c.lastrowid
|
request_id = c.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -62,8 +64,8 @@ class QueueDB:
|
||||||
self._ensure_table_if_needed(conn)
|
self._ensure_table_if_needed(conn)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||||
(text, voice, stitch_long_output)
|
(text, voice, speed, stitch_long_output)
|
||||||
)
|
)
|
||||||
request_id = c.lastrowid
|
request_id = c.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -71,13 +73,13 @@ class QueueDB:
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_next_pending(self) -> Optional[Tuple[int, str, str]]:
|
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
|
||||||
"""Get the next pending request"""
|
"""Get the next pending request"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
try:
|
try:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||||
)
|
)
|
||||||
return c.fetchone()
|
return c.fetchone()
|
||||||
except sqlite3.OperationalError: # Table doesn't exist
|
except sqlite3.OperationalError: # Table doesn't exist
|
||||||
|
|
|
@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
voice: str = "af" # Default voice
|
voice: str = "af" # Default voice
|
||||||
local: bool = False # Whether to save file locally or return bytes
|
local: bool = False # Whether to save file locally or return bytes
|
||||||
|
speed: float = 1.0
|
||||||
stitch_long_output: bool = True # Whether to stitch together long outputs
|
stitch_long_output: bool = True # Whether to stitch together long outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,4 +21,3 @@ class TTSResponse(BaseModel):
|
||||||
class VoicesResponse(BaseModel):
|
class VoicesResponse(BaseModel):
|
||||||
voices: list[str]
|
voices: list[str]
|
||||||
default: str
|
default: str
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ async def create_tts(request: TTSRequest):
|
||||||
request_id = tts_service.create_tts_request(
|
request_id = tts_service.create_tts_request(
|
||||||
request.text,
|
request.text,
|
||||||
request.voice,
|
request.voice,
|
||||||
|
request.speed,
|
||||||
request.stitch_long_output
|
request.stitch_long_output
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -103,7 +103,7 @@ class TTSService:
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||||
"""Generate audio and measure processing time"""
|
"""Generate audio and measure processing time"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class TTSService:
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
|
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
|
||||||
audio_chunks.append(chunk_audio)
|
audio_chunks.append(chunk_audio)
|
||||||
|
|
||||||
# Concatenate audio chunks
|
# Concatenate audio chunks
|
||||||
|
@ -149,10 +149,10 @@ class TTSService:
|
||||||
while True:
|
while True:
|
||||||
next_request = self.db.get_next_pending()
|
next_request = self.db.get_next_pending()
|
||||||
if next_request:
|
if next_request:
|
||||||
request_id, text, voice, stitch_long_output = next_request
|
request_id, text, voice, speed, stitch_long_output = next_request
|
||||||
try:
|
try:
|
||||||
# Generate audio and measure time
|
# Generate audio and measure time
|
||||||
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
|
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
output_file = os.path.abspath(os.path.join(
|
output_file = os.path.abspath(os.path.join(
|
||||||
|
@ -186,9 +186,9 @@ class TTSService:
|
||||||
print(f"Error listing voices: {str(e)}")
|
print(f"Error listing voices: {str(e)}")
|
||||||
return voices
|
return voices
|
||||||
|
|
||||||
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
|
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
|
||||||
"""Create a new TTS request and return the request ID"""
|
"""Create a new TTS request and return the request ID"""
|
||||||
return self.db.add_request(text, voice, stitch_long_output)
|
return self.db.add_request(text, voice, speed, stitch_long_output)
|
||||||
|
|
||||||
def get_request_status(
|
def get_request_status(
|
||||||
self, request_id: int
|
self, request_id: int
|
||||||
|
|
|
@ -22,11 +22,11 @@ def get_voices(
|
||||||
|
|
||||||
|
|
||||||
def submit_tts_request(
|
def submit_tts_request(
|
||||||
text: str, voice: Optional[str] = None, base_url: str = "http://localhost:8880"
|
text: str, voice: Optional[str] = None, speed: Optional[float] = 1.0, base_url: str = "http://localhost:8880"
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
"""Submit a TTS request and return the request ID"""
|
"""Submit a TTS request and return the request ID"""
|
||||||
try:
|
try:
|
||||||
payload = {"text": text, "voice": voice} if voice else {"text": text}
|
payload = {"text": text, "speed": speed, "voice": voice} if voice else {"text": text, "speed": speed}
|
||||||
response = requests.post(f"{base_url}/tts", json=payload)
|
response = requests.post(f"{base_url}/tts", json=payload)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print(f"Error submitting request: {response.text}")
|
print(f"Error submitting request: {response.text}")
|
||||||
|
@ -83,13 +83,14 @@ def download_audio(
|
||||||
def generate_speech(
|
def generate_speech(
|
||||||
text: str,
|
text: str,
|
||||||
voice: Optional[str] = None,
|
voice: Optional[str] = None,
|
||||||
|
speed: Optional[float] = 1.0,
|
||||||
base_url: str = "http://localhost:8880",
|
base_url: str = "http://localhost:8880",
|
||||||
download: bool = True,
|
download: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Generate speech from text"""
|
"""Generate speech from text"""
|
||||||
# Submit request
|
# Submit request
|
||||||
print("Submitting request...")
|
print("Submitting request...")
|
||||||
request_id = submit_tts_request(text, voice, base_url)
|
request_id = submit_tts_request(text, voice, speed, base_url)
|
||||||
if not request_id:
|
if not request_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -149,6 +150,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
|
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
|
||||||
parser.add_argument("text", nargs="?", help="Text to convert to speech")
|
parser.add_argument("text", nargs="?", help="Text to convert to speech")
|
||||||
parser.add_argument("--voice", help="Voice to use")
|
parser.add_argument("--voice", help="Voice to use")
|
||||||
|
parser.add_argument("--speed", default=1.0, help="speed of speech")
|
||||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -177,7 +179,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
success = generate_speech(
|
success = generate_speech(
|
||||||
args.text, args.voice, args.url, download=not args.no_download
|
args.text, args.voice, args.speed, args.url, download=not args.no_download
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue