beets-config/plugins/beetslabels.py

310 lines
9.6 KiB
Python

from beets import ui
from beets.plugins import BeetsPlugin
from beets.ui import Subcommand, print_
from beets.ui.commands import print_and_modify
from beets.dbcore import types
from beets.dbcore import FieldQuery
import optparse
import json
from beets.dbcore import queryparse, query
from beets.dbcore.query import Sort, FieldQueryType, SlowFieldSort
from beets.dbcore.queryparse import query_from_strings, sort_from_strings
from beets.library import LibModel
Prefixes = dict[str, FieldQueryType]
import re
LABELS_FIELD_NAME = "mylabels"
def do_modify_labels(lib, objs, label, action):
changed = []
for obj in objs:
labels = {}
if action != 2:
if LABELS_FIELD_NAME in obj:
labels = json.loads(obj[LABELS_FIELD_NAME])
split = label.split(":")
if action == 0:
if len(split) == 1:
labels[label] = 0
else:
labels[split[0]] = int(split[1])
elif label in labels:
del labels[label]
obj_mods = {
LABELS_FIELD_NAME: json.dumps(labels)
}
if print_and_modify(obj, obj_mods, []) and obj not in changed:
changed.append(obj)
# Still something to do?
if not changed:
print_("No changes to make.")
return
# Confirm action.
changed = ui.input_select_objects(
"Really modify",
changed,
lambda o: print_and_modify(o, mods, dels),
)
# Apply changes to database and files
with lib.transaction():
for obj in changed:
obj.try_sync(True, False, False)
action_map = {
"add": 0,
"remove": 1,
"removeall": 2,
"show": 3,
"transfer": 4
}
def modify_labels(lib, opts, args):
action = args[0]
if action not in action_map:
print_("%s is not a valid action. " % action)
print_("Valid actions are: add, remove, removeall, show, transfer")
return
actnum = action_map[action]
if actnum == 4: # transfer
transfer_labels(lib, opts, args[1:])
return
if actnum >= 2:
label = None
query = args[1:]
else:
label = args[1]
query = args[2:]
items = lib.items(query=query)
if actnum == 3:
for obj in items:
if LABELS_FIELD_NAME in obj:
labels = json.loads(obj[LABELS_FIELD_NAME])
labelstr = ";".join([f"{key}:{labels[key]}" for key in labels.keys()])
print_(f"{obj.title}: {labelstr}")
else:
do_modify_labels(lib, items, label, actnum)
def transfer_labels(lib, opts, args):
source_query = args[0]
dest_query = args[1]
source_items = lib.items(query=source_query)
dest_items = lib.items(query=dest_query)
source_list = list(source_items)
dest_list = list(dest_items)
if len(source_list) == 0:
print_("No source items found.")
return
if len(dest_list) == 0:
print_("No destination items found.")
return
if len(source_list) > 1:
print_("Multiple source items found. Please refine your query to select one source item.")
for item in source_list:
print_(f" - {item.artist} - {item.title} ({item.path})")
return
if len(dest_list) > 1:
print_("Multiple destination items found. Please refine your query to select one destination item.")
for item in dest_list:
print_(f" - {item.artist} - {item.title} ({item.path})")
return
source = source_list[0]
dest = dest_list[0]
print_(f"Source: {source.artist} - {source.title}")
print_(f"Dest: {dest.artist} - {dest.title}")
print_("")
if LABELS_FIELD_NAME not in source or not source[LABELS_FIELD_NAME]:
print_("Source has no labels to transfer.")
return
labels = json.loads(source[LABELS_FIELD_NAME])
labelstr = ";".join([f"{key}:{labels[key]}" for key in labels.keys()])
print_(f"Labels to transfer: {labelstr}")
if LABELS_FIELD_NAME in dest and dest[LABELS_FIELD_NAME]:
dest_labels = json.loads(dest[LABELS_FIELD_NAME])
dest_labelstr = ";".join([f"{key}:{dest_labels[key]}" for key in dest_labels.keys()])
print_(f"WARNING: Destination already has labels: {dest_labelstr}")
print_("These will be overwritten!")
print_("")
if not ui.input_yn("Transfer labels? (y/n)", True):
return
with lib.transaction():
dest[LABELS_FIELD_NAME] = source[LABELS_FIELD_NAME]
dest.try_sync(True, False, False)
print_("Transfer complete.")
labels_command = Subcommand('labels', help='Add or remove labels')
labels_command.func = modify_labels
class HasLabelQuery(FieldQuery):
def __init__(self, _, pattern: str, __):
super().__init__(LABELS_FIELD_NAME, pattern, False)
@classmethod
def value_match(self, pattern, jsonstr):
if jsonstr is not None:
label = pattern
value = None
if "." in pattern:
label,value = pattern.split(".")
labels = json.loads(jsonstr)
if value == None:
return label in labels
else:
return label in labels and str(labels[label]) == value
return False
class LabelValueSort(SlowFieldSort):
def __init__(self, field, ascending=True, case_insensitive=True):
super().__init__(field, ascending, case_insensitive)
# Extract the label key from the field name
# Field format: "label:<labelkey>" or just "labels"
if field.startswith("label:"):
self.label_key = field[len("label:"):]
else:
self.label_key = None
def sort(self, objs):
def key(obj):
labels_json = obj.get(LABELS_FIELD_NAME, None)
if not labels_json:
# No labels, return minimum value for sorting
return float('-inf') if self.ascending else float('inf')
labels = json.loads(labels_json)
if self.label_key:
value = labels.get(self.label_key, float('-inf') if self.ascending else float('inf'))
else:
value = max(labels.values()) if labels else (float('-inf') if self.ascending else float('inf'))
return value
return sorted(objs, key=key, reverse=not self.ascending)
class LabelValueSortsDict(dict):
"""Custom dict that returns LabelValueSort for any label:* key."""
def __missing__(self, key):
if re.match(r"^label[s:]", key):
return LabelValueSort
raise KeyError(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
# Store the original parse_sorted_query function
_original_parse_sorted_query = queryparse.parse_sorted_query
def parse_sorted_query_override(
model_cls: type[LibModel],
parts: list[str],
prefixes: Prefixes = {},
case_insensitive: bool = True,
) -> tuple[query.Query, Sort]:
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
# Separate query token and sort token.
query_parts = []
sort_parts = []
# Split up query in to comma-separated subqueries, each representing
# an AndQuery, which need to be joined together in one OrQuery
subquery_parts = []
for part in parts + [","]:
if part.endswith(","):
# Ensure we can catch "foo, bar" as well as "foo , bar"
last_subquery_part = part[:-1]
if last_subquery_part:
subquery_parts.append(last_subquery_part)
# Parse the subquery in to a single AndQuery
# TODO: Avoid needlessly wrapping AndQueries containing 1 subquery?
query_parts.append(
query_from_strings(
query.AndQuery, model_cls, prefixes, subquery_parts
)
)
del subquery_parts[:]
else:
# Sort parts (1) end in + or -, (2) don't have a field, and
# (3) consist of more than just the + or -.
if part.startswith("label:"):
if part.endswith(("+", "-")):
sort_parts.append(part)
part = part[0:-1]
subquery_parts.append(part)
else:
if part.endswith(("+", "-")) and ":" not in part and len(part) > 1:
sort_parts.append(part)
else:
subquery_parts.append(part)
# Avoid needlessly wrapping single statements in an OR
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
s = sort_from_strings(model_cls, sort_parts, case_insensitive)
return q, s
class BeetsLabelsPlugin(BeetsPlugin):
# item_queries = { 'has_label': DelimitedHasExact }
# item_types = {'mylabels': types.SEMICOLON_SPACE_DSV}
def __init__(self):
super().__init__()
self.item_queries = {"label": HasLabelQuery}
# Register the custom sort for label values
# Import Item here to avoid circular imports
from beets.library import Item
from beets import dbcore
# Replace Item._sorts with our custom dict that handles label:* patterns
original_sorts = Item._sorts.copy() if isinstance(Item._sorts, dict) else {}
new_sorts = LabelValueSortsDict(original_sorts)
Item._sorts = new_sorts
# Monkey patch parse_sorted_query globally
# Need to patch both queryparse and dbcore since dbcore imports it
queryparse.parse_sorted_query = parse_sorted_query_override
dbcore.parse_sorted_query = parse_sorted_query_override
def commands(self):
return [labels_command]