diff --git a/plugins/beetslabels.py b/plugins/beetslabels.py index 7c1093c..e40086f 100644 --- a/plugins/beetslabels.py +++ b/plugins/beetslabels.py @@ -8,6 +8,16 @@ from beets.dbcore import FieldQuery import optparse import json + +from beets.dbcore import queryparse, query +from beets.dbcore.query import Sort, FieldQueryType, SlowFieldSort +from beets.dbcore.queryparse import query_from_strings, sort_from_strings + +from beets.library import LibModel +Prefixes = dict[str, FieldQueryType] + +import re + LABELS_FIELD_NAME = "mylabels" def do_modify_labels(lib, objs, label, action): @@ -176,6 +186,103 @@ class HasLabelQuery(FieldQuery): return label in labels and str(labels[label]) == value return False +class LabelValueSort(SlowFieldSort): + def __init__(self, field, ascending=True, case_insensitive=True): + super().__init__(field, ascending, case_insensitive) + + # Extract the label key from the field name + # Field format: "label:" or just "labels" + if field.startswith("label:"): + self.label_key = field[len("label:"):] + else: + self.label_key = None + + def sort(self, objs): + def key(obj): + labels_json = obj.get(LABELS_FIELD_NAME, None) + + if not labels_json: + # No labels, return minimum value for sorting + return float('-inf') if self.ascending else float('inf') + + labels = json.loads(labels_json) + + if self.label_key: + value = labels.get(self.label_key, float('-inf') if self.ascending else float('inf')) + else: + value = max(labels.values()) if labels else (float('-inf') if self.ascending else float('inf')) + + return value + + return sorted(objs, key=key, reverse=not self.ascending) + +class LabelValueSortsDict(dict): + """Custom dict that returns LabelValueSort for any label:* key.""" + + def __missing__(self, key): + if re.match(r"^label[s:]", key): + return LabelValueSort + raise KeyError(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + +# Store the original parse_sorted_query function +_original_parse_sorted_query = queryparse.parse_sorted_query + +def parse_sorted_query_override( + model_cls: type[LibModel], + parts: list[str], + prefixes: Prefixes = {}, + case_insensitive: bool = True, +) -> tuple[query.Query, Sort]: + """Given a list of strings, create the `Query` and `Sort` that they + represent. + """ + # Separate query token and sort token. + query_parts = [] + sort_parts = [] + + # Split up query in to comma-separated subqueries, each representing + # an AndQuery, which need to be joined together in one OrQuery + subquery_parts = [] + for part in parts + [","]: + if part.endswith(","): + # Ensure we can catch "foo, bar" as well as "foo , bar" + last_subquery_part = part[:-1] + if last_subquery_part: + subquery_parts.append(last_subquery_part) + # Parse the subquery in to a single AndQuery + # TODO: Avoid needlessly wrapping AndQueries containing 1 subquery? + query_parts.append( + query_from_strings( + query.AndQuery, model_cls, prefixes, subquery_parts + ) + ) + del subquery_parts[:] + else: + # Sort parts (1) end in + or -, (2) don't have a field, and + # (3) consist of more than just the + or -. + if part.startswith("label:"): + if part.endswith(("+", "-")): + sort_parts.append(part) + part = part[0:-1] + + subquery_parts.append(part) + else: + if part.endswith(("+", "-")) and ":" not in part and len(part) > 1: + sort_parts.append(part) + else: + subquery_parts.append(part) + + # Avoid needlessly wrapping single statements in an OR + q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0] + s = sort_from_strings(model_cls, sort_parts, case_insensitive) + return q, s + class BeetsLabelsPlugin(BeetsPlugin): # item_queries = { 'has_label': DelimitedHasExact } # item_types = {'mylabels': types.SEMICOLON_SPACE_DSV} @@ -184,6 +291,20 @@ class BeetsLabelsPlugin(BeetsPlugin): super().__init__() self.item_queries = {"label": HasLabelQuery} + # Register the custom sort for label values + # Import Item here to avoid circular imports + from beets.library import Item + from beets import dbcore + + # Replace Item._sorts with our custom dict that handles label:* patterns + original_sorts = Item._sorts.copy() if isinstance(Item._sorts, dict) else {} + new_sorts = LabelValueSortsDict(original_sorts) + Item._sorts = new_sorts + + # Monkey patch parse_sorted_query globally + # Need to patch both queryparse and dbcore since dbcore imports it + queryparse.parse_sorted_query = parse_sorted_query_override + dbcore.parse_sorted_query = parse_sorted_query_override def commands(self): return [labels_command]