mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge branch 'remsky:master' into fixes
This commit is contained in:
commit
f2c5bc1b71
3 changed files with 24 additions and 6 deletions
|
@ -26,8 +26,8 @@ class StreamingAudioWriter:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
||||||
#print(av.codecs_available)
|
|
||||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
||||||
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
|
|
|
@ -270,7 +270,6 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||||
text = text.replace("(", "«").replace(")", "»")
|
|
||||||
|
|
||||||
# Handle CJK punctuation and some non standard chars
|
# Handle CJK punctuation and some non standard chars
|
||||||
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, List, Tuple
|
from typing import AsyncGenerator, Dict, List, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -12,6 +12,9 @@ from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
from ...structures.schemas import NormalizationOptions
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
|
# Pre-compiled regex patterns for performance
|
||||||
|
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
text: str, language: str = "a", skip_phonemize: bool = False
|
text: str, language: str = "a", skip_phonemize: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
@ -85,12 +88,21 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
return process_text_chunk(text, language)
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info."""
|
"""Process all sentences and return info."""
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
|
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(0, len(sentences), 2):
|
for i in range(0, len(sentences), 2):
|
||||||
sentence = sentences[i].strip()
|
sentence = sentences[i].strip()
|
||||||
|
for replaced in range(min_value, phoneme_length):
|
||||||
|
current_id = f"</|custom_phonemes_{replaced}|/>"
|
||||||
|
if current_id in sentence:
|
||||||
|
sentence = sentence.replace(current_id, custom_phenomes_list.pop(current_id))
|
||||||
|
min_value += 1
|
||||||
|
|
||||||
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
|
|
||||||
if not sentence:
|
if not sentence:
|
||||||
|
@ -102,6 +114,10 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> str:
|
||||||
|
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
|
||||||
|
phenomes_list[latest_id] = s.group(0).strip()
|
||||||
|
return latest_id
|
||||||
|
|
||||||
async def smart_split(
|
async def smart_split(
|
||||||
text: str,
|
text: str,
|
||||||
|
@ -114,15 +130,18 @@ async def smart_split(
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
logger.info(f"Starting smart split for {len(text)} chars")
|
logger.info(f"Starting smart split for {len(text)} chars")
|
||||||
|
|
||||||
|
custom_phoneme_list = {}
|
||||||
|
|
||||||
# Normalize text
|
# Normalize text
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
if lang_code in ["a","b","en-us","en-gb"]:
|
if lang_code in ["a","b","en-us","en-gb"]:
|
||||||
|
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
||||||
text=normalize_text(text,normalization_options)
|
text=normalize_text(text,normalization_options)
|
||||||
else:
|
else:
|
||||||
logger.info("Skipping text normalization as it is only supported for english")
|
logger.info("Skipping text normalization as it is only supported for english")
|
||||||
|
|
||||||
# Process all sentences
|
# Process all sentences
|
||||||
sentences = get_sentence_info(text)
|
sentences = get_sentence_info(text, custom_phoneme_list)
|
||||||
|
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
|
Loading…
Add table
Reference in a new issue