[WIP] Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael 2022-05-03 23:39:37 +08:00
parent 03cb42ba15
commit ab8f78f40f
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
5 changed files with 348 additions and 299 deletions

View file

@ -213,8 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets):
def __init__(self, string_mobject, **kwargs): def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString) assert isinstance(string_mobject, LabelledString)
grouped_mobject = VGroup(*[ grouped_mobject = string_mobject.build_parts_from_indices_lists([
part for _, part in string_mobject.get_group_part_items() indices_list
for _, indices_list in string_mobject.get_group_part_items()
]) ])
digest_config(self, kwargs) digest_config(self, kwargs)
if self.run_time is None: if self.run_time is None:

View file

@ -168,73 +168,67 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source, LabelledString) assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString) assert isinstance(target, LabelledString)
anims = [] anims = []
source_indices = list(range(len(source.labels)))
target_indices = list(range(len(target.labels)))
source_submobs = [ def get_filtered_indices_lists(indices_lists, rest_indices):
submob for _, submob in source.labelled_submobject_items
]
target_submobs = [
submob for _, submob in target.labelled_submobject_items
]
source_indices = list(range(len(source_submobs)))
target_indices = list(range(len(target_submobs)))
def get_filtered_indices_lists(parts, submobs, rest_indices):
return list(filter( return list(filter(
lambda indices_list: all([ lambda indices_list: all(
index in rest_indices index in rest_indices
for index in indices_list for index in indices_list
]), ),
[ indices_lists
[submobs.index(submob) for submob in part]
for part in parts
]
)) ))
def add_anims(anim_class, parts_pairs): def add_anims(anim_class, indices_lists_pairs):
for source_parts, target_parts in parts_pairs: for source_indices_lists, target_indices_lists in indices_lists_pairs:
source_indices_lists = get_filtered_indices_lists( source_indices_lists = get_filtered_indices_lists(
source_parts, source_submobs, source_indices source_indices_lists, source_indices
) )
target_indices_lists = get_filtered_indices_lists( target_indices_lists = get_filtered_indices_lists(
target_parts, target_submobs, target_indices target_indices_lists, target_indices
) )
if not source_indices_lists or not target_indices_lists: if not source_indices_lists or not target_indices_lists:
continue continue
anims.append(anim_class(source_parts, target_parts, **kwargs)) anims.append(anim_class(
source.build_parts_from_indices_lists(source_indices_lists),
target.build_parts_from_indices_lists(target_indices_lists),
**kwargs
))
for index in it.chain(*source_indices_lists): for index in it.chain(*source_indices_lists):
source_indices.remove(index) source_indices.remove(index)
for index in it.chain(*target_indices_lists): for index in it.chain(*target_indices_lists):
target_indices.remove(index) target_indices.remove(index)
def get_substr_to_parts_map(part_items): def get_substr_to_indices_lists_map(part_items):
result = {} result = {}
for substr, part in part_items: for substr, indices_list in part_items:
if substr not in result: if substr not in result:
result[substr] = [] result[substr] = []
result[substr].append(part) result[substr].append(indices_list)
return result return result
def add_anims_from(anim_class, func): def add_anims_from(anim_class, func):
source_substr_to_parts_map = get_substr_to_parts_map(func(source)) source_substr_map = get_substr_to_indices_lists_map(func(source))
target_substr_to_parts_map = get_substr_to_parts_map(func(target)) target_substr_map = get_substr_to_indices_lists_map(func(target))
common_substrings = sorted([
s for s in source_substr_map if s and s in target_substr_map
], key=len, reverse=True)
add_anims( add_anims(
anim_class, anim_class,
[ [
( (source_substr_map[substr], target_substr_map[substr])
VGroup(*source_substr_to_parts_map[substr]), for substr in common_substrings
VGroup(*target_substr_to_parts_map[substr])
)
for substr in sorted([
s for s in source_substr_to_parts_map
if s and s in target_substr_to_parts_map
], key=len, reverse=True)
] ]
) )
add_anims( add_anims(
ReplacementTransform, ReplacementTransform,
[ [
(source.select_parts(k), target.select_parts(v)) (
source.get_submob_indices_lists_by_selector(k),
target.get_submob_indices_lists_by_selector(v)
)
for k, v in self.key_map.items() for k, v in self.key_map.items()
] ]
) )

