mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pull request #256 from fireblade2534/Fixing-number-normalization
Some checks failed
CI / test (3.10) (push) Has been cancelled
Some checks failed
CI / test (3.10) (push) Has been cancelled
This commit is contained in:
commit
fe99bb7697
15 changed files with 165 additions and 36 deletions
31
README.md
31
README.md
|
@ -516,7 +516,36 @@ Monitor system state and resource usage with these endpoints:
|
||||||
Useful for debugging resource exhaustion or performance issues.
|
Useful for debugging resource exhaustion or performance issues.
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Known Issues
|
## Known Issues & Troubleshooting
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Missing words & Missing some timestamps</summary>
|
||||||
|
|
||||||
|
The api will automaticly do text normalization on input text which may incorrectly remove or change some phrases. This can be disabled by adding `"normalization_options":{"normalize": false}` to your request json:
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/v1/audio/speech",
|
||||||
|
json={
|
||||||
|
"input": "Hello world!",
|
||||||
|
"voice": "af_heart",
|
||||||
|
"response_format": "pcm",
|
||||||
|
"normalization_options":
|
||||||
|
{
|
||||||
|
"normalize": False
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
# Process streaming chunks
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Versioning & Development</summary>
|
<summary>Versioning & Development</summary>
|
||||||
|
|
|
@ -125,20 +125,18 @@ async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_se
|
||||||
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
|
|
||||||
unique_properties = {"return_timestamps": False}
|
unique_properties = {"return_timestamps": False}
|
||||||
if hasattr(request, "return_timestamps"):
|
if hasattr(request, "return_timestamps"):
|
||||||
unique_properties["return_timestamps"] = request.return_timestamps
|
unique_properties["return_timestamps"] = request.return_timestamps
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
|
||||||
async for chunk_data in tts_service.generate_audio_stream(
|
async for chunk_data in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
|
lang_code=request.lang_code,
|
||||||
normalization_options=request.normalization_options,
|
normalization_options=request.normalization_options,
|
||||||
return_timestamps=unique_properties["return_timestamps"],
|
return_timestamps=unique_properties["return_timestamps"],
|
||||||
):
|
):
|
||||||
|
|
|
@ -25,7 +25,7 @@ class StreamingAudioWriter:
|
||||||
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
|
||||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
||||||
self.stream.bit_rate = 128000
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,9 +8,11 @@ import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import inflect
|
import inflect
|
||||||
from numpy import number
|
from numpy import number
|
||||||
|
from torch import mul
|
||||||
from ...structures.schemas import NormalizationOptions
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
|
from text_to_num import text2num
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
VALID_TLDS = [
|
VALID_TLDS = [
|
||||||
"com",
|
"com",
|
||||||
|
@ -134,25 +136,35 @@ def handle_units(u: re.Match[str]) -> str:
|
||||||
unit[0]=INFLECT_ENGINE.no(unit[0],number)
|
unit[0]=INFLECT_ENGINE.no(unit[0],number)
|
||||||
return " ".join(unit)
|
return " ".join(unit)
|
||||||
|
|
||||||
|
def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
|
if abs(round(number) - number) < threshold:
|
||||||
|
return int(round(number))
|
||||||
|
return number
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
m = m.group()
|
|
||||||
bill = "dollar" if m[0] == "$" else "pound"
|
|
||||||
if m[-1].isalpha():
|
|
||||||
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}s"
|
|
||||||
elif "." not in m:
|
|
||||||
s = "" if m[1:] == "1" else "s"
|
|
||||||
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}{s}"
|
|
||||||
b, c = m[1:].split(".")
|
|
||||||
s = "" if b == "1" else "s"
|
|
||||||
c = int(c.ljust(2, "0"))
|
|
||||||
coins = (
|
|
||||||
f"cent{'' if c == 1 else 's'}"
|
|
||||||
if m[0] == "$"
|
|
||||||
else ("penny" if c == 1 else "pence")
|
|
||||||
)
|
|
||||||
return f"{INFLECT_ENGINE.number_to_words(b)} {bill}{s} and {INFLECT_ENGINE.number_to_words(c)} {coins}"
|
|
||||||
|
|
||||||
|
bill = "dollar" if m.group(2) == "$" else "pound"
|
||||||
|
coin = "cent" if m.group(2) == "$" else "pence"
|
||||||
|
number = m.group(3)
|
||||||
|
|
||||||
|
multiplier = m.group(4)
|
||||||
|
try:
|
||||||
|
number = float(number)
|
||||||
|
except:
|
||||||
|
return m.group()
|
||||||
|
|
||||||
|
if m.group(1) == "-":
|
||||||
|
number *= -1
|
||||||
|
|
||||||
|
if number % 1 == 0 or multiplier != "":
|
||||||
|
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
||||||
|
else:
|
||||||
|
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
||||||
|
|
||||||
|
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
||||||
|
|
||||||
|
return text_number
|
||||||
|
|
||||||
def handle_decimal(num: re.Match[str]) -> str:
|
def handle_decimal(num: re.Match[str]) -> str:
|
||||||
"""Convert decimal numbers to spoken form"""
|
"""Convert decimal numbers to spoken form"""
|
||||||
|
@ -297,7 +309,7 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||||
|
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
||||||
handle_money,
|
handle_money,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
|
@ -134,6 +134,7 @@ async def smart_split(
|
||||||
|
|
||||||
# Normalize text
|
# Normalize text
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
|
print(lang_code)
|
||||||
if lang_code in ["a","b","en-us","en-gb"]:
|
if lang_code in ["a","b","en-us","en-gb"]:
|
||||||
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
||||||
text=normalize_text(text,normalization_options)
|
text=normalize_text(text,normalization_options)
|
||||||
|
|
|
@ -258,7 +258,7 @@ class TTSService:
|
||||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting
|
||||||
async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options):
|
async for chunk_text, tokens in smart_split(text, lang_code=pipeline_lang_code, normalization_options=normalization_options):
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
async for chunk_data in self._process_chunk(
|
async for chunk_data in self._process_chunk(
|
||||||
|
|
|
@ -23,12 +23,11 @@ def test_initial_state(kokoro_backend):
|
||||||
|
|
||||||
|
|
||||||
@patch("torch.cuda.is_available", return_value=True)
|
@patch("torch.cuda.is_available", return_value=True)
|
||||||
@patch("torch.cuda.memory_allocated")
|
@patch("torch.cuda.memory_allocated", return_value=5e9)
|
||||||
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
||||||
"""Test GPU memory management functions."""
|
"""Test GPU memory management functions."""
|
||||||
# Mock GPU memory usage
|
# Patch backend so it thinks we have cuda
|
||||||
mock_memory.return_value = 5e9 # 5GB
|
with patch.object(kokoro_backend, "_device", "cuda"):
|
||||||
|
|
||||||
# Test memory check
|
# Test memory check
|
||||||
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
|
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
|
||||||
mock_config.pytorch_gpu.memory_threshold = 4
|
mock_config.pytorch_gpu.memory_threshold = 4
|
||||||
|
|
|
@ -83,6 +83,11 @@ def test_url_email_addresses():
|
||||||
== "Send to test dot user at site dot com"
|
== "Send to test dot user at site dot com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_money():
|
||||||
|
"""Test that money text is normalized correctly"""
|
||||||
|
assert normalize_text("He lost $5.3 thousand.",normalization_options=NormalizationOptions()) == "He lost five point three thousand dollars."
|
||||||
|
assert normalize_text("To put it weirdly -$6.9 million",normalization_options=NormalizationOptions()) == "To put it weirdly minus six point nine million dollars"
|
||||||
|
assert normalize_text("It costs $50.3.",normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents."
|
||||||
|
|
||||||
def test_non_url_text():
|
def test_non_url_text():
|
||||||
"""Test that non-URL text is unaffected"""
|
"""Test that non-URL text is unaffected"""
|
||||||
|
|
|
@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
|
||||||
def test_get_sentence_info():
|
def test_get_sentence_info():
|
||||||
"""Test sentence splitting and info extraction."""
|
"""Test sentence splitting and info extraction."""
|
||||||
text = "This is sentence one. This is sentence two! What about three?"
|
text = "This is sentence one. This is sentence two! What about three?"
|
||||||
results = get_sentence_info(text)
|
results = get_sentence_info(text, {})
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
|
@ -44,6 +44,19 @@ def test_get_sentence_info():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
def test_get_sentence_info_phenomoes():
|
||||||
|
"""Test sentence splitting and info extraction."""
|
||||||
|
text = "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||||
|
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
assert "sˈɛntᵊns" in results[1][0]
|
||||||
|
for sentence, tokens, count in results:
|
||||||
|
assert isinstance(sentence, str)
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count == len(tokens)
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_short_text():
|
async def test_smart_split_short_text():
|
||||||
|
|
26
dev/Test money.py
Normal file
26
dev/Test money.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
|
text="""the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
||||||
|
|
||||||
|
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
||||||
|
|
||||||
|
|
||||||
|
Type="wav"
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/v1/audio/speech",
|
||||||
|
json={
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": text,
|
||||||
|
"voice": "af_heart+af_sky",
|
||||||
|
"speed": 1.0,
|
||||||
|
"response_format": Type,
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
||||||
|
f.write(response.content)
|
45
dev/Test num.py
Normal file
45
dev/Test num.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from text_to_num import text2num
|
||||||
|
import re
|
||||||
|
import inflect
|
||||||
|
from torch import mul
|
||||||
|
|
||||||
|
INFLECT_ENGINE = inflect.engine()
|
||||||
|
|
||||||
|
|
||||||
|
def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
|
if abs(round(number) - number) < threshold:
|
||||||
|
return int(round(number))
|
||||||
|
return number
|
||||||
|
|
||||||
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
|
bill = "dollar" if m.group(2) == "$" else "pound"
|
||||||
|
coin = "cent" if m.group(2) == "$" else "pence"
|
||||||
|
number = m.group(3)
|
||||||
|
|
||||||
|
multiplier = m.group(4)
|
||||||
|
try:
|
||||||
|
number = float(number)
|
||||||
|
except:
|
||||||
|
return m.group()
|
||||||
|
|
||||||
|
if m.group(1) == "-":
|
||||||
|
number *= -1
|
||||||
|
|
||||||
|
if number % 1 == 0 or multiplier != "":
|
||||||
|
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
||||||
|
else:
|
||||||
|
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
||||||
|
|
||||||
|
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
||||||
|
|
||||||
|
return text_number
|
||||||
|
|
||||||
|
|
||||||
|
text = re.sub(
|
||||||
|
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
||||||
|
handle_money,
|
||||||
|
"he administration has offered up a platter of repression for more than a year and is still slated to lose -$5.3 billion",
|
||||||
|
)
|
||||||
|
print(text)
|
|
@ -38,6 +38,7 @@ dependencies = [
|
||||||
"inflect>=7.5.0",
|
"inflect>=7.5.0",
|
||||||
"phonemizer-fork>=3.3.2",
|
"phonemizer-fork>=3.3.2",
|
||||||
"av>=14.2.0",
|
"av>=14.2.0",
|
||||||
|
"text2num>=2.5.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
Loading…
Add table
Reference in a new issue