I just wanted custom sorting behavior...
This commit is contained in:
parent
1da18c5e1b
commit
1edc0b270d
1 changed files with 121 additions and 0 deletions
|
|
@ -8,6 +8,16 @@ from beets.dbcore import FieldQuery
|
|||
|
||||
import optparse
|
||||
import json
|
||||
|
||||
from beets.dbcore import queryparse, query
|
||||
from beets.dbcore.query import Sort, FieldQueryType, SlowFieldSort
|
||||
from beets.dbcore.queryparse import query_from_strings, sort_from_strings
|
||||
|
||||
from beets.library import LibModel
|
||||
Prefixes = dict[str, FieldQueryType]
|
||||
|
||||
import re
|
||||
|
||||
LABELS_FIELD_NAME = "mylabels"
|
||||
|
||||
def do_modify_labels(lib, objs, label, action):
|
||||
|
|
@ -176,6 +186,103 @@ class HasLabelQuery(FieldQuery):
|
|||
return label in labels and str(labels[label]) == value
|
||||
return False
|
||||
|
||||
class LabelValueSort(SlowFieldSort):
|
||||
def __init__(self, field, ascending=True, case_insensitive=True):
|
||||
super().__init__(field, ascending, case_insensitive)
|
||||
|
||||
# Extract the label key from the field name
|
||||
# Field format: "label:<labelkey>" or just "labels"
|
||||
if field.startswith("label:"):
|
||||
self.label_key = field[len("label:"):]
|
||||
else:
|
||||
self.label_key = None
|
||||
|
||||
def sort(self, objs):
|
||||
def key(obj):
|
||||
labels_json = obj.get(LABELS_FIELD_NAME, None)
|
||||
|
||||
if not labels_json:
|
||||
# No labels, return minimum value for sorting
|
||||
return float('-inf') if self.ascending else float('inf')
|
||||
|
||||
labels = json.loads(labels_json)
|
||||
|
||||
if self.label_key:
|
||||
value = labels.get(self.label_key, float('-inf') if self.ascending else float('inf'))
|
||||
else:
|
||||
value = max(labels.values()) if labels else (float('-inf') if self.ascending else float('inf'))
|
||||
|
||||
return value
|
||||
|
||||
return sorted(objs, key=key, reverse=not self.ascending)
|
||||
|
||||
class LabelValueSortsDict(dict):
|
||||
"""Custom dict that returns LabelValueSort for any label:* key."""
|
||||
|
||||
def __missing__(self, key):
|
||||
if re.match(r"^label[s:]", key):
|
||||
return LabelValueSort
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
# Store the original parse_sorted_query function
|
||||
_original_parse_sorted_query = queryparse.parse_sorted_query
|
||||
|
||||
def parse_sorted_query_override(
|
||||
model_cls: type[LibModel],
|
||||
parts: list[str],
|
||||
prefixes: Prefixes = {},
|
||||
case_insensitive: bool = True,
|
||||
) -> tuple[query.Query, Sort]:
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
# Separate query token and sort token.
|
||||
query_parts = []
|
||||
sort_parts = []
|
||||
|
||||
# Split up query in to comma-separated subqueries, each representing
|
||||
# an AndQuery, which need to be joined together in one OrQuery
|
||||
subquery_parts = []
|
||||
for part in parts + [","]:
|
||||
if part.endswith(","):
|
||||
# Ensure we can catch "foo, bar" as well as "foo , bar"
|
||||
last_subquery_part = part[:-1]
|
||||
if last_subquery_part:
|
||||
subquery_parts.append(last_subquery_part)
|
||||
# Parse the subquery in to a single AndQuery
|
||||
# TODO: Avoid needlessly wrapping AndQueries containing 1 subquery?
|
||||
query_parts.append(
|
||||
query_from_strings(
|
||||
query.AndQuery, model_cls, prefixes, subquery_parts
|
||||
)
|
||||
)
|
||||
del subquery_parts[:]
|
||||
else:
|
||||
# Sort parts (1) end in + or -, (2) don't have a field, and
|
||||
# (3) consist of more than just the + or -.
|
||||
if part.startswith("label:"):
|
||||
if part.endswith(("+", "-")):
|
||||
sort_parts.append(part)
|
||||
part = part[0:-1]
|
||||
|
||||
subquery_parts.append(part)
|
||||
else:
|
||||
if part.endswith(("+", "-")) and ":" not in part and len(part) > 1:
|
||||
sort_parts.append(part)
|
||||
else:
|
||||
subquery_parts.append(part)
|
||||
|
||||
# Avoid needlessly wrapping single statements in an OR
|
||||
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
|
||||
s = sort_from_strings(model_cls, sort_parts, case_insensitive)
|
||||
return q, s
|
||||
|
||||
class BeetsLabelsPlugin(BeetsPlugin):
|
||||
# item_queries = { 'has_label': DelimitedHasExact }
|
||||
# item_types = {'mylabels': types.SEMICOLON_SPACE_DSV}
|
||||
|
|
@ -184,6 +291,20 @@ class BeetsLabelsPlugin(BeetsPlugin):
|
|||
super().__init__()
|
||||
self.item_queries = {"label": HasLabelQuery}
|
||||
|
||||
# Register the custom sort for label values
|
||||
# Import Item here to avoid circular imports
|
||||
from beets.library import Item
|
||||
from beets import dbcore
|
||||
|
||||
# Replace Item._sorts with our custom dict that handles label:* patterns
|
||||
original_sorts = Item._sorts.copy() if isinstance(Item._sorts, dict) else {}
|
||||
new_sorts = LabelValueSortsDict(original_sorts)
|
||||
Item._sorts = new_sorts
|
||||
|
||||
# Monkey patch parse_sorted_query globally
|
||||
# Need to patch both queryparse and dbcore since dbcore imports it
|
||||
queryparse.parse_sorted_query = parse_sorted_query_override
|
||||
dbcore.parse_sorted_query = parse_sorted_query_override
|
||||
|
||||
def commands(self):
|
||||
return [labels_command]
|
||||
|
|
|
|||
Loading…
Reference in a new issue