View file

@ -5,6 +5,7 @@ import itertools as it
import re import re
from manimlib.constants import WHITE from manimlib.constants import WHITE
from manimlib.logger import log
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_rgb from manimlib.utils.color import color_to_rgb
@ -63,6 +64,7 @@ class LabelledString(SVGMobject, ABC):
(submob.label, submob) (submob.label, submob)
for submob in self.submobjects for submob in self.submobjects
] ]
self.labels = [label for label, _ in self.labelled_submobject_items]
def get_file_path(self) -> str: def get_file_path(self) -> str:
return self.get_file_path_by_content(self.original_content) return self.get_file_path_by_content(self.original_content)
@ -78,26 +80,30 @@ class LabelledString(SVGMobject, ABC):
labelled_svg = SVGMobject(file_path) labelled_svg = SVGMobject(file_path)
num_submobjects = len(self.submobjects) num_submobjects = len(self.submobjects)
if num_submobjects != len(labelled_svg.submobjects): if num_submobjects != len(labelled_svg.submobjects):
raise ValueError( log.warning(
"Cannot align submobjects of the labelled svg " "Cannot align submobjects of the labelled svg "
"to the original svg" "to the original svg. Skip the labelling process."
) )
submob_color_ints = [0] * num_submobjects
else:
submob_color_ints = [ submob_color_ints = [
self.hex_to_int(self.color_to_hex(submob.get_fill_color())) self.hex_to_int(self.color_to_hex(submob.get_fill_color()))
for submob in labelled_svg.submobjects for submob in labelled_svg.submobjects
] ]
unrecognized_color_ints = self.remove_redundancies(sorted(filter( unrecognized_colors = list(filter(
lambda color_int: color_int > len(self.label_span_list), lambda color_int: color_int > len(self.labelled_spans),
submob_color_ints submob_color_ints
))) ))
if unrecognized_color_ints: if unrecognized_colors:
raise ValueError( log.warning(
"Unrecognized color label(s) detected: " "Unrecognized color label(s) detected (%s, etc). "
f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}" "Skip the labelling process.",
self.int_to_hex(unrecognized_colors[0])
) )
submob_color_ints = [0] * num_submobjects
#if self.sort_labelled_submobs: #if self.sort_labelled_submobs:
# TODO: remove this
submob_indices = sorted( submob_indices = sorted(
range(num_submobjects), range(num_submobjects),
key=lambda index: tuple( key=lambda index: tuple(
@ -135,12 +141,10 @@ class LabelledString(SVGMobject, ABC):
# pattern = re.compile(pattern) # pattern = re.compile(pattern)
# return re.compile(pattern).match(self.string, **kwargs) # return re.compile(pattern).match(self.string, **kwargs)
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: def find_spans(self, pattern: str) -> list[Span]:
if isinstance(pattern, str):
pattern = re.compile(pattern)
return [ return [
match_obj.span() match_obj.span()
for match_obj in pattern.finditer(self.string, **kwargs) for match_obj in re.finditer(pattern, self.string)
] ]
#def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
@ -151,7 +155,18 @@ class LabelledString(SVGMobject, ABC):
if isinstance(sel, str): if isinstance(sel, str):
return self.find_spans(re.escape(sel)) return self.find_spans(re.escape(sel))
if isinstance(sel, re.Pattern): if isinstance(sel, re.Pattern):
return self.find_spans(sel) result_iterator = sel.finditer(self.string)
if not sel.groups:
return [
match_obj.span()
for match_obj in result_iterator
]
return [
span
for match_obj in result_iterator
for span in match_obj.regs[1:]
if span != (-1, -1)
]
if isinstance(sel, tuple) and len(sel) == 2 and all( if isinstance(sel, tuple) and len(sel) == 2 and all(
isinstance(index, int) or index is None isinstance(index, int) or index is None
for index in sel for index in sel
@ -225,7 +240,7 @@ class LabelledString(SVGMobject, ABC):
def span_contains(span_0: Span, span_1: Span) -> bool: def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
def get_piece_items( def get_level_items(
self, self,
tag_span_pairs: list[tuple[Span, Span]], tag_span_pairs: list[tuple[Span, Span]],
entity_spans: list[Span] entity_spans: list[Span]
@ -241,7 +256,7 @@ class LabelledString(SVGMobject, ABC):
piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])]
return piece_spans, piece_levels return piece_spans, piece_levels
def split_span(self, arbitrary_span: Span) -> list[Span]: def split_span_by_levels(self, arbitrary_span: Span) -> list[Span]:
# ignorable_indices -- # ignorable_indices --
# left_bracket_spans # left_bracket_spans
# right_bracket_spans # right_bracket_spans
@ -413,10 +428,10 @@ class LabelledString(SVGMobject, ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_tag_str( def get_tag_string_pair(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool attr_dict: dict[str, str], label_hex: str | None
) -> str: ) -> tuple[str, str]:
return "" return ("", "")
#def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str:
# return self.get_tag_str({ # return self.get_tag_str({
@ -481,7 +496,7 @@ class LabelledString(SVGMobject, ABC):
def parse(self) -> None: def parse(self) -> None:
self.entity_spans = self.get_entity_spans() self.entity_spans = self.get_entity_spans()
tag_span_pairs, internal_items = self.get_internal_items() tag_span_pairs, internal_items = self.get_internal_items()
self.piece_spans, self.piece_levels = self.get_piece_items( self.piece_spans, self.piece_levels = self.get_level_items(
tag_span_pairs, self.entity_spans tag_span_pairs, self.entity_spans
) )
#self.tag_content_spans = [ #self.tag_content_spans = [
@ -497,26 +512,19 @@ class LabelledString(SVGMobject, ABC):
for span in self.find_spans_by_selector(self.isolate) for span in self.find_spans_by_selector(self.isolate)
] ]
) )
print(f"\n{specified_items=}\n") #print(f"\n{specified_items=}\n")
specified_spans = [span for span, _ in specified_items] #specified_spans =
for span_0, span_1 in it.product(specified_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
split_items = [ split_items = [
(span, attr_dict) (span, attr_dict)
for specified_span, attr_dict in specified_items for specified_span, attr_dict in specified_items
for span in self.split_span(specified_span) for span in self.split_span_by_levels(specified_span)
] ]
print(f"\n{split_items=}\n") #print(f"\n{split_items=}\n")
split_spans = [span for span, _ in split_items] #labelled_spans = [span for span, _ in split_items]
label_span_list = self.get_label_span_list(split_spans) #labelled_spans = self.get_labelled_spans(split_spans)
if len(label_span_list) >= 16777216: #if len(labelled_spans) >= 16777216:
raise ValueError("Cannot handle that many substrings") # raise ValueError("Cannot handle that many substrings")
#content_strings = [] #content_strings = []
#for is_labelled in (False, True): #for is_labelled in (False, True):
@ -549,17 +557,66 @@ class LabelledString(SVGMobject, ABC):
# for flag in range(2) # for flag in range(2)
#] #]
self.specified_spans = specified_spans command_repl_items = self.get_command_repl_items()
self.label_span_list = label_span_list
self.original_content = self.get_full_content_string( #full_content_strings = {}
label_span_list, split_items, is_labelled=False #for is_labelled in (False, True):
# inserted_str_pairs = [
# (span, self.get_tag_string_pair(
# attr_dict,
# rgb_hex=self.int_to_hex(label + 1) if is_labelled else None
# ))
# for label, (span, attr_dict) in enumerate(split_items)
# ]
# repl_items = self.chain(
# command_repl_items,
# [
# ((index, index), inserted_str)
# for index, inserted_str
# in self.sort_obj_pairs_by_spans(inserted_str_pairs)
# ]
# )
# content_string = self.get_replaced_substr(
# self.full_span, repl_items
# )
# full_content_string = self.get_full_content_string(content_string)
# #full_content_strings[is_labelled] = full_content_string
self.specified_spans = [span for span, _ in specified_items]
self.labelled_spans = [span for span, _ in split_items]
for span_0, span_1 in it.product(self.labelled_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
) )
self.labelled_content = self.get_full_content_string(
label_span_list, split_items, is_labelled=True self.original_content, self.labelled_content = (
self.get_full_content_string(self.get_replaced_substr(
self.full_span, self.chain(
command_repl_items,
[
((index, index), inserted_str)
for index, inserted_str in self.sort_obj_pairs_by_spans([
(span, self.get_tag_string_pair(
attr_dict,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
))
for label, (span, attr_dict) in enumerate(split_items)
])
]
) )
print(self.original_content) ), is_labelled=is_labelled)
print() for is_labelled in (False, True)
print(self.labelled_content) )
#self.original_content = full_content_strings[False]
#self.labelled_content = full_content_strings[True]
#print(self.original_content)
#print()
#print(self.labelled_content)
#self.command_repl_dict = self.get_command_repl_dict() #self.command_repl_dict = self.get_command_repl_dict()
@ -569,8 +626,8 @@ class LabelledString(SVGMobject, ABC):
##self.specified_items = self.get_specified_items() ##self.specified_items = self.get_specified_items()
#self.specified_spans = [] #self.specified_spans = []
#self.check_overlapping() ####### #self.check_overlapping() #######
#self.label_span_list = [] #self.labelled_spans = []
#if len(self.label_span_list) >= 16777216: #if len(self.labelled_spans) >= 16777216:
# raise ValueError("Cannot handle that many substrings") # raise ValueError("Cannot handle that many substrings")
@abstractmethod @abstractmethod
@ -636,9 +693,9 @@ class LabelledString(SVGMobject, ABC):
#def get_split_items(self, specified_items: list[T]) -> list[T]: #def get_split_items(self, specified_items: list[T]) -> list[T]:
# return [] # return []
@abstractmethod #@abstractmethod
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: #def get_labelled_spans(self, split_spans: list[Span]) -> list[Span]:
return [] # return []
#@abstractmethod #@abstractmethod
#def get_predefined_inserted_str_items( #def get_predefined_inserted_str_items(
@ -666,7 +723,7 @@ class LabelledString(SVGMobject, ABC):
# return [] # return []
#@abstractmethod #@abstractmethod
#def get_label_span_list(self) -> list[Span]: #def get_labelled_spans(self) -> list[Span]:
# return [] # return []
#def get_decorated_string( #def get_decorated_string(
@ -694,56 +751,19 @@ class LabelledString(SVGMobject, ABC):
# repl_items.extend(self.command_repl_items) # repl_items.extend(self.command_repl_items)
# return self.get_replaced_substr(self.full_span, repl_items) # return self.get_replaced_substr(self.full_span, repl_items)
#@abstractmethod
#def get_additional_inserted_str_pairs(
# self
#) -> list[tuple[Span, tuple[str, str]]]:
# return []
@abstractmethod @abstractmethod
def get_additional_inserted_str_pairs( def get_command_repl_items(self) -> list[Span, str]:
self
) -> list[tuple[Span, tuple[str, str]]]:
return [] return []
@abstractmethod @abstractmethod
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
return [] return ""
def get_full_content_string(
self,
label_span_list: list[Span],
split_items: list[tuple[Span, dict[str, str]]],
is_labelled: bool
) -> str:
label_items = [
(span, {
"foreground": self.int_to_hex(label + 1)
} if is_labelled else {})
for label, span in enumerate(label_span_list)
]
inserted_str_pairs = self.chain(
self.get_additional_inserted_str_pairs(),
[
(span, tuple(
self.get_tag_str(
attr_dict,
escape_color_keys=is_labelled and not is_label_item,
is_begin_tag=is_begin_tag
)
for is_begin_tag in (True, False)
))
for is_label_item, items in enumerate((
split_items, label_items
))
for span, attr_dict in items
]
)
repl_items = self.chain(
self.get_command_repl_items(is_labelled),
[
((index, index), inserted_str)
for index, inserted_str
in self.sort_obj_pairs_by_spans(inserted_str_pairs)
]
)
return self.get_replaced_substr(
self.full_span, repl_items
)
#def get_content(self, is_labelled: bool) -> str: #def get_content(self, is_labelled: bool) -> str:
# return self.content_strings[int(is_labelled)] # return self.content_strings[int(is_labelled)]
@ -754,16 +774,15 @@ class LabelledString(SVGMobject, ABC):
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
return "" return ""
def get_group_part_items(self) -> list[tuple[str, VGroup]]: def get_group_part_items(self) -> list[tuple[str, list[int]]]:
if not self.labelled_submobject_items: if not self.labels:
return [] return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_ranges = zip( group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(labels) *self.compress_neighbours(self.labels)
) )
ordered_spans = [ ordered_spans = [
self.label_span_list[label] if label != -1 else self.full_span self.labelled_spans[label] if label != -1 else self.full_span
for label in group_labels for label in group_labels
] ]
interval_spans = [ interval_spans = [
@ -785,37 +804,67 @@ class LabelledString(SVGMobject, ABC):
(ordered_spans[0][0], ordered_spans[-1][1]), interval_spans (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans
) )
] ]
submob_groups = VGroup(*[ submob_indices_lists = [
VGroup(*labelled_submobjects[slice(*submob_range)]) list(range(*submob_range))
for submob_range in labelled_submob_ranges for submob_range in labelled_submob_ranges
]) ]
return list(zip(group_substrs, submob_groups)) return list(zip(group_substrs, submob_indices_lists))
def get_specified_part_items(self) -> list[tuple[str, VGroup]]: def get_submob_indices_list_by_span(
self, arbitrary_span: Span
) -> list[int]:
return [
submob_index
for submob_index, label in enumerate(self.labels)
if label != -1 and self.span_contains(
arbitrary_span, self.labelled_spans[label]
)
]
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
return [ return [
( (
self.get_substr(span), self.get_substr(span),
self.select_part_by_span(span) self.get_submob_indices_list_by_span(span)
) )
for span in self.specified_spans for span in self.specified_spans
] ]
def select_part_by_span(self, arbitrary_span: Span) -> VGroup: def get_submob_indices_lists_by_selector(
return VGroup(*[ self, selector: Selector
submob for label, submob in self.labelled_submobject_items ) -> list[list[int]]:
if label != -1 return list(filter(
and self.span_contains(arbitrary_span, self.label_span_list[label]) lambda indices_list: indices_list,
])
def select_parts(self, selector: Selector) -> VGroup:
return VGroup(*filter(
lambda part: part.submobjects,
[ [
self.select_part_by_span(span) self.get_submob_indices_list_by_span(span)
for span in self.find_spans_by_selector(selector) for span in self.find_spans_by_selector(selector)
] ]
)) ))
def build_parts_from_indices_lists(
self, submob_indices_lists: list[list[int]]
) -> VGroup:
return VGroup(*[
VGroup(*[
self.labelled_submobject_items[submob_index][1]
for submob_index in indices_list
])
for indices_list in submob_indices_lists
])
#def select_part_by_span(self, arbitrary_span: Span) -> VGroup:
# return VGroup(*[
# self.labelled_submobject_items[submob_index]
# for submob_index in self.get_submob_indices_list_by_span(
# arbitrary_span
# )
# ])
def select_parts(self, selector: Selector) -> VGroup:
return self.build_parts_from_indices_lists(
self.get_submob_indices_lists_by_selector(selector)
)
def select_part(self, selector: Selector, index: int = 0) -> VGroup: def select_part(self, selector: Selector, index: int = 0) -> VGroup:
return self.select_parts(selector)[index] return self.select_parts(selector)[index]

View file

@ -31,14 +31,14 @@ if TYPE_CHECKING:
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_COLOR_COMMANDS_DICT = { #TEX_COLOR_COMMANDS_DICT = {
"\\color": (1, False), # "\\color": (1, False),
"\\textcolor": (1, False), # "\\textcolor": (1, False),
"\\pagecolor": (1, True), # "\\pagecolor": (1, True),
"\\colorbox": (1, True), # "\\colorbox": (1, True),
"\\fcolorbox": (2, True), # "\\fcolorbox": (2, True),
} #}
TEX_COLOR_COMMAND_SUFFIX = "replaced" #TEX_COLOR_COMMAND_SUFFIX = "replaced"
class MTex(LabelledString): class MTex(LabelledString):
@ -56,7 +56,7 @@ class MTex(LabelledString):
self.tex_string = tex_string self.tex_string = tex_string
super().__init__(tex_string, **kwargs) super().__init__(tex_string, **kwargs)
#self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@property @property
@ -97,16 +97,12 @@ class MTex(LabelledString):
return f"\\color[RGB]{{{r}, {g}, {b}}}" return f"\\color[RGB]{{{r}, {g}, {b}}}"
@staticmethod @staticmethod
def get_tag_str( def get_tag_string_pair(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool attr_dict: dict[str, str], label_hex: str | None
) -> str: ) -> tuple[str, str]:
if escape_color_keys: if label_hex is None:
return "" return ("", "")
if not is_begin_tag: return ("{{" + MTex.get_color_command_str(label_hex), "}}")
return "}}"
if "foreground" not in attr_dict:
return "{{"
return "{{" + MTex.get_color_command_str(attr_dict["foreground"])
#@staticmethod #@staticmethod
#def shrink_span(span: Span, skippable_indices: list[int]) -> Span: #def shrink_span(span: Span, skippable_indices: list[int]) -> Span:
@ -223,20 +219,20 @@ class MTex(LabelledString):
raise ValueError("Missing '}' inserted") raise ValueError("Missing '}' inserted")
#tag_span_pairs = brace_span_pairs.copy() #tag_span_pairs = brace_span_pairs.copy()
script_entity_dict = dict(self.chain( #script_entity_dict = dict(self.chain(
[ # [
(span_begin, span_end) # (span_begin, span_end)
for (span_begin, _), (_, span_end) in brace_span_pairs # for (span_begin, _), (_, span_end) in brace_span_pairs
], # ],
command_spans # command_spans
)) #))
script_additional_brace_spans = [ #script_additional_brace_spans = [
(char_index + 1, script_entity_dict.get( # (char_index + 1, script_entity_dict.get(
script_begin, script_begin + 1 # script_begin, script_begin + 1
)) # ))
for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") # for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)")
if (char_index - 1, char_index + 1) not in command_spans # if (char_index - 1, char_index + 1) not in command_spans
] #]
#for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"):
# if (char_index - 1, char_index + 1) in command_spans: # if (char_index - 1, char_index + 1) in command_spans:
# continue # continue
@ -246,13 +242,13 @@ class MTex(LabelledString):
# ) # )
# script_additional_brace_spans.append((char_index + 1, script_end)) # script_additional_brace_spans.append((char_index + 1, script_end))
tag_span_pairs = self.chain( #tag_span_pairs = self.chain(
brace_span_pairs, # brace_span_pairs,
[ # [
((script_begin - 1, script_begin), (script_end, script_end)) # ((script_begin - 1, script_begin), (script_end, script_end))
for script_begin, script_end in script_additional_brace_spans # for script_begin, script_end in script_additional_brace_spans
] # ]
) #)
brace_content_spans = [ brace_content_spans = [
(span_begin, span_end) (span_begin, span_end)
@ -268,16 +264,19 @@ class MTex(LabelledString):
]) ])
if range_end - range_begin >= 2 if range_end - range_begin >= 2
] ]
self.script_additional_brace_spans = script_additional_brace_spans #self.script_additional_brace_spans = script_additional_brace_spans
return tag_span_pairs, internal_items return brace_span_pairs, internal_items
def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: def get_external_items(self) -> list[tuple[Span, dict[str, str]]]:
return [ return [
(span, {"foreground": self.color_to_hex(color)}) (span, {})
for selector, color in self.tex_to_color_map.items() for selector in self.tex_to_color_map
for span in self.find_spans_by_selector(selector) for span in self.find_spans_by_selector(selector)
] ]
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
# return split_spans.copy()
#def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]:
# return specified_items # return specified_items
@ -287,29 +286,30 @@ class MTex(LabelledString):
# for span in specified_items # for span in specified_items
# ])) # ]))
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
return split_spans # return split_spans
def get_additional_inserted_str_pairs( #def get_additional_inserted_str_pairs(
self # self
) -> list[tuple[Span, tuple[str, str]]]: #) -> list[tuple[Span, tuple[str, str]]]:
return [ # return [
(span, ("{", "}")) # (span, ("{", "}"))
for span in self.script_additional_brace_spans # for span in self.script_additional_brace_spans
] # ]
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: def get_command_repl_items(self) -> list[Span, str]:
if not is_labelled:
return [] return []
result = [] #if not is_labelled:
command_spans = self.entity_spans # TODO # return []
for cmd_span in command_spans: #result = []
cmd_str = self.get_substr(cmd_span) #command_spans = self.entity_spans # TODO
if cmd_str not in TEX_COLOR_COMMANDS_DICT: #for cmd_span in command_spans:
continue # cmd_str = self.get_substr(cmd_span)
repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" # if cmd_str not in TEX_COLOR_COMMANDS_DICT:
result.append((cmd_span, repl_str)) # continue
return result # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}"
# result.append((cmd_span, repl_str))
#return result
#def get_predefined_inserted_str_items( #def get_predefined_inserted_str_items(
# self, split_items: list[Span] # self, split_items: list[Span]
@ -558,15 +558,8 @@ class MTex(LabelledString):
# for label, span in enumerate(self.label_span_list) # for label, span in enumerate(self.label_span_list)
# ] # ]
def get_full_content_string( def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
self, result = content_string
label_span_list: list[Span],
split_items: list[tuple[Span, dict[str, str]]],
is_labelled: bool
) -> str:
result = super().get_full_content_string(
label_span_list, split_items, is_labelled
)
if self.tex_environment: if self.tex_environment:
if isinstance(self.tex_environment, str): if isinstance(self.tex_environment, str):
@ -578,25 +571,25 @@ class MTex(LabelledString):
if self.alignment: if self.alignment:
result = "\n".join([self.alignment, result]) result = "\n".join([self.alignment, result])
if is_labelled: #if is_labelled:
occurred_commands = [ # occurred_commands = [
# TODO # # TODO
self.get_substr(span) for span in self.entity_spans # self.get_substr(span) for span in self.entity_spans
] # ]
newcommand_lines = [ # newcommand_lines = [
"".join([ # "".join([
f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", # f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}",
f"[{n_braces + 1}][]", # f"[{n_braces + 1}][]",
"{", # "{",
cmd_name + "{black}" * n_braces if substitute_cmd else "", # cmd_name + "{black}" * n_braces if substitute_cmd else "",
"}" # "}"
]) # ])
for cmd_name, (n_braces, substitute_cmd) # for cmd_name, (n_braces, substitute_cmd)
in TEX_COLOR_COMMANDS_DICT.items() # in TEX_COLOR_COMMANDS_DICT.items()
if cmd_name in occurred_commands # if cmd_name in occurred_commands
] # ]
result = "\n".join([*newcommand_lines, result]) # result = "\n".join([*newcommand_lines, result])
else: if not is_labelled:
result = "\n".join([ result = "\n".join([
self.get_color_command_str(self.base_color_hex), self.get_color_command_str(self.base_color_hex),
result result

View file

@ -114,6 +114,7 @@ class MarkupText(LabelledString):
"t2w": {}, "t2w": {},
"global_config": {}, "global_config": {},
"local_configs": {}, "local_configs": {},
"split_words": True,
} }
def __init__(self, text: str, **kwargs): def __init__(self, text: str, **kwargs):
@ -162,7 +163,8 @@ class MarkupText(LabelledString):
self.t2s, self.t2s,
self.t2w, self.t2w,
self.global_config, self.global_config,
self.local_configs self.local_configs,
self.split_words
) )
def full2short(self, config: dict) -> None: def full2short(self, config: dict) -> None:
@ -250,28 +252,26 @@ class MarkupText(LabelledString):
# Toolkits # Toolkits
@staticmethod @staticmethod
def get_tag_str( def get_tag_string_pair(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool attr_dict: dict[str, str], label_hex: str | None
) -> str: ) -> tuple[str, str]:
if not is_begin_tag: if label_hex is not None:
return "</span>" converted_attr_dict = {"foreground": label_hex}
if escape_color_keys:
converted_attr_dict = {}
for key, val in attr_dict.items(): for key, val in attr_dict.items():
substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None)
if substitute_key is None: if substitute_key is None:
converted_attr_dict[key] = val converted_attr_dict[key] = val
elif substitute_key: elif substitute_key:
converted_attr_dict[key] = "black" converted_attr_dict[key] = "black"
else: #else:
converted_attr_dict[key] = "black" # converted_attr_dict[key] = "black"
else: else:
converted_attr_dict = attr_dict.copy() converted_attr_dict = attr_dict.copy()
result = " ".join([ attrs_str = " ".join([
f"{key}='{val}'" f"{key}='{val}'"
for key, val in converted_attr_dict.items() for key, val in converted_attr_dict.items()
]) ])
return f"<span {result}>" return (f"<span {attrs_str}>", "</span>")
def get_global_attr_dict(self) -> dict[str, str]: def get_global_attr_dict(self) -> dict[str, str]:
result = { result = {
@ -286,8 +286,9 @@ class MarkupText(LabelledString):
if tuple(map(int, pango_version.split("."))) < (1, 50): if tuple(map(int, pango_version.split("."))) < (1, 50):
if self.lsh is not None: if self.lsh is not None:
log.warning( log.warning(
f"Pango version {pango_version} found (< 1.50), " "Pango version %s found (< 1.50), "
"unable to set `line_height` attribute" "unable to set `line_height` attribute",
pango_version
) )
else: else:
line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE
@ -477,8 +478,8 @@ class MarkupText(LabelledString):
if not self.is_markup: if not self.is_markup:
return [], [] return [], []
tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>"
attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2"
begin_match_obj_stack = [] begin_match_obj_stack = []
markup_tag_items = [] markup_tag_items = []
for match_obj in re.finditer(tag_pattern, self.string): for match_obj in re.finditer(tag_pattern, self.string):
@ -511,7 +512,7 @@ class MarkupText(LabelledString):
return tag_span_pairs, internal_items return tag_span_pairs, internal_items
def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: def get_external_items(self) -> list[tuple[Span, dict[str, str]]]:
return [ result = [
(self.full_span, self.get_global_attr_dict()), (self.full_span, self.get_global_attr_dict()),
(self.full_span, self.global_config), (self.full_span, self.global_config),
*[ *[
@ -531,6 +532,17 @@ class MarkupText(LabelledString):
for span in self.find_spans_by_selector(selector) for span in self.find_spans_by_selector(selector)
] ]
] ]
if self.split_words:
# For backward compatibility
result.extend([
(span, {})
for span in self.find_spans(r"[a-zA-Z]+")
for pattern in (r"[a-zA-Z]+", r"\S+")
])
return result
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
#def get_spans_from_items( #def get_spans_from_items(
# self, specified_items: list[tuple[Span, dict[str, str]]] # self, specified_items: list[tuple[Span, dict[str, str]]]
@ -546,31 +558,31 @@ class MarkupText(LabelledString):
# for span in self.split_span(specified_span) # for span in self.split_span(specified_span)
# ] # ]
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
interval_spans = sorted(self.chain( # interval_spans = sorted(self.chain(
self.tag_spans, # self.tag_spans,
[ # [
(index, index) # (index, index)
for span in split_spans # for span in split_spans
for index in span # for index in span
] # ]
)) # ))
text_spans = self.get_complement_spans(self.full_span, interval_spans) # text_spans = self.get_complement_spans(self.full_span, interval_spans)
if self.is_markup: # if self.is_markup:
pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+"
else: # else:
pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+"
return self.chain(*[ # return self.chain(*[
self.find_spans(pattern, pos=span_begin, endpos=span_end) # self.find_spans(pattern, pos=span_begin, endpos=span_end)
for span_begin, span_end in text_spans # for span_begin, span_end in text_spans
]) # ])
def get_additional_inserted_str_pairs( #def get_additional_inserted_str_pairs(
self # self
) -> list[tuple[Span, tuple[str, str]]]: #) -> list[tuple[Span, tuple[str, str]]]:
return [] # return []
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: def get_command_repl_items(self) -> list[Span, str]:
result = [ result = [
(tag_span, "") for tag_span in self.tag_spans (tag_span, "") for tag_span in self.tag_spans
] ]
@ -755,8 +767,8 @@ class MarkupText(LabelledString):
# for span, attr_dict in attr_dict_items # for span, attr_dict in attr_dict_items
# ] # ]
def get_content(self, is_labelled: bool) -> str: def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
return self.decorated_strings[is_labelled] return content_string
# Selector # Selector