Use eval to compute cost

This commit is contained in:
Benson Chu 2026-04-04 14:33:58 -05:00
parent a5ccea6b33
commit ca8d753cd0
2 changed files with 66 additions and 4 deletions

View file

@ -63,6 +63,7 @@ beetslabels:
# or: "label:a , label:b" # or: "label:a , label:b"
# in the future? concat: ["label:a", "label:b"] # in the future? concat: ["label:a", "label:b"]
query: "label:effortless" query: "label:effortless"
sort: "effortless + exercise"
smartplaylist: smartplaylist:
relative_to: ~/Music/0beets_playlists relative_to: ~/Music/0beets_playlists

View file

@ -5,6 +5,8 @@ from beets.dbcore.query import SlowFieldSort
from .labels import LABELS_FIELD_NAME from .labels import LABELS_FIELD_NAME
import json import json
import ast
import re
playlist_config = {} playlist_config = {}
@ -12,20 +14,72 @@ def initialize_playlists(playlists):
for playlist in playlists: for playlist in playlists:
pl_name = playlist["name"] pl_name = playlist["name"]
query = [playlist["query"], ",", f"label:{pl_name}"] query = [playlist["query"], ",", f"label:{pl_name}"]
playlist_config[pl_name] = query sort_expr = playlist.get("sort", None)
playlist_config[pl_name] = {
"query": query,
"sort": sort_expr
}
def valid_playlist(playlist_name): def valid_playlist(playlist_name):
return playlist_name in playlist_config return playlist_name in playlist_config
def expand_playlist_query(playlist_name): def expand_playlist_query(playlist_name):
print(playlist_config[playlist_name]) return playlist_config[playlist_name]["query"]
return playlist_config[playlist_name]
def get_sort_expression(playlist_name):
"""Get the sort expression for a playlist, or None if not defined."""
return playlist_config[playlist_name].get("sort")
def extract_label_names(expression):
"""Extract label names from a Python expression.
Returns a set of identifier names found in the expression.
"""
try:
tree = ast.parse(expression, mode='eval')
label_names = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name):
label_names.add(node.id)
return label_names
except SyntaxError:
return set()
def evaluate_sort_expression(expression, labels):
"""Evaluate a sort expression with label values.
Args:
expression: Python expression like "effortless + exercise * 2"
labels: Dict of label names to values
Returns:
The evaluated result, or a default value if evaluation fails.
"""
if not expression:
return 0
# Extract label names from the expression
label_names = extract_label_names(expression)
# Build a namespace with label values (defaulting to 0 for missing labels)
namespace = {name: labels.get(name, 0) for name in label_names}
try:
# Evaluate with restricted builtins for safety
result = eval(expression, {"__builtins__": {}}, namespace)
return float(result) if result is not None else 0
except Exception:
# If evaluation fails, return 0
return 0
class PlaylistValueSort(SlowFieldSort): class PlaylistValueSort(SlowFieldSort):
def __init__(self, field, ascending=True, case_insensitive=True): def __init__(self, field, ascending=True, case_insensitive=True):
super().__init__(field, ascending, case_insensitive) super().__init__(field, ascending, case_insensitive)
self.playlist_key = field[len("playlist:"):] self.playlist_key = field[len("playlist:"):]
self.sort_expr = get_sort_expression(self.playlist_key)
def sort(self, objs): def sort(self, objs):
def key(obj): def key(obj):
@ -37,12 +91,19 @@ class PlaylistValueSort(SlowFieldSort):
labels = json.loads(labels_json) labels = json.loads(labels_json)
# If there's a sort expression, evaluate it
if self.sort_expr:
return evaluate_sort_expression(self.sort_expr, labels)
# Otherwise, use the old behavior:
# Check if there's a label with the playlist name
if self.playlist_key in labels: if self.playlist_key in labels:
return labels[self.playlist_key] return labels[self.playlist_key]
# Sum all matching label values from the playlist query
matching_labels = \ matching_labels = \
[label for label in labels [label for label in labels
if label in playlist_config[self.playlist_key]] if label in playlist_config[self.playlist_key]["query"]]
if len(matching_labels) == 0: if len(matching_labels) == 0:
return float('-inf') if self.ascending else float('inf') return float('-inf') if self.ascending else float('inf')