mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge remote-tracking branch 'upstream/master' into streaming-word-timestamps
This commit is contained in:
commit
c1207f085b
10 changed files with 769 additions and 726 deletions
5
.gitattributes
vendored
Normal file
5
.gitattributes
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
* text=auto
|
||||||
|
|
||||||
|
*.py text eol=lf
|
||||||
|
*.sh text eol=lf
|
||||||
|
*.yml text eol=lf
|
|
@ -1,413 +1,413 @@
|
||||||
"""Async file and path operations."""
|
"""Async file and path operations."""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
|
|
||||||
async def _find_file(
|
async def _find_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
search_paths: List[str],
|
search_paths: List[str],
|
||||||
filter_fn: Optional[Callable[[str], bool]] = None,
|
filter_fn: Optional[Callable[[str], bool]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Find file in search paths.
|
"""Find file in search paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of file to find
|
filename: Name of file to find
|
||||||
search_paths: List of paths to search in
|
search_paths: List of paths to search in
|
||||||
filter_fn: Optional function to filter files
|
filter_fn: Optional function to filter files
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Absolute path to file
|
Absolute path to file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file not found
|
RuntimeError: If file not found
|
||||||
"""
|
"""
|
||||||
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
for path in search_paths:
|
for path in search_paths:
|
||||||
full_path = os.path.join(path, filename)
|
full_path = os.path.join(path, filename)
|
||||||
if await aiofiles.os.path.exists(full_path):
|
if await aiofiles.os.path.exists(full_path):
|
||||||
if filter_fn is None or filter_fn(full_path):
|
if filter_fn is None or filter_fn(full_path):
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
|
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
|
||||||
|
|
||||||
|
|
||||||
async def _scan_directories(
|
async def _scan_directories(
|
||||||
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
|
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Scan directories for files.
|
"""Scan directories for files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_paths: List of paths to scan
|
search_paths: List of paths to scan
|
||||||
filter_fn: Optional function to filter files
|
filter_fn: Optional function to filter files
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Set of matching filenames
|
Set of matching filenames
|
||||||
"""
|
"""
|
||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
for path in search_paths:
|
for path in search_paths:
|
||||||
if not await aiofiles.os.path.exists(path):
|
if not await aiofiles.os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get directory entries first
|
# Get directory entries first
|
||||||
entries = await aiofiles.os.scandir(path)
|
entries = await aiofiles.os.scandir(path)
|
||||||
# Then process entries after await completes
|
# Then process entries after await completes
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if filter_fn is None or filter_fn(entry.name):
|
if filter_fn is None or filter_fn(entry.name):
|
||||||
results.add(entry.name)
|
results.add(entry.name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error scanning {path}: {e}")
|
logger.warning(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def get_model_path(model_name: str) -> str:
|
async def get_model_path(model_name: str) -> str:
|
||||||
"""Get path to model file.
|
"""Get path to model file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of model file
|
model_name: Name of model file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Absolute path to model file
|
Absolute path to model file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If model not found
|
RuntimeError: If model not found
|
||||||
"""
|
"""
|
||||||
# Get api directory path (two levels up from core)
|
# Get api directory path (two levels up from core)
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
# Construct model directory path relative to api directory
|
# Construct model directory path relative to api directory
|
||||||
model_dir = os.path.join(api_dir, settings.model_dir)
|
model_dir = os.path.join(api_dir, settings.model_dir)
|
||||||
|
|
||||||
# Ensure model directory exists
|
# Ensure model directory exists
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
# Search in model directory
|
# Search in model directory
|
||||||
search_paths = [model_dir]
|
search_paths = [model_dir]
|
||||||
logger.debug(f"Searching for model in path: {model_dir}")
|
logger.debug(f"Searching for model in path: {model_dir}")
|
||||||
|
|
||||||
return await _find_file(model_name, search_paths)
|
return await _find_file(model_name, search_paths)
|
||||||
|
|
||||||
|
|
||||||
async def get_voice_path(voice_name: str) -> str:
|
async def get_voice_path(voice_name: str) -> str:
|
||||||
"""Get path to voice file.
|
"""Get path to voice file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice_name: Name of voice file (without .pt extension)
|
voice_name: Name of voice file (without .pt extension)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Absolute path to voice file
|
Absolute path to voice file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If voice not found
|
RuntimeError: If voice not found
|
||||||
"""
|
"""
|
||||||
# Get api directory path
|
# Get api directory path
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
# Construct voice directory path relative to api directory
|
# Construct voice directory path relative to api directory
|
||||||
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
# Ensure voice directory exists
|
# Ensure voice directory exists
|
||||||
os.makedirs(voice_dir, exist_ok=True)
|
os.makedirs(voice_dir, exist_ok=True)
|
||||||
|
|
||||||
voice_file = f"{voice_name}.pt"
|
voice_file = f"{voice_name}.pt"
|
||||||
|
|
||||||
# Search in voice directory/o
|
# Search in voice directory/o
|
||||||
search_paths = [voice_dir]
|
search_paths = [voice_dir]
|
||||||
logger.debug(f"Searching for voice in path: {voice_dir}")
|
logger.debug(f"Searching for voice in path: {voice_dir}")
|
||||||
|
|
||||||
return await _find_file(voice_file, search_paths)
|
return await _find_file(voice_file, search_paths)
|
||||||
|
|
||||||
|
|
||||||
async def list_voices() -> List[str]:
|
async def list_voices() -> List[str]:
|
||||||
"""List available voice files.
|
"""List available voice files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of voice names (without .pt extension)
|
List of voice names (without .pt extension)
|
||||||
"""
|
"""
|
||||||
# Get api directory path
|
# Get api directory path
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
# Construct voice directory path relative to api directory
|
# Construct voice directory path relative to api directory
|
||||||
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
# Ensure voice directory exists
|
# Ensure voice directory exists
|
||||||
os.makedirs(voice_dir, exist_ok=True)
|
os.makedirs(voice_dir, exist_ok=True)
|
||||||
|
|
||||||
# Search in voice directory
|
# Search in voice directory
|
||||||
search_paths = [voice_dir]
|
search_paths = [voice_dir]
|
||||||
logger.debug(f"Scanning for voices in path: {voice_dir}")
|
logger.debug(f"Scanning for voices in path: {voice_dir}")
|
||||||
|
|
||||||
def filter_voice_files(name: str) -> bool:
|
def filter_voice_files(name: str) -> bool:
|
||||||
return name.endswith(".pt")
|
return name.endswith(".pt")
|
||||||
|
|
||||||
voices = await _scan_directories(search_paths, filter_voice_files)
|
voices = await _scan_directories(search_paths, filter_voice_files)
|
||||||
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||||
|
|
||||||
|
|
||||||
async def load_voice_tensor(
|
async def load_voice_tensor(
|
||||||
voice_path: str, device: str = "cpu", weights_only=False
|
voice_path: str, device: str = "cpu", weights_only=False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Load voice tensor from file.
|
"""Load voice tensor from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice_path: Path to voice file
|
voice_path: Path to voice file
|
||||||
device: Device to load tensor to
|
device: Device to load tensor to
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Voice tensor
|
Voice tensor
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be read
|
RuntimeError: If file cannot be read
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(voice_path, "rb") as f:
|
async with aiofiles.open(voice_path, "rb") as f:
|
||||||
data = await f.read()
|
data = await f.read()
|
||||||
return torch.load(
|
return torch.load(
|
||||||
io.BytesIO(data), map_location=device, weights_only=weights_only
|
io.BytesIO(data), map_location=device, weights_only=weights_only
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
||||||
"""Save voice tensor to file.
|
"""Save voice tensor to file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: Voice tensor to save
|
tensor: Voice tensor to save
|
||||||
voice_path: Path to save voice file
|
voice_path: Path to save voice file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be written
|
RuntimeError: If file cannot be written
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
torch.save(tensor, buffer)
|
torch.save(tensor, buffer)
|
||||||
async with aiofiles.open(voice_path, "wb") as f:
|
async with aiofiles.open(voice_path, "wb") as f:
|
||||||
await f.write(buffer.getvalue())
|
await f.write(buffer.getvalue())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def load_json(path: str) -> dict:
|
async def load_json(path: str) -> dict:
|
||||||
"""Load JSON file asynchronously.
|
"""Load JSON file asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to JSON file
|
path: Path to JSON file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Parsed JSON data
|
Parsed JSON data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be read or parsed
|
RuntimeError: If file cannot be read or parsed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
||||||
content = await f.read()
|
content = await f.read()
|
||||||
return json.loads(content)
|
return json.loads(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
|
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def load_model_weights(path: str, device: str = "cpu") -> dict:
|
async def load_model_weights(path: str, device: str = "cpu") -> dict:
|
||||||
"""Load model weights asynchronously.
|
"""Load model weights asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to model file (.pth or .onnx)
|
path: Path to model file (.pth or .onnx)
|
||||||
device: Device to load model to
|
device: Device to load model to
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model weights
|
Model weights
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be read
|
RuntimeError: If file cannot be read
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(path, "rb") as f:
|
async with aiofiles.open(path, "rb") as f:
|
||||||
data = await f.read()
|
data = await f.read()
|
||||||
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
|
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
|
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def read_file(path: str) -> str:
|
async def read_file(path: str) -> str:
|
||||||
"""Read text file asynchronously.
|
"""Read text file asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to file
|
path: Path to file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
File contents as string
|
File contents as string
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be read
|
RuntimeError: If file cannot be read
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
||||||
return await f.read()
|
return await f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to read file {path}: {e}")
|
raise RuntimeError(f"Failed to read file {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def read_bytes(path: str) -> bytes:
|
async def read_bytes(path: str) -> bytes:
|
||||||
"""Read file as bytes asynchronously.
|
"""Read file as bytes asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to file
|
path: Path to file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
File contents as bytes
|
File contents as bytes
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file cannot be read
|
RuntimeError: If file cannot be read
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(path, "rb") as f:
|
async with aiofiles.open(path, "rb") as f:
|
||||||
return await f.read()
|
return await f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to read file {path}: {e}")
|
raise RuntimeError(f"Failed to read file {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_web_file_path(filename: str) -> str:
|
async def get_web_file_path(filename: str) -> str:
|
||||||
"""Get path to web static file.
|
"""Get path to web static file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of file in web directory
|
filename: Name of file in web directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Absolute path to file
|
Absolute path to file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If file not found
|
RuntimeError: If file not found
|
||||||
"""
|
"""
|
||||||
# Get project root directory (four levels up from core to get to project root)
|
# Get project root directory (four levels up from core to get to project root)
|
||||||
root_dir = os.path.dirname(
|
root_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct web directory path relative to project root
|
# Construct web directory path relative to project root
|
||||||
web_dir = os.path.join("/app", settings.web_player_path)
|
web_dir = os.path.join("/app", settings.web_player_path)
|
||||||
|
|
||||||
# Search in web directory
|
# Search in web directory
|
||||||
search_paths = [web_dir]
|
search_paths = [web_dir]
|
||||||
logger.debug(f"Searching for web file in path: {web_dir}")
|
logger.debug(f"Searching for web file in path: {web_dir}")
|
||||||
|
|
||||||
return await _find_file(filename, search_paths)
|
return await _find_file(filename, search_paths)
|
||||||
|
|
||||||
|
|
||||||
async def get_content_type(path: str) -> str:
|
async def get_content_type(path: str) -> str:
|
||||||
"""Get content type for file.
|
"""Get content type for file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to file
|
path: Path to file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Content type string
|
Content type string
|
||||||
"""
|
"""
|
||||||
ext = os.path.splitext(path)[1].lower()
|
ext = os.path.splitext(path)[1].lower()
|
||||||
return {
|
return {
|
||||||
".html": "text/html",
|
".html": "text/html",
|
||||||
".js": "application/javascript",
|
".js": "application/javascript",
|
||||||
".css": "text/css",
|
".css": "text/css",
|
||||||
".png": "image/png",
|
".png": "image/png",
|
||||||
".jpg": "image/jpeg",
|
".jpg": "image/jpeg",
|
||||||
".jpeg": "image/jpeg",
|
".jpeg": "image/jpeg",
|
||||||
".gif": "image/gif",
|
".gif": "image/gif",
|
||||||
".svg": "image/svg+xml",
|
".svg": "image/svg+xml",
|
||||||
".ico": "image/x-icon",
|
".ico": "image/x-icon",
|
||||||
}.get(ext, "application/octet-stream")
|
}.get(ext, "application/octet-stream")
|
||||||
|
|
||||||
|
|
||||||
async def verify_model_path(model_path: str) -> bool:
|
async def verify_model_path(model_path: str) -> bool:
|
||||||
"""Verify model file exists at path."""
|
"""Verify model file exists at path."""
|
||||||
return await aiofiles.os.path.exists(model_path)
|
return await aiofiles.os.path.exists(model_path)
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_temp_files() -> None:
|
async def cleanup_temp_files() -> None:
|
||||||
"""Clean up old temp files on startup"""
|
"""Clean up old temp files on startup"""
|
||||||
try:
|
try:
|
||||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry.is_file():
|
if entry.is_file():
|
||||||
stat = await aiofiles.os.stat(entry.path)
|
stat = await aiofiles.os.stat(entry.path)
|
||||||
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
|
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
|
||||||
if max_age < stat.st_mtime:
|
if max_age < stat.st_mtime:
|
||||||
try:
|
try:
|
||||||
await aiofiles.os.remove(entry.path)
|
await aiofiles.os.remove(entry.path)
|
||||||
logger.info(f"Cleaned up old temp file: {entry.name}")
|
logger.info(f"Cleaned up old temp file: {entry.name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to delete old temp file {entry.name}: {e}"
|
f"Failed to delete old temp file {entry.name}: {e}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error cleaning temp files: {e}")
|
logger.warning(f"Error cleaning temp files: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_temp_file_path(filename: str) -> str:
|
async def get_temp_file_path(filename: str) -> str:
|
||||||
"""Get path to temporary audio file.
|
"""Get path to temporary audio file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of temp file
|
filename: Name of temp file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Absolute path to temp file
|
Absolute path to temp file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If temp directory does not exist
|
RuntimeError: If temp directory does not exist
|
||||||
"""
|
"""
|
||||||
temp_path = os.path.join(settings.temp_file_dir, filename)
|
temp_path = os.path.join(settings.temp_file_dir, filename)
|
||||||
|
|
||||||
# Ensure temp directory exists
|
# Ensure temp directory exists
|
||||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||||
|
|
||||||
return temp_path
|
return temp_path
|
||||||
|
|
||||||
|
|
||||||
async def list_temp_files() -> List[str]:
|
async def list_temp_files() -> List[str]:
|
||||||
"""List temporary audio files.
|
"""List temporary audio files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of temp file names
|
List of temp file names
|
||||||
"""
|
"""
|
||||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
return [entry.name for entry in entries if entry.is_file()]
|
return [entry.name for entry in entries if entry.is_file()]
|
||||||
|
|
||||||
|
|
||||||
async def get_temp_dir_size() -> int:
|
async def get_temp_dir_size() -> int:
|
||||||
"""Get total size of temp directory in bytes.
|
"""Get total size of temp directory in bytes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Size in bytes
|
Size in bytes
|
||||||
"""
|
"""
|
||||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry.is_file():
|
if entry.is_file():
|
||||||
stat = await aiofiles.os.stat(entry.path)
|
stat = await aiofiles.os.stat(entry.path)
|
||||||
total += stat.st_size
|
total += stat.st_size
|
||||||
return total
|
return total
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
"""Model inference package."""
|
"""Model inference package."""
|
||||||
|
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .kokoro_v1 import KokoroV1
|
from .kokoro_v1 import KokoroV1
|
||||||
from .model_manager import ModelManager, get_manager
|
from .model_manager import ModelManager, get_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModelBackend",
|
"BaseModelBackend",
|
||||||
"ModelManager",
|
"ModelManager",
|
||||||
"get_manager",
|
"get_manager",
|
||||||
"KokoroV1",
|
"KokoroV1",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,120 +1,120 @@
|
||||||
"""Base interface for Kokoro inference."""
|
"""Base interface for Kokoro inference."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Optional, Tuple, Union, List
|
from typing import AsyncGenerator, Optional, Tuple, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class AudioChunk:
|
class AudioChunk:
|
||||||
"""Class for audio chunks returned by model backends"""
|
"""Class for audio chunks returned by model backends"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
word_timestamps: Optional[List]=[],
|
word_timestamps: Optional[List]=[],
|
||||||
output: Optional[Union[bytes,np.ndarray]]=b""
|
output: Optional[Union[bytes,np.ndarray]]=b""
|
||||||
):
|
):
|
||||||
self.audio=audio
|
self.audio=audio
|
||||||
self.word_timestamps=word_timestamps
|
self.word_timestamps=word_timestamps
|
||||||
self.output=output
|
self.output=output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine(audio_chunk_list: List):
|
def combine(audio_chunk_list: List):
|
||||||
output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
|
output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
|
||||||
|
|
||||||
for audio_chunk in audio_chunk_list[1:]:
|
for audio_chunk in audio_chunk_list[1:]:
|
||||||
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
|
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
|
||||||
if output.word_timestamps is not None:
|
if output.word_timestamps is not None:
|
||||||
output.word_timestamps+=audio_chunk.word_timestamps
|
output.word_timestamps+=audio_chunk.word_timestamps
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class ModelBackend(ABC):
|
class ModelBackend(ABC):
|
||||||
"""Abstract base class for model inference backend."""
|
"""Abstract base class for model inference backend."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load model from path.
|
"""Load model from path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to model file
|
path: Path to model file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If model loading fails
|
RuntimeError: If model loading fails
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
) -> AsyncGenerator[AudioChunk, None]:
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Generate audio from text.
|
"""Generate audio from text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Input text to synthesize
|
text: Input text to synthesize
|
||||||
voice: Either a voice path or tuple of (name, tensor/path)
|
voice: Either a voice path or tuple of (name, tensor/path)
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generated audio chunks
|
Generated audio chunks
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If generation fails
|
RuntimeError: If generation fails
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""Check if model is loaded.
|
"""Check if model is loaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if model is loaded, False otherwise
|
True if model is loaded, False otherwise
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def device(self) -> str:
|
def device(self) -> str:
|
||||||
"""Get device model is running on.
|
"""Get device model is running on.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Device string ('cpu' or 'cuda')
|
Device string ('cpu' or 'cuda')
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseModelBackend(ModelBackend):
|
class BaseModelBackend(ModelBackend):
|
||||||
"""Base implementation of model backend."""
|
"""Base implementation of model backend."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize base backend."""
|
"""Initialize base backend."""
|
||||||
self._model: Optional[torch.nn.Module] = None
|
self._model: Optional[torch.nn.Module] = None
|
||||||
self._device: str = "cpu"
|
self._device: str = "cpu"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""Check if model is loaded."""
|
"""Check if model is loaded."""
|
||||||
return self._model is not None
|
return self._model is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> str:
|
def device(self) -> str:
|
||||||
"""Get device model is running on."""
|
"""Get device model is running on."""
|
||||||
return self._device
|
return self._device
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
if self._model is not None:
|
if self._model is not None:
|
||||||
del self._model
|
del self._model
|
||||||
self._model = None
|
self._model = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
|
@ -1,171 +1,171 @@
|
||||||
"""Kokoro V1 model management."""
|
"""Kokoro V1 model management."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..core.model_config import ModelConfig, model_config
|
from ..core.model_config import ModelConfig, model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .kokoro_v1 import KokoroV1
|
from .kokoro_v1 import KokoroV1
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
"""Manages Kokoro V1 model loading and inference."""
|
"""Manages Kokoro V1 model loading and inference."""
|
||||||
|
|
||||||
# Singleton instance
|
# Singleton instance
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __init__(self, config: Optional[ModelConfig] = None):
|
def __init__(self, config: Optional[ModelConfig] = None):
|
||||||
"""Initialize manager.
|
"""Initialize manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional model configuration override
|
config: Optional model configuration override
|
||||||
"""
|
"""
|
||||||
self._config = config or model_config
|
self._config = config or model_config
|
||||||
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
|
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
|
||||||
self._device: Optional[str] = None
|
self._device: Optional[str] = None
|
||||||
|
|
||||||
def _determine_device(self) -> str:
|
def _determine_device(self) -> str:
|
||||||
"""Determine device based on settings."""
|
"""Determine device based on settings."""
|
||||||
return "cuda" if settings.use_gpu else "cpu"
|
return "cuda" if settings.use_gpu else "cpu"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Initialize Kokoro V1 backend."""
|
"""Initialize Kokoro V1 backend."""
|
||||||
try:
|
try:
|
||||||
self._device = self._determine_device()
|
self._device = self._determine_device()
|
||||||
logger.info(f"Initializing Kokoro V1 on {self._device}")
|
logger.info(f"Initializing Kokoro V1 on {self._device}")
|
||||||
self._backend = KokoroV1()
|
self._backend = KokoroV1()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
|
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
|
||||||
|
|
||||||
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
||||||
"""Initialize and warm up model.
|
"""Initialize and warm up model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice_manager: Voice manager instance for warmup
|
voice_manager: Voice manager instance for warmup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (device, backend type, voice count)
|
Tuple of (device, backend type, voice count)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If initialization fails
|
RuntimeError: If initialization fails
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize backend
|
# Initialize backend
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model_path = self._config.pytorch_kokoro_v1_file
|
model_path = self._config.pytorch_kokoro_v1_file
|
||||||
await self.load_model(model_path)
|
await self.load_model(model_path)
|
||||||
|
|
||||||
# Use paths module to get voice path
|
# Use paths module to get voice path
|
||||||
try:
|
try:
|
||||||
voices = await paths.list_voices()
|
voices = await paths.list_voices()
|
||||||
voice_path = await paths.get_voice_path(settings.default_voice)
|
voice_path = await paths.get_voice_path(settings.default_voice)
|
||||||
|
|
||||||
# Warm up with short text
|
# Warm up with short text
|
||||||
warmup_text = "Warmup text for initialization."
|
warmup_text = "Warmup text for initialization."
|
||||||
# Use default voice name for warmup
|
# Use default voice name for warmup
|
||||||
voice_name = settings.default_voice
|
voice_name = settings.default_voice
|
||||||
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
||||||
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to get default voice: {e}")
|
raise RuntimeError(f"Failed to get default voice: {e}")
|
||||||
|
|
||||||
ms = int((time.perf_counter() - start) * 1000)
|
ms = int((time.perf_counter() - start) * 1000)
|
||||||
logger.info(f"Warmup completed in {ms}ms")
|
logger.info(f"Warmup completed in {ms}ms")
|
||||||
|
|
||||||
return self._device, "kokoro_v1", len(voices)
|
return self._device, "kokoro_v1", len(voices)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
logger.error("""
|
logger.error("""
|
||||||
Model files not found! You need to download the Kokoro V1 model:
|
Model files not found! You need to download the Kokoro V1 model:
|
||||||
|
|
||||||
1. Download model using the script:
|
1. Download model using the script:
|
||||||
python docker/scripts/download_model.py --output api/src/models/v1_0
|
python docker/scripts/download_model.py --output api/src/models/v1_0
|
||||||
|
|
||||||
2. Or set environment variable in docker-compose:
|
2. Or set environment variable in docker-compose:
|
||||||
DOWNLOAD_MODEL=true
|
DOWNLOAD_MODEL=true
|
||||||
""")
|
""")
|
||||||
exit(0)
|
exit(0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Warmup failed: {e}")
|
raise RuntimeError(f"Warmup failed: {e}")
|
||||||
|
|
||||||
def get_backend(self) -> BaseModelBackend:
|
def get_backend(self) -> BaseModelBackend:
|
||||||
"""Get initialized backend.
|
"""Get initialized backend.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Initialized backend instance
|
Initialized backend instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If backend not initialized
|
RuntimeError: If backend not initialized
|
||||||
"""
|
"""
|
||||||
if not self._backend:
|
if not self._backend:
|
||||||
raise RuntimeError("Backend not initialized")
|
raise RuntimeError("Backend not initialized")
|
||||||
return self._backend
|
return self._backend
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load model using initialized backend.
|
"""Load model using initialized backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to model file
|
path: Path to model file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If loading fails
|
RuntimeError: If loading fails
|
||||||
"""
|
"""
|
||||||
if not self._backend:
|
if not self._backend:
|
||||||
raise RuntimeError("Backend not initialized")
|
raise RuntimeError("Backend not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._backend.load_model(path)
|
await self._backend.load_model(path)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load model: {e}")
|
raise RuntimeError(f"Failed to load model: {e}")
|
||||||
|
|
||||||
async def generate(self, *args, **kwargs):
|
async def generate(self, *args, **kwargs):
|
||||||
"""Generate audio using initialized backend.
|
"""Generate audio using initialized backend.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If generation fails
|
RuntimeError: If generation fails
|
||||||
"""
|
"""
|
||||||
if not self._backend:
|
if not self._backend:
|
||||||
raise RuntimeError("Backend not initialized")
|
raise RuntimeError("Backend not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in self._backend.generate(*args, **kwargs):
|
async for chunk in self._backend.generate(*args, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
def unload_all(self) -> None:
|
def unload_all(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
if self._backend:
|
if self._backend:
|
||||||
self._backend.unload()
|
self._backend.unload()
|
||||||
self._backend = None
|
self._backend = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_backend(self) -> str:
|
def current_backend(self) -> str:
|
||||||
"""Get current backend type."""
|
"""Get current backend type."""
|
||||||
return "kokoro_v1"
|
return "kokoro_v1"
|
||||||
|
|
||||||
|
|
||||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
"""Get model manager instance.
|
"""Get model manager instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional configuration override
|
config: Optional configuration override
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelManager instance
|
ModelManager instance
|
||||||
"""
|
"""
|
||||||
if ModelManager._instance is None:
|
if ModelManager._instance is None:
|
||||||
ModelManager._instance = ModelManager(config)
|
ModelManager._instance = ModelManager(config)
|
||||||
return ModelManager._instance
|
return ModelManager._instance
|
||||||
|
|
|
@ -7,6 +7,7 @@ Converts them into a format suitable for text-to-speech processing.
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import inflect
|
import inflect
|
||||||
|
from numpy import number
|
||||||
|
|
||||||
from ...structures.schemas import NormalizationOptions
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
|
@ -87,6 +88,8 @@ URL_PATTERN = re.compile(
|
||||||
|
|
||||||
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
|
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
|
||||||
|
|
||||||
|
TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
|
||||||
|
|
||||||
INFLECT_ENGINE=inflect.engine()
|
INFLECT_ENGINE=inflect.engine()
|
||||||
|
|
||||||
def split_num(num: re.Match[str]) -> str:
|
def split_num(num: re.Match[str]) -> str:
|
||||||
|
@ -136,10 +139,10 @@ def handle_money(m: re.Match[str]) -> str:
|
||||||
m = m.group()
|
m = m.group()
|
||||||
bill = "dollar" if m[0] == "$" else "pound"
|
bill = "dollar" if m[0] == "$" else "pound"
|
||||||
if m[-1].isalpha():
|
if m[-1].isalpha():
|
||||||
return f"{m[1:]} {bill}s"
|
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}s"
|
||||||
elif "." not in m:
|
elif "." not in m:
|
||||||
s = "" if m[1:] == "1" else "s"
|
s = "" if m[1:] == "1" else "s"
|
||||||
return f"{m[1:]} {bill}{s}"
|
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}{s}"
|
||||||
b, c = m[1:].split(".")
|
b, c = m[1:].split(".")
|
||||||
s = "" if b == "1" else "s"
|
s = "" if b == "1" else "s"
|
||||||
c = int(c.ljust(2, "0"))
|
c = int(c.ljust(2, "0"))
|
||||||
|
@ -148,7 +151,7 @@ def handle_money(m: re.Match[str]) -> str:
|
||||||
if m[0] == "$"
|
if m[0] == "$"
|
||||||
else ("penny" if c == 1 else "pence")
|
else ("penny" if c == 1 else "pence")
|
||||||
)
|
)
|
||||||
return f"{b} {bill}{s} and {c} {coins}"
|
return f"{INFLECT_ENGINE.number_to_words(b)} {bill}{s} and {INFLECT_ENGINE.number_to_words(c)} {coins}"
|
||||||
|
|
||||||
|
|
||||||
def handle_decimal(num: re.Match[str]) -> str:
|
def handle_decimal(num: re.Match[str]) -> str:
|
||||||
|
@ -214,6 +217,32 @@ def handle_url(u: re.Match[str]) -> str:
|
||||||
# Clean up extra spaces
|
# Clean up extra spaces
|
||||||
return re.sub(r"\s+", " ", url).strip()
|
return re.sub(r"\s+", " ", url).strip()
|
||||||
|
|
||||||
|
def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
|
p=list(p.groups())
|
||||||
|
|
||||||
|
country_code=""
|
||||||
|
if p[0] is not None:
|
||||||
|
p[0]=p[0].replace("+","")
|
||||||
|
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
||||||
|
|
||||||
|
area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
|
||||||
|
|
||||||
|
telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
|
||||||
|
|
||||||
|
line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
|
||||||
|
|
||||||
|
return ",".join([country_code,area_code,telephone_prefix,line_number])
|
||||||
|
|
||||||
|
def handle_time(t: re.Match[str]) -> str:
|
||||||
|
t=t.groups()
|
||||||
|
|
||||||
|
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
|
||||||
|
|
||||||
|
half=""
|
||||||
|
if t[2] is not None:
|
||||||
|
half=t[2].strip()
|
||||||
|
|
||||||
|
return numbers + half
|
||||||
|
|
||||||
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
|
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
|
@ -233,6 +262,10 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
if normalization_options.optional_pluralization_normalization:
|
if normalization_options.optional_pluralization_normalization:
|
||||||
text = re.sub(r"\(s\)","s",text)
|
text = re.sub(r"\(s\)","s",text)
|
||||||
|
|
||||||
|
# Replace phone numbers:
|
||||||
|
if normalization_options.phone_normalization:
|
||||||
|
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||||
|
@ -243,6 +276,9 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
||||||
text = text.replace(a, b + " ")
|
text = text.replace(a, b + " ")
|
||||||
|
|
||||||
|
# Handle simple time in the format of HH:MM:SS
|
||||||
|
text = TIME_PATTERN.sub(handle_time, text, )
|
||||||
|
|
||||||
# Clean up whitespace
|
# Clean up whitespace
|
||||||
text = re.sub(r"[^\S \n]", " ", text)
|
text = re.sub(r"[^\S \n]", " ", text)
|
||||||
text = re.sub(r" +", " ", text)
|
text = re.sub(r" +", " ", text)
|
||||||
|
@ -259,17 +295,18 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
||||||
|
|
||||||
# Handle numbers and money
|
# Handle numbers and money
|
||||||
text = re.sub(
|
|
||||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
|
||||||
)
|
|
||||||
|
|
||||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||||
|
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||||
handle_money,
|
handle_money,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text = re.sub(
|
||||||
|
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
||||||
|
)
|
||||||
|
|
||||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||||
|
|
||||||
# Handle various formatting
|
# Handle various formatting
|
||||||
|
|
|
@ -44,6 +44,7 @@ class NormalizationOptions(BaseModel):
|
||||||
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
||||||
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
||||||
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
|
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
|
||||||
|
phone_normalization: bool = Field(default=True, description="Changes phone numbers so they can be properly pronouced by kokoro")
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||||
|
|
|
@ -88,4 +88,4 @@ def test_non_url_text():
|
||||||
"""Test that non-URL text is unaffected"""
|
"""Test that non-URL text is unaffected"""
|
||||||
assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
|
assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
|
||||||
assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
|
assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
|
||||||
assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs 50 dollars."
|
assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs fifty dollars."
|
||||||
|
|
|
@ -14,4 +14,4 @@ export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
# Run FastAPI with CPU extras using uv run
|
# Run FastAPI with CPU extras using uv run
|
||||||
# Note: espeak may still require manual installation,
|
# Note: espeak may still require manual installation,
|
||||||
uv pip install -e ".[cpu]"
|
uv pip install -e ".[cpu]"
|
||||||
uv run uvicorn api.src.main:app --reload --host 0.0.0.0 --port 8880
|
uv run uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
||||||
|
|
|
@ -13,4 +13,4 @@ export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
|
|
||||||
# Run FastAPI with GPU extras using uv run
|
# Run FastAPI with GPU extras using uv run
|
||||||
uv pip install -e ".[gpu]"
|
uv pip install -e ".[gpu]"
|
||||||
uv run uvicorn api.src.main:app --reload --host 0.0.0.0 --port 8880
|
uv run uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
||||||
|
|
Loading…
Add table
Reference in a new issue