mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pull request #256 from fireblade2534/Fixing-number-normalization
Some checks failed
CI / test (3.10) (push) Has been cancelled
Some checks failed
CI / test (3.10) (push) Has been cancelled
This commit is contained in:
commit
fe99bb7697
15 changed files with 165 additions and 36 deletions
31
README.md
31
README.md
|
@ -516,7 +516,36 @@ Monitor system state and resource usage with these endpoints:
|
|||
Useful for debugging resource exhaustion or performance issues.
|
||||
</details>
|
||||
|
||||
## Known Issues
|
||||
## Known Issues & Troubleshooting
|
||||
|
||||
<details>
|
||||
<summary>Missing words & Missing some timestamps</summary>
|
||||
|
||||
The api will automaticly do text normalization on input text which may incorrectly remove or change some phrases. This can be disabled by adding `"normalization_options":{"normalize": false}` to your request json:
|
||||
```python
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"input": "Hello world!",
|
||||
"voice": "af_heart",
|
||||
"response_format": "pcm",
|
||||
"normalization_options":
|
||||
{
|
||||
"normalize": False
|
||||
}
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
# Process streaming chunks
|
||||
pass
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Versioning & Development</summary>
|
||||
|
|
|
@ -125,20 +125,18 @@ async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_se
|
|||
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
unique_properties = {"return_timestamps": False}
|
||||
if hasattr(request, "return_timestamps"):
|
||||
unique_properties["return_timestamps"] = request.return_timestamps
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk_data in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
writer=writer,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code or settings.default_voice_code or voice_name[0].lower(),
|
||||
lang_code=request.lang_code,
|
||||
normalization_options=request.normalization_options,
|
||||
return_timestamps=unique_properties["return_timestamps"],
|
||||
):
|
||||
|
|
|
@ -25,7 +25,7 @@ class StreamingAudioWriter:
|
|||
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
||||
if self.format != "pcm":
|
||||
self.output_buffer = BytesIO()
|
||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
||||
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
|
||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
||||
self.stream.bit_rate = 128000
|
||||
else:
|
||||
|
|
|
@ -8,9 +8,11 @@ import re
|
|||
from functools import lru_cache
|
||||
import inflect
|
||||
from numpy import number
|
||||
|
||||
from torch import mul
|
||||
from ...structures.schemas import NormalizationOptions
|
||||
|
||||
from text_to_num import text2num
|
||||
|
||||
# Constants
|
||||
VALID_TLDS = [
|
||||
"com",
|
||||
|
@ -134,25 +136,35 @@ def handle_units(u: re.Match[str]) -> str:
|
|||
unit[0]=INFLECT_ENGINE.no(unit[0],number)
|
||||
return " ".join(unit)
|
||||
|
||||
def conditional_int(number: float, threshold: float = 0.00001):
|
||||
if abs(round(number) - number) < threshold:
|
||||
return int(round(number))
|
||||
return number
|
||||
|
||||
def handle_money(m: re.Match[str]) -> str:
|
||||
"""Convert money expressions to spoken form"""
|
||||
m = m.group()
|
||||
bill = "dollar" if m[0] == "$" else "pound"
|
||||
if m[-1].isalpha():
|
||||
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}s"
|
||||
elif "." not in m:
|
||||
s = "" if m[1:] == "1" else "s"
|
||||
return f"{INFLECT_ENGINE.number_to_words(m[1:])} {bill}{s}"
|
||||
b, c = m[1:].split(".")
|
||||
s = "" if b == "1" else "s"
|
||||
c = int(c.ljust(2, "0"))
|
||||
coins = (
|
||||
f"cent{'' if c == 1 else 's'}"
|
||||
if m[0] == "$"
|
||||
else ("penny" if c == 1 else "pence")
|
||||
)
|
||||
return f"{INFLECT_ENGINE.number_to_words(b)} {bill}{s} and {INFLECT_ENGINE.number_to_words(c)} {coins}"
|
||||
|
||||
bill = "dollar" if m.group(2) == "$" else "pound"
|
||||
coin = "cent" if m.group(2) == "$" else "pence"
|
||||
number = m.group(3)
|
||||
|
||||
multiplier = m.group(4)
|
||||
try:
|
||||
number = float(number)
|
||||
except:
|
||||
return m.group()
|
||||
|
||||
if m.group(1) == "-":
|
||||
number *= -1
|
||||
|
||||
if number % 1 == 0 or multiplier != "":
|
||||
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
||||
else:
|
||||
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
||||
|
||||
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
||||
|
||||
return text_number
|
||||
|
||||
def handle_decimal(num: re.Match[str]) -> str:
|
||||
"""Convert decimal numbers to spoken form"""
|
||||
|
@ -297,7 +309,7 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
|||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||
|
||||
text = re.sub(
|
||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
||||
handle_money,
|
||||
text,
|
||||
)
|
||||
|
|
|
@ -134,6 +134,7 @@ async def smart_split(
|
|||
|
||||
# Normalize text
|
||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||
print(lang_code)
|
||||
if lang_code in ["a","b","en-us","en-gb"]:
|
||||
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
||||
text=normalize_text(text,normalization_options)
|
||||
|
|
|
@ -258,7 +258,7 @@ class TTSService:
|
|||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text, lang_code=lang_code, normalization_options=normalization_options):
|
||||
async for chunk_text, tokens in smart_split(text, lang_code=pipeline_lang_code, normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
|
|
|
@ -23,12 +23,11 @@ def test_initial_state(kokoro_backend):
|
|||
|
||||
|
||||
@patch("torch.cuda.is_available", return_value=True)
|
||||
@patch("torch.cuda.memory_allocated")
|
||||
@patch("torch.cuda.memory_allocated", return_value=5e9)
|
||||
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
||||
"""Test GPU memory management functions."""
|
||||
# Mock GPU memory usage
|
||||
mock_memory.return_value = 5e9 # 5GB
|
||||
|
||||
# Patch backend so it thinks we have cuda
|
||||
with patch.object(kokoro_backend, "_device", "cuda"):
|
||||
# Test memory check
|
||||
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
|
||||
mock_config.pytorch_gpu.memory_threshold = 4
|
||||
|
|
|
@ -83,6 +83,11 @@ def test_url_email_addresses():
|
|||
== "Send to test dot user at site dot com"
|
||||
)
|
||||
|
||||
def test_money():
|
||||
"""Test that money text is normalized correctly"""
|
||||
assert normalize_text("He lost $5.3 thousand.",normalization_options=NormalizationOptions()) == "He lost five point three thousand dollars."
|
||||
assert normalize_text("To put it weirdly -$6.9 million",normalization_options=NormalizationOptions()) == "To put it weirdly minus six point nine million dollars"
|
||||
assert normalize_text("It costs $50.3.",normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents."
|
||||
|
||||
def test_non_url_text():
|
||||
"""Test that non-URL text is unaffected"""
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
|
|||
def test_get_sentence_info():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = "This is sentence one. This is sentence two! What about three?"
|
||||
results = get_sentence_info(text)
|
||||
results = get_sentence_info(text, {})
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
|
@ -44,6 +44,19 @@ def test_get_sentence_info():
|
|||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
def test_get_sentence_info_phenomoes():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||
|
||||
assert len(results) == 3
|
||||
assert "sˈɛntᵊns" in results[1][0]
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_short_text():
|
||||
|
|
26
dev/Test money.py
Normal file
26
dev/Test money.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import requests
|
||||
import base64
|
||||
import json
|
||||
|
||||
text="""the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
||||
|
||||
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
||||
|
||||
|
||||
Type="wav"
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af_heart+af_sky",
|
||||
"speed": 1.0,
|
||||
"response_format": Type,
|
||||
"stream": False,
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
||||
f.write(response.content)
|
45
dev/Test num.py
Normal file
45
dev/Test num.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from text_to_num import text2num
|
||||
import re
|
||||
import inflect
|
||||
from torch import mul
|
||||
|
||||
INFLECT_ENGINE = inflect.engine()
|
||||
|
||||
|
||||
def conditional_int(number: float, threshold: float = 0.00001):
|
||||
if abs(round(number) - number) < threshold:
|
||||
return int(round(number))
|
||||
return number
|
||||
|
||||
def handle_money(m: re.Match[str]) -> str:
|
||||
"""Convert money expressions to spoken form"""
|
||||
|
||||
bill = "dollar" if m.group(2) == "$" else "pound"
|
||||
coin = "cent" if m.group(2) == "$" else "pence"
|
||||
number = m.group(3)
|
||||
|
||||
multiplier = m.group(4)
|
||||
try:
|
||||
number = float(number)
|
||||
except:
|
||||
return m.group()
|
||||
|
||||
if m.group(1) == "-":
|
||||
number *= -1
|
||||
|
||||
if number % 1 == 0 or multiplier != "":
|
||||
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
||||
else:
|
||||
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
||||
|
||||
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
||||
|
||||
return text_number
|
||||
|
||||
|
||||
text = re.sub(
|
||||
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
||||
handle_money,
|
||||
"he administration has offered up a platter of repression for more than a year and is still slated to lose -$5.3 billion",
|
||||
)
|
||||
print(text)
|
|
@ -38,6 +38,7 @@ dependencies = [
|
|||
"inflect>=7.5.0",
|
||||
"phonemizer-fork>=3.3.2",
|
||||
"av>=14.2.0",
|
||||
"text2num>=2.5.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
Loading…
Add table
Reference in a new issue