213 lines
7.2 KiB
Python
213 lines
7.2 KiB
Python
from beets import ui
|
|
from beets.plugins import BeetsPlugin
|
|
from beets.ui import print_
|
|
|
|
from beets.dbcore import types
|
|
from typing import Any
|
|
import json
|
|
|
|
from beets.dbcore import queryparse, query
|
|
from beets.dbcore.query import Sort, FieldQueryType
|
|
from beets.dbcore.queryparse import query_from_strings, sort_from_strings
|
|
|
|
from beets.library import LibModel
|
|
Prefixes = dict[str, FieldQueryType]
|
|
|
|
import re
|
|
|
|
from .labels import HasLabelQuery, labels_command, LABELS_FIELD_NAME, LabelValueSort
|
|
from .playlists import initialize_playlists, expand_playlist_query, valid_playlist, PlaylistValueSort
|
|
|
|
|
|
class KeyValueDelimitedString(types.Type):
|
|
"""A dict type stored as 'key:value; key2:value2' in the database.
|
|
|
|
Stores dictionaries with string keys and integer values in a human-readable
|
|
semicolon-delimited format. Values default to 1 if not specified.
|
|
"""
|
|
|
|
sql = "TEXT"
|
|
query = query.SubstringQuery
|
|
model_type = dict
|
|
|
|
@property
|
|
def null(self):
|
|
return {}
|
|
|
|
def format(self, value: dict | None) -> str:
|
|
"""Format dict as 'key:value; key2:value2' string."""
|
|
if not value:
|
|
return ""
|
|
# Sort for consistent output
|
|
return "; ".join(f"{k}:{v}" for k, v in sorted(value.items()))
|
|
|
|
def parse(self, string: str) -> dict:
|
|
"""Parse 'key:value; key2:value2' string or legacy JSON into dict."""
|
|
if not string:
|
|
return {}
|
|
|
|
# Try parsing as JSON first (for backward compatibility)
|
|
if string.startswith('{'):
|
|
try:
|
|
return json.loads(string)
|
|
except (json.JSONDecodeError, ValueError):
|
|
pass
|
|
|
|
# Parse as semicolon-delimited format
|
|
result = {}
|
|
for pair in string.split("; "):
|
|
pair = pair.strip()
|
|
if not pair:
|
|
continue
|
|
if ":" in pair:
|
|
k, v = pair.split(":", 1)
|
|
k = k.strip()
|
|
v = v.strip()
|
|
try:
|
|
result[k] = int(v)
|
|
except ValueError:
|
|
result[k] = 1
|
|
else:
|
|
result[pair] = 1
|
|
return result
|
|
|
|
def to_sql(self, model_value: Any) -> str:
|
|
"""Convert dict to SQL string."""
|
|
if model_value is None:
|
|
return ""
|
|
return self.format(model_value)
|
|
|
|
def normalize(self, value: Any) -> dict:
|
|
"""Normalize value to dict."""
|
|
if value is None:
|
|
return {}
|
|
if isinstance(value, dict):
|
|
return value
|
|
if isinstance(value, str):
|
|
return self.parse(value)
|
|
return {}
|
|
|
|
class LabelValueSortsDict(dict):
|
|
"""Custom dict that returns LabelValueSort for any label:* key."""
|
|
|
|
def __missing__(self, key):
|
|
if re.match(r"^label[s:]", key):
|
|
return LabelValueSort
|
|
if re.match(r"^playlist:", key):
|
|
return PlaylistValueSort
|
|
raise KeyError(key)
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
# Store the original parse_sorted_query function
|
|
_original_parse_sorted_query = queryparse.parse_sorted_query
|
|
|
|
def parse_sorted_query_override(
|
|
model_cls: type[LibModel],
|
|
parts: list[str],
|
|
prefixes: Prefixes = {},
|
|
case_insensitive: bool = True,
|
|
) -> tuple[query.Query, Sort]:
|
|
"""Given a list of strings, create the `Query` and `Sort` that they
|
|
represent.
|
|
"""
|
|
# First, expand any playlist: and sorted_playlist: macros
|
|
expanded_parts = []
|
|
|
|
for part in parts:
|
|
if part.startswith('sorted_playlist:'):
|
|
# Extract playlist name
|
|
playlist_name = part[len("sorted_playlist:"):]
|
|
|
|
if valid_playlist(playlist_name):
|
|
# Expand to the query
|
|
expanded_parts.extend(expand_playlist_query(playlist_name))
|
|
# Inject the sort
|
|
expanded_parts.append(f"playlist:{playlist_name}-")
|
|
else:
|
|
# Unknown playlist, just let it go through
|
|
expanded_parts.append(part)
|
|
elif part.startswith('playlist:') and not re.match(r"[+-]$", part):
|
|
playlist_name = part[len("playlist:"):]
|
|
|
|
if valid_playlist(playlist_name):
|
|
expanded_parts.extend(expand_playlist_query(playlist_name))
|
|
else:
|
|
# Unknown playlist, just let it go through
|
|
expanded_parts.append(part)
|
|
else:
|
|
expanded_parts.append(part)
|
|
|
|
parts = expanded_parts
|
|
|
|
# Separate query token and sort token.
|
|
query_parts = []
|
|
sort_parts = []
|
|
|
|
# Split up query in to comma-separated subqueries, each representing
|
|
# an AndQuery, which need to be joined together in one OrQuery
|
|
subquery_parts = []
|
|
for part in parts + [","]:
|
|
if part.endswith(","):
|
|
# Ensure we can catch "foo, bar" as well as "foo , bar"
|
|
last_subquery_part = part[:-1]
|
|
if last_subquery_part:
|
|
subquery_parts.append(last_subquery_part)
|
|
# Parse the subquery in to a single AndQuery
|
|
# TODO: Avoid needlessly wrapping AndQueries containing 1 subquery?
|
|
query_parts.append(
|
|
query_from_strings(
|
|
query.AndQuery, model_cls, prefixes, subquery_parts
|
|
)
|
|
)
|
|
del subquery_parts[:]
|
|
else:
|
|
# Sort parts (1) end in + or -, (2) don't have a field, and
|
|
# (3) consist of more than just the + or -.
|
|
if re.match(r"^(label|playlist):", part):
|
|
if part.endswith(("+", "-")):
|
|
sort_parts.append(part)
|
|
part = part[0:-1]
|
|
else:
|
|
subquery_parts.append(part)
|
|
else:
|
|
if part.endswith(("+", "-")) and ":" not in part and len(part) > 1:
|
|
sort_parts.append(part)
|
|
else:
|
|
subquery_parts.append(part)
|
|
|
|
# Avoid needlessly wrapping single statements in an OR
|
|
q = query.OrQuery(query_parts) if len(query_parts) > 1 else query_parts[0]
|
|
s = sort_from_strings(model_cls, sort_parts, case_insensitive)
|
|
return q, s
|
|
|
|
class BeetsLabelsPlugin(BeetsPlugin):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.item_queries = {"label": HasLabelQuery}
|
|
self.item_types = {LABELS_FIELD_NAME: KeyValueDelimitedString()}
|
|
|
|
playlists_config = self.config['playlists'].get()
|
|
initialize_playlists(playlists_config)
|
|
|
|
# Register the custom sort for label values
|
|
# Import Item here to avoid circular imports
|
|
from beets.library import Item
|
|
from beets import dbcore
|
|
|
|
# Replace Item._sorts with our custom dict that handles label:* patterns
|
|
original_sorts = Item._sorts.copy() if isinstance(Item._sorts, dict) else {}
|
|
new_sorts = LabelValueSortsDict(original_sorts)
|
|
Item._sorts = new_sorts
|
|
|
|
# Monkey patch parse_sorted_query globally
|
|
# Need to patch both queryparse and dbcore since dbcore imports it
|
|
queryparse.parse_sorted_query = parse_sorted_query_override
|
|
dbcore.parse_sorted_query = parse_sorted_query_override
|
|
|
|
def commands(self):
|
|
return [labels_command]
|