Add back base_color attribute

This commit is contained in:
YishiMichael 2022-04-15 13:27:50 +08:00
parent 4c324767bd
commit 020bd87271
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 20 additions and 15 deletions

View file

@ -36,18 +36,15 @@ class LabelledString(SVGMobject, ABC):
"should_subdivide_sharp_curves": True, "should_subdivide_sharp_curves": True,
"should_remove_null_curves": True, "should_remove_null_curves": True,
}, },
"base_color": WHITE,
"isolate": [], "isolate": [],
} }
def __init__(self, string: str, **kwargs): def __init__(self, string: str, **kwargs):
self.string = string self.string = string
digest_config(self, kwargs) digest_config(self, kwargs)
if self.base_color is None:
self.base_color_int = self.color_to_int( self.base_color = WHITE
self.svg_default.get("fill_color") \
or self.svg_default.get("color") \
or WHITE
)
self.pre_parse() self.pre_parse()
self.parse() self.parse()
@ -68,7 +65,8 @@ class LabelledString(SVGMobject, ABC):
def generate_mobject(self) -> None: def generate_mobject(self) -> None:
super().generate_mobject() super().generate_mobject()
if self.label_span_list: num_labels = len(self.label_span_list)
if num_labels:
file_path = self.get_file_path_(use_plain_file=False) file_path = self.get_file_path_(use_plain_file=False)
labelled_svg = SVGMobject(file_path) labelled_svg = SVGMobject(file_path)
submob_color_ints = [ submob_color_ints = [
@ -85,7 +83,7 @@ class LabelledString(SVGMobject, ABC):
) )
unrecognized_color_ints = remove_list_redundancies(sorted(filter( unrecognized_color_ints = remove_list_redundancies(sorted(filter(
lambda color_int: color_int > len(self.label_span_list), lambda color_int: color_int > num_labels,
submob_color_ints submob_color_ints
))) )))
if unrecognized_color_ints: if unrecognized_color_ints:
@ -100,6 +98,7 @@ class LabelledString(SVGMobject, ABC):
def pre_parse(self) -> None: def pre_parse(self) -> None:
self.string_len = len(self.string) self.string_len = len(self.string)
self.full_span = (0, self.string_len) self.full_span = (0, self.string_len)
self.base_color_int = self.color_to_int(self.base_color)
def parse(self) -> None: def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items() self.command_repl_items = self.get_command_repl_items()
@ -311,7 +310,7 @@ class LabelledString(SVGMobject, ABC):
self.extra_entity_spans self.extra_entity_spans
)) ))
def index_not_in_entity_spans(self, index: int) -> bool: def is_splittable_index(self, index: int) -> bool:
return not any([ return not any([
entity_span[0] < index < entity_span[1] entity_span[0] < index < entity_span[1]
for entity_span in self.entity_spans for entity_span in self.entity_spans
@ -348,12 +347,16 @@ class LabelledString(SVGMobject, ABC):
self.external_specified_spans, self.external_specified_spans,
self.find_substrs(self.isolate) self.find_substrs(self.isolate)
)) ))
shrinked_spans = list(filter( filtered_spans = list(filter(
lambda span: span[0] < span[1] and all([ lambda span: all([
self.index_not_in_entity_spans(index) self.is_splittable_index(index)
for index in span for index in span
]), ]),
[self.shrink_span(span) for span in spans] spans
))
shrinked_spans = list(filter(
lambda span: span[0] < span[1],
[self.shrink_span(span) for span in filtered_spans]
)) ))
return remove_list_redundancies(shrinked_spans) return remove_list_redundancies(shrinked_spans)

View file

@ -47,6 +47,7 @@ class MTex(LabelledString):
self.__class__.__name__, self.__class__.__name__,
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.base_color,
self.isolate, self.isolate,
self.tex_string, self.tex_string,
self.alignment, self.alignment,

View file

@ -161,6 +161,7 @@ class MarkupText(LabelledString):
self.__class__.__name__, self.__class__.__name__,
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.base_color,
self.isolate, self.isolate,
self.text, self.text,
self.is_markup, self.is_markup,
@ -452,7 +453,7 @@ class MarkupText(LabelledString):
self.specified_spans self.specified_spans
)))) ))))
breakup_indices = sorted(filter( breakup_indices = sorted(filter(
self.index_not_in_entity_spans, breakup_indices self.is_splittable_index, breakup_indices
)) ))
return list(filter( return list(filter(
lambda span: self.get_substr(span).strip(), lambda span: self.get_substr(span).strip(),
@ -462,7 +463,7 @@ class MarkupText(LabelledString):
def get_content(self, use_plain_file: bool) -> str: def get_content(self, use_plain_file: bool) -> str:
filtered_attr_dicts = list(filter( filtered_attr_dicts = list(filter(
lambda item: all([ lambda item: all([
self.index_not_in_entity_spans(index) self.is_splittable_index(index)
for index in item[0] for index in item[0]
]), ]),
self.predefined_attr_dicts self.predefined_attr_dicts