[WIP] Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael 2022-05-03 23:39:37 +08:00
parent 03cb42ba15
commit ab8f78f40f
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
5 changed files with 348 additions and 299 deletions

View file

@ -213,8 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets):
def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString)
grouped_mobject = VGroup(*[
part for _, part in string_mobject.get_group_part_items()
grouped_mobject = string_mobject.build_parts_from_indices_lists([
indices_list
for _, indices_list in string_mobject.get_group_part_items()
])
digest_config(self, kwargs)
if self.run_time is None:

View file

@ -168,73 +168,67 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString)
anims = []
source_indices = list(range(len(source.labels)))
target_indices = list(range(len(target.labels)))
source_submobs = [
submob for _, submob in source.labelled_submobject_items
]
target_submobs = [
submob for _, submob in target.labelled_submobject_items
]
source_indices = list(range(len(source_submobs)))
target_indices = list(range(len(target_submobs)))
def get_filtered_indices_lists(parts, submobs, rest_indices):
def get_filtered_indices_lists(indices_lists, rest_indices):
return list(filter(
lambda indices_list: all([
lambda indices_list: all(
index in rest_indices
for index in indices_list
]),
[
[submobs.index(submob) for submob in part]
for part in parts
]
),
indices_lists
))
def add_anims(anim_class, parts_pairs):
for source_parts, target_parts in parts_pairs:
def add_anims(anim_class, indices_lists_pairs):
for source_indices_lists, target_indices_lists in indices_lists_pairs:
source_indices_lists = get_filtered_indices_lists(
source_parts, source_submobs, source_indices
source_indices_lists, source_indices
)
target_indices_lists = get_filtered_indices_lists(
target_parts, target_submobs, target_indices
target_indices_lists, target_indices
)
if not source_indices_lists or not target_indices_lists:
continue
anims.append(anim_class(source_parts, target_parts, **kwargs))
anims.append(anim_class(
source.build_parts_from_indices_lists(source_indices_lists),
target.build_parts_from_indices_lists(target_indices_lists),
**kwargs
))
for index in it.chain(*source_indices_lists):
source_indices.remove(index)
for index in it.chain(*target_indices_lists):
target_indices.remove(index)
def get_substr_to_parts_map(part_items):
def get_substr_to_indices_lists_map(part_items):
result = {}
for substr, part in part_items:
for substr, indices_list in part_items:
if substr not in result:
result[substr] = []
result[substr].append(part)
result[substr].append(indices_list)
return result
def add_anims_from(anim_class, func):
source_substr_to_parts_map = get_substr_to_parts_map(func(source))
target_substr_to_parts_map = get_substr_to_parts_map(func(target))
source_substr_map = get_substr_to_indices_lists_map(func(source))
target_substr_map = get_substr_to_indices_lists_map(func(target))
common_substrings = sorted([
s for s in source_substr_map if s and s in target_substr_map
], key=len, reverse=True)
add_anims(
anim_class,
[
(
VGroup(*source_substr_to_parts_map[substr]),
VGroup(*target_substr_to_parts_map[substr])
)
for substr in sorted([
s for s in source_substr_to_parts_map
if s and s in target_substr_to_parts_map
], key=len, reverse=True)
(source_substr_map[substr], target_substr_map[substr])
for substr in common_substrings
]
)
add_anims(
ReplacementTransform,
[
(source.select_parts(k), target.select_parts(v))
(
source.get_submob_indices_lists_by_selector(k),
target.get_submob_indices_lists_by_selector(v)
)
for k, v in self.key_map.items()
]
)

View file

@ -5,6 +5,7 @@ import itertools as it
import re
from manimlib.constants import WHITE
from manimlib.logger import log
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_rgb
@ -63,6 +64,7 @@ class LabelledString(SVGMobject, ABC):
(submob.label, submob)
for submob in self.submobjects
]
self.labels = [label for label, _ in self.labelled_submobject_items]
def get_file_path(self) -> str:
return self.get_file_path_by_content(self.original_content)
@ -78,26 +80,30 @@ class LabelledString(SVGMobject, ABC):
labelled_svg = SVGMobject(file_path)
num_submobjects = len(self.submobjects)
if num_submobjects != len(labelled_svg.submobjects):
raise ValueError(
log.warning(
"Cannot align submobjects of the labelled svg "
"to the original svg"
)
submob_color_ints = [
self.hex_to_int(self.color_to_hex(submob.get_fill_color()))
for submob in labelled_svg.submobjects
]
unrecognized_color_ints = self.remove_redundancies(sorted(filter(
lambda color_int: color_int > len(self.label_span_list),
submob_color_ints
)))
if unrecognized_color_ints:
raise ValueError(
"Unrecognized color label(s) detected: "
f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}"
"to the original svg. Skip the labelling process."
)
submob_color_ints = [0] * num_submobjects
else:
submob_color_ints = [
self.hex_to_int(self.color_to_hex(submob.get_fill_color()))
for submob in labelled_svg.submobjects
]
unrecognized_colors = list(filter(
lambda color_int: color_int > len(self.labelled_spans),
submob_color_ints
))
if unrecognized_colors:
log.warning(
"Unrecognized color label(s) detected (%s, etc). "
"Skip the labelling process.",
self.int_to_hex(unrecognized_colors[0])
)
submob_color_ints = [0] * num_submobjects
#if self.sort_labelled_submobs:
# TODO: remove this
submob_indices = sorted(
range(num_submobjects),
key=lambda index: tuple(
@ -135,12 +141,10 @@ class LabelledString(SVGMobject, ABC):
# pattern = re.compile(pattern)
# return re.compile(pattern).match(self.string, **kwargs)
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]:
if isinstance(pattern, str):
pattern = re.compile(pattern)
def find_spans(self, pattern: str) -> list[Span]:
return [
match_obj.span()
for match_obj in pattern.finditer(self.string, **kwargs)
for match_obj in re.finditer(pattern, self.string)
]
#def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
@ -151,7 +155,18 @@ class LabelledString(SVGMobject, ABC):
if isinstance(sel, str):
return self.find_spans(re.escape(sel))
if isinstance(sel, re.Pattern):
return self.find_spans(sel)
result_iterator = sel.finditer(self.string)
if not sel.groups:
return [
match_obj.span()
for match_obj in result_iterator
]
return [
span
for match_obj in result_iterator
for span in match_obj.regs[1:]
if span != (-1, -1)
]
if isinstance(sel, tuple) and len(sel) == 2 and all(
isinstance(index, int) or index is None
for index in sel
@ -225,7 +240,7 @@ class LabelledString(SVGMobject, ABC):
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
def get_piece_items(
def get_level_items(
self,
tag_span_pairs: list[tuple[Span, Span]],
entity_spans: list[Span]
@ -241,7 +256,7 @@ class LabelledString(SVGMobject, ABC):
piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])]
return piece_spans, piece_levels
def split_span(self, arbitrary_span: Span) -> list[Span]:
def split_span_by_levels(self, arbitrary_span: Span) -> list[Span]:
# ignorable_indices --
# left_bracket_spans
# right_bracket_spans
@ -413,10 +428,10 @@ class LabelledString(SVGMobject, ABC):
@staticmethod
@abstractmethod
def get_tag_str(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool
) -> str:
return ""
def get_tag_string_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
return ("", "")
#def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str:
# return self.get_tag_str({
@ -481,7 +496,7 @@ class LabelledString(SVGMobject, ABC):
def parse(self) -> None:
self.entity_spans = self.get_entity_spans()
tag_span_pairs, internal_items = self.get_internal_items()
self.piece_spans, self.piece_levels = self.get_piece_items(
self.piece_spans, self.piece_levels = self.get_level_items(
tag_span_pairs, self.entity_spans
)
#self.tag_content_spans = [
@ -497,26 +512,19 @@ class LabelledString(SVGMobject, ABC):
for span in self.find_spans_by_selector(self.isolate)
]
)
print(f"\n{specified_items=}\n")
specified_spans = [span for span, _ in specified_items]
for span_0, span_1 in it.product(specified_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
#print(f"\n{specified_items=}\n")
#specified_spans =
split_items = [
(span, attr_dict)
for specified_span, attr_dict in specified_items
for span in self.split_span(specified_span)
for span in self.split_span_by_levels(specified_span)
]
print(f"\n{split_items=}\n")
split_spans = [span for span, _ in split_items]
label_span_list = self.get_label_span_list(split_spans)
if len(label_span_list) >= 16777216:
raise ValueError("Cannot handle that many substrings")
#print(f"\n{split_items=}\n")
#labelled_spans = [span for span, _ in split_items]
#labelled_spans = self.get_labelled_spans(split_spans)
#if len(labelled_spans) >= 16777216:
# raise ValueError("Cannot handle that many substrings")
#content_strings = []
#for is_labelled in (False, True):
@ -549,17 +557,66 @@ class LabelledString(SVGMobject, ABC):
# for flag in range(2)
#]
self.specified_spans = specified_spans
self.label_span_list = label_span_list
self.original_content = self.get_full_content_string(
label_span_list, split_items, is_labelled=False
command_repl_items = self.get_command_repl_items()
#full_content_strings = {}
#for is_labelled in (False, True):
# inserted_str_pairs = [
# (span, self.get_tag_string_pair(
# attr_dict,
# rgb_hex=self.int_to_hex(label + 1) if is_labelled else None
# ))
# for label, (span, attr_dict) in enumerate(split_items)
# ]
# repl_items = self.chain(
# command_repl_items,
# [
# ((index, index), inserted_str)
# for index, inserted_str
# in self.sort_obj_pairs_by_spans(inserted_str_pairs)
# ]
# )
# content_string = self.get_replaced_substr(
# self.full_span, repl_items
# )
# full_content_string = self.get_full_content_string(content_string)
# #full_content_strings[is_labelled] = full_content_string
self.specified_spans = [span for span, _ in specified_items]
self.labelled_spans = [span for span, _ in split_items]
for span_0, span_1 in it.product(self.labelled_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
self.original_content, self.labelled_content = (
self.get_full_content_string(self.get_replaced_substr(
self.full_span, self.chain(
command_repl_items,
[
((index, index), inserted_str)
for index, inserted_str in self.sort_obj_pairs_by_spans([
(span, self.get_tag_string_pair(
attr_dict,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
))
for label, (span, attr_dict) in enumerate(split_items)
])
]
)
), is_labelled=is_labelled)
for is_labelled in (False, True)
)
self.labelled_content = self.get_full_content_string(
label_span_list, split_items, is_labelled=True
)
print(self.original_content)
print()
print(self.labelled_content)
#self.original_content = full_content_strings[False]
#self.labelled_content = full_content_strings[True]
#print(self.original_content)
#print()
#print(self.labelled_content)
#self.command_repl_dict = self.get_command_repl_dict()
@ -569,8 +626,8 @@ class LabelledString(SVGMobject, ABC):
##self.specified_items = self.get_specified_items()
#self.specified_spans = []
#self.check_overlapping() #######
#self.label_span_list = []
#if len(self.label_span_list) >= 16777216:
#self.labelled_spans = []
#if len(self.labelled_spans) >= 16777216:
# raise ValueError("Cannot handle that many substrings")
@abstractmethod
@ -636,9 +693,9 @@ class LabelledString(SVGMobject, ABC):
#def get_split_items(self, specified_items: list[T]) -> list[T]:
# return []
@abstractmethod
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
return []
#@abstractmethod
#def get_labelled_spans(self, split_spans: list[Span]) -> list[Span]:
# return []
#@abstractmethod
#def get_predefined_inserted_str_items(
@ -666,7 +723,7 @@ class LabelledString(SVGMobject, ABC):
# return []
#@abstractmethod
#def get_label_span_list(self) -> list[Span]:
#def get_labelled_spans(self) -> list[Span]:
# return []
#def get_decorated_string(
@ -694,56 +751,19 @@ class LabelledString(SVGMobject, ABC):
# repl_items.extend(self.command_repl_items)
# return self.get_replaced_substr(self.full_span, repl_items)
#@abstractmethod
#def get_additional_inserted_str_pairs(
# self
#) -> list[tuple[Span, tuple[str, str]]]:
# return []
@abstractmethod
def get_additional_inserted_str_pairs(
self
) -> list[tuple[Span, tuple[str, str]]]:
def get_command_repl_items(self) -> list[Span, str]:
return []
@abstractmethod
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]:
return []
def get_full_content_string(
self,
label_span_list: list[Span],
split_items: list[tuple[Span, dict[str, str]]],
is_labelled: bool
) -> str:
label_items = [
(span, {
"foreground": self.int_to_hex(label + 1)
} if is_labelled else {})
for label, span in enumerate(label_span_list)
]
inserted_str_pairs = self.chain(
self.get_additional_inserted_str_pairs(),
[
(span, tuple(
self.get_tag_str(
attr_dict,
escape_color_keys=is_labelled and not is_label_item,
is_begin_tag=is_begin_tag
)
for is_begin_tag in (True, False)
))
for is_label_item, items in enumerate((
split_items, label_items
))
for span, attr_dict in items
]
)
repl_items = self.chain(
self.get_command_repl_items(is_labelled),
[
((index, index), inserted_str)
for index, inserted_str
in self.sort_obj_pairs_by_spans(inserted_str_pairs)
]
)
return self.get_replaced_substr(
self.full_span, repl_items
)
def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
return ""
#def get_content(self, is_labelled: bool) -> str:
# return self.content_strings[int(is_labelled)]
@ -754,16 +774,15 @@ class LabelledString(SVGMobject, ABC):
def get_cleaned_substr(self, span: Span) -> str:
return ""
def get_group_part_items(self) -> list[tuple[str, VGroup]]:
if not self.labelled_submobject_items:
def get_group_part_items(self) -> list[tuple[str, list[int]]]:
if not self.labels:
return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(labels)
*self.compress_neighbours(self.labels)
)
ordered_spans = [
self.label_span_list[label] if label != -1 else self.full_span
self.labelled_spans[label] if label != -1 else self.full_span
for label in group_labels
]
interval_spans = [
@ -785,37 +804,67 @@ class LabelledString(SVGMobject, ABC):
(ordered_spans[0][0], ordered_spans[-1][1]), interval_spans
)
]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_range)])
submob_indices_lists = [
list(range(*submob_range))
for submob_range in labelled_submob_ranges
])
return list(zip(group_substrs, submob_groups))
]
return list(zip(group_substrs, submob_indices_lists))
def get_specified_part_items(self) -> list[tuple[str, VGroup]]:
def get_submob_indices_list_by_span(
self, arbitrary_span: Span
) -> list[int]:
return [
submob_index
for submob_index, label in enumerate(self.labels)
if label != -1 and self.span_contains(
arbitrary_span, self.labelled_spans[label]
)
]
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
return [
(
self.get_substr(span),
self.select_part_by_span(span)
self.get_submob_indices_list_by_span(span)
)
for span in self.specified_spans
]
def select_part_by_span(self, arbitrary_span: Span) -> VGroup:
return VGroup(*[
submob for label, submob in self.labelled_submobject_items
if label != -1
and self.span_contains(arbitrary_span, self.label_span_list[label])
])
def select_parts(self, selector: Selector) -> VGroup:
return VGroup(*filter(
lambda part: part.submobjects,
def get_submob_indices_lists_by_selector(
self, selector: Selector
) -> list[list[int]]:
return list(filter(
lambda indices_list: indices_list,
[
self.select_part_by_span(span)
self.get_submob_indices_list_by_span(span)
for span in self.find_spans_by_selector(selector)
]
))
def build_parts_from_indices_lists(
self, submob_indices_lists: list[list[int]]
) -> VGroup:
return VGroup(*[
VGroup(*[
self.labelled_submobject_items[submob_index][1]
for submob_index in indices_list
])
for indices_list in submob_indices_lists
])
#def select_part_by_span(self, arbitrary_span: Span) -> VGroup:
# return VGroup(*[
# self.labelled_submobject_items[submob_index]
# for submob_index in self.get_submob_indices_list_by_span(
# arbitrary_span
# )
# ])
def select_parts(self, selector: Selector) -> VGroup:
return self.build_parts_from_indices_lists(
self.get_submob_indices_lists_by_selector(selector)
)
def select_part(self, selector: Selector, index: int = 0) -> VGroup:
return self.select_parts(selector)[index]

View file

@ -31,14 +31,14 @@ if TYPE_CHECKING:
SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_COLOR_COMMANDS_DICT = {
"\\color": (1, False),
"\\textcolor": (1, False),
"\\pagecolor": (1, True),
"\\colorbox": (1, True),
"\\fcolorbox": (2, True),
}
TEX_COLOR_COMMAND_SUFFIX = "replaced"
#TEX_COLOR_COMMANDS_DICT = {
# "\\color": (1, False),
# "\\textcolor": (1, False),
# "\\pagecolor": (1, True),
# "\\colorbox": (1, True),
# "\\fcolorbox": (2, True),
#}
#TEX_COLOR_COMMAND_SUFFIX = "replaced"
class MTex(LabelledString):
@ -56,7 +56,7 @@ class MTex(LabelledString):
self.tex_string = tex_string
super().__init__(tex_string, **kwargs)
#self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@property
@ -97,16 +97,12 @@ class MTex(LabelledString):
return f"\\color[RGB]{{{r}, {g}, {b}}}"
@staticmethod
def get_tag_str(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool
) -> str:
if escape_color_keys:
return ""
if not is_begin_tag:
return "}}"
if "foreground" not in attr_dict:
return "{{"
return "{{" + MTex.get_color_command_str(attr_dict["foreground"])
def get_tag_string_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
if label_hex is None:
return ("", "")
return ("{{" + MTex.get_color_command_str(label_hex), "}}")
#@staticmethod
#def shrink_span(span: Span, skippable_indices: list[int]) -> Span:
@ -223,20 +219,20 @@ class MTex(LabelledString):
raise ValueError("Missing '}' inserted")
#tag_span_pairs = brace_span_pairs.copy()
script_entity_dict = dict(self.chain(
[
(span_begin, span_end)
for (span_begin, _), (_, span_end) in brace_span_pairs
],
command_spans
))
script_additional_brace_spans = [
(char_index + 1, script_entity_dict.get(
script_begin, script_begin + 1
))
for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)")
if (char_index - 1, char_index + 1) not in command_spans
]
#script_entity_dict = dict(self.chain(
# [
# (span_begin, span_end)
# for (span_begin, _), (_, span_end) in brace_span_pairs
# ],
# command_spans
#))
#script_additional_brace_spans = [
# (char_index + 1, script_entity_dict.get(
# script_begin, script_begin + 1
# ))
# for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)")
# if (char_index - 1, char_index + 1) not in command_spans
#]
#for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"):
# if (char_index - 1, char_index + 1) in command_spans:
# continue
@ -246,13 +242,13 @@ class MTex(LabelledString):
# )
# script_additional_brace_spans.append((char_index + 1, script_end))
tag_span_pairs = self.chain(
brace_span_pairs,
[
((script_begin - 1, script_begin), (script_end, script_end))
for script_begin, script_end in script_additional_brace_spans
]
)
#tag_span_pairs = self.chain(
# brace_span_pairs,
# [
# ((script_begin - 1, script_begin), (script_end, script_end))
# for script_begin, script_end in script_additional_brace_spans
# ]
#)
brace_content_spans = [
(span_begin, span_end)
@ -268,16 +264,19 @@ class MTex(LabelledString):
])
if range_end - range_begin >= 2
]
self.script_additional_brace_spans = script_additional_brace_spans
return tag_span_pairs, internal_items
#self.script_additional_brace_spans = script_additional_brace_spans
return brace_span_pairs, internal_items
def get_external_items(self) -> list[tuple[Span, dict[str, str]]]:
return [
(span, {"foreground": self.color_to_hex(color)})
for selector, color in self.tex_to_color_map.items()
(span, {})
for selector in self.tex_to_color_map
for span in self.find_spans_by_selector(selector)
]
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
# return split_spans.copy()
#def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]:
# return specified_items
@ -287,29 +286,30 @@ class MTex(LabelledString):
# for span in specified_items
# ]))
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
return split_spans
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
# return split_spans
def get_additional_inserted_str_pairs(
self
) -> list[tuple[Span, tuple[str, str]]]:
return [
(span, ("{", "}"))
for span in self.script_additional_brace_spans
]
#def get_additional_inserted_str_pairs(
# self
#) -> list[tuple[Span, tuple[str, str]]]:
# return [
# (span, ("{", "}"))
# for span in self.script_additional_brace_spans
# ]
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]:
if not is_labelled:
return []
result = []
command_spans = self.entity_spans # TODO
for cmd_span in command_spans:
cmd_str = self.get_substr(cmd_span)
if cmd_str not in TEX_COLOR_COMMANDS_DICT:
continue
repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}"
result.append((cmd_span, repl_str))
return result
def get_command_repl_items(self) -> list[Span, str]:
return []
#if not is_labelled:
# return []
#result = []
#command_spans = self.entity_spans # TODO
#for cmd_span in command_spans:
# cmd_str = self.get_substr(cmd_span)
# if cmd_str not in TEX_COLOR_COMMANDS_DICT:
# continue
# repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}"
# result.append((cmd_span, repl_str))
#return result
#def get_predefined_inserted_str_items(
# self, split_items: list[Span]
@ -558,15 +558,8 @@ class MTex(LabelledString):
# for label, span in enumerate(self.label_span_list)
# ]
def get_full_content_string(
self,
label_span_list: list[Span],
split_items: list[tuple[Span, dict[str, str]]],
is_labelled: bool
) -> str:
result = super().get_full_content_string(
label_span_list, split_items, is_labelled
)
def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
result = content_string
if self.tex_environment:
if isinstance(self.tex_environment, str):
@ -578,25 +571,25 @@ class MTex(LabelledString):
if self.alignment:
result = "\n".join([self.alignment, result])
if is_labelled:
occurred_commands = [
# TODO
self.get_substr(span) for span in self.entity_spans
]
newcommand_lines = [
"".join([
f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}",
f"[{n_braces + 1}][]",
"{",
cmd_name + "{black}" * n_braces if substitute_cmd else "",
"}"
])
for cmd_name, (n_braces, substitute_cmd)
in TEX_COLOR_COMMANDS_DICT.items()
if cmd_name in occurred_commands
]
result = "\n".join([*newcommand_lines, result])
else:
#if is_labelled:
# occurred_commands = [
# # TODO
# self.get_substr(span) for span in self.entity_spans
# ]
# newcommand_lines = [
# "".join([
# f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}",
# f"[{n_braces + 1}][]",
# "{",
# cmd_name + "{black}" * n_braces if substitute_cmd else "",
# "}"
# ])
# for cmd_name, (n_braces, substitute_cmd)
# in TEX_COLOR_COMMANDS_DICT.items()
# if cmd_name in occurred_commands
# ]
# result = "\n".join([*newcommand_lines, result])
if not is_labelled:
result = "\n".join([
self.get_color_command_str(self.base_color_hex),
result

View file

@ -114,6 +114,7 @@ class MarkupText(LabelledString):
"t2w": {},
"global_config": {},
"local_configs": {},
"split_words": True,
}
def __init__(self, text: str, **kwargs):
@ -162,7 +163,8 @@ class MarkupText(LabelledString):
self.t2s,
self.t2w,
self.global_config,
self.local_configs
self.local_configs,
self.split_words
)
def full2short(self, config: dict) -> None:
@ -250,28 +252,26 @@ class MarkupText(LabelledString):
# Toolkits
@staticmethod
def get_tag_str(
attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool
) -> str:
if not is_begin_tag:
return "</span>"
if escape_color_keys:
converted_attr_dict = {}
def get_tag_string_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
if label_hex is not None:
converted_attr_dict = {"foreground": label_hex}
for key, val in attr_dict.items():
substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None)
if substitute_key is None:
converted_attr_dict[key] = val
elif substitute_key:
converted_attr_dict[key] = "black"
else:
converted_attr_dict[key] = "black"
#else:
# converted_attr_dict[key] = "black"
else:
converted_attr_dict = attr_dict.copy()
result = " ".join([
attrs_str = " ".join([
f"{key}='{val}'"
for key, val in converted_attr_dict.items()
])
return f"<span {result}>"
return (f"<span {attrs_str}>", "</span>")
def get_global_attr_dict(self) -> dict[str, str]:
result = {
@ -286,8 +286,9 @@ class MarkupText(LabelledString):
if tuple(map(int, pango_version.split("."))) < (1, 50):
if self.lsh is not None:
log.warning(
f"Pango version {pango_version} found (< 1.50), "
"unable to set `line_height` attribute"
"Pango version %s found (< 1.50), "
"unable to set `line_height` attribute",
pango_version
)
else:
line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE
@ -477,8 +478,8 @@ class MarkupText(LabelledString):
if not self.is_markup:
return [], []
tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>"""
attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2"""
tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>"
attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2"
begin_match_obj_stack = []
markup_tag_items = []
for match_obj in re.finditer(tag_pattern, self.string):
@ -511,7 +512,7 @@ class MarkupText(LabelledString):
return tag_span_pairs, internal_items
def get_external_items(self) -> list[tuple[Span, dict[str, str]]]:
return [
result = [
(self.full_span, self.get_global_attr_dict()),
(self.full_span, self.global_config),
*[
@ -531,6 +532,17 @@ class MarkupText(LabelledString):
for span in self.find_spans_by_selector(selector)
]
]
if self.split_words:
# For backward compatibility
result.extend([
(span, {})
for span in self.find_spans(r"[a-zA-Z]+")
for pattern in (r"[a-zA-Z]+", r"\S+")
])
return result
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
#def get_spans_from_items(
# self, specified_items: list[tuple[Span, dict[str, str]]]
@ -546,31 +558,31 @@ class MarkupText(LabelledString):
# for span in self.split_span(specified_span)
# ]
def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
interval_spans = sorted(self.chain(
self.tag_spans,
[
(index, index)
for span in split_spans
for index in span
]
))
text_spans = self.get_complement_spans(self.full_span, interval_spans)
if self.is_markup:
pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+"
else:
pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+"
return self.chain(*[
self.find_spans(pattern, pos=span_begin, endpos=span_end)
for span_begin, span_end in text_spans
])
#def get_label_span_list(self, split_spans: list[Span]) -> list[Span]:
# interval_spans = sorted(self.chain(
# self.tag_spans,
# [
# (index, index)
# for span in split_spans
# for index in span
# ]
# ))
# text_spans = self.get_complement_spans(self.full_span, interval_spans)
# if self.is_markup:
# pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+"
# else:
# pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+"
# return self.chain(*[
# self.find_spans(pattern, pos=span_begin, endpos=span_end)
# for span_begin, span_end in text_spans
# ])
def get_additional_inserted_str_pairs(
self
) -> list[tuple[Span, tuple[str, str]]]:
return []
#def get_additional_inserted_str_pairs(
# self
#) -> list[tuple[Span, tuple[str, str]]]:
# return []
def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]:
def get_command_repl_items(self) -> list[Span, str]:
result = [
(tag_span, "") for tag_span in self.tag_spans
]
@ -755,8 +767,8 @@ class MarkupText(LabelledString):
# for span, attr_dict in attr_dict_items
# ]
def get_content(self, is_labelled: bool) -> str:
return self.decorated_strings[is_labelled]
def get_full_content_string(self, content_string: str, is_labelled: bool) -> str:
return content_string
# Selector