Use eval to compute cost
This commit is contained in:
parent
a5ccea6b33
commit
ca8d753cd0
2 changed files with 66 additions and 4 deletions
|
|
@ -63,6 +63,7 @@ beetslabels:
|
||||||
# or: "label:a , label:b"
|
# or: "label:a , label:b"
|
||||||
# in the future? concat: ["label:a", "label:b"]
|
# in the future? concat: ["label:a", "label:b"]
|
||||||
query: "label:effortless"
|
query: "label:effortless"
|
||||||
|
sort: "effortless + exercise"
|
||||||
|
|
||||||
smartplaylist:
|
smartplaylist:
|
||||||
relative_to: ~/Music/0beets_playlists
|
relative_to: ~/Music/0beets_playlists
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ from beets.dbcore.query import SlowFieldSort
|
||||||
from .labels import LABELS_FIELD_NAME
|
from .labels import LABELS_FIELD_NAME
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
|
||||||
playlist_config = {}
|
playlist_config = {}
|
||||||
|
|
||||||
|
|
@ -12,20 +14,72 @@ def initialize_playlists(playlists):
|
||||||
for playlist in playlists:
|
for playlist in playlists:
|
||||||
pl_name = playlist["name"]
|
pl_name = playlist["name"]
|
||||||
query = [playlist["query"], ",", f"label:{pl_name}"]
|
query = [playlist["query"], ",", f"label:{pl_name}"]
|
||||||
playlist_config[pl_name] = query
|
sort_expr = playlist.get("sort", None)
|
||||||
|
playlist_config[pl_name] = {
|
||||||
|
"query": query,
|
||||||
|
"sort": sort_expr
|
||||||
|
}
|
||||||
|
|
||||||
def valid_playlist(playlist_name):
|
def valid_playlist(playlist_name):
|
||||||
return playlist_name in playlist_config
|
return playlist_name in playlist_config
|
||||||
|
|
||||||
def expand_playlist_query(playlist_name):
|
def expand_playlist_query(playlist_name):
|
||||||
print(playlist_config[playlist_name])
|
return playlist_config[playlist_name]["query"]
|
||||||
return playlist_config[playlist_name]
|
|
||||||
|
def get_sort_expression(playlist_name):
|
||||||
|
"""Get the sort expression for a playlist, or None if not defined."""
|
||||||
|
return playlist_config[playlist_name].get("sort")
|
||||||
|
|
||||||
|
def extract_label_names(expression):
|
||||||
|
"""Extract label names from a Python expression.
|
||||||
|
|
||||||
|
Returns a set of identifier names found in the expression.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tree = ast.parse(expression, mode='eval')
|
||||||
|
label_names = set()
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
label_names.add(node.id)
|
||||||
|
|
||||||
|
return label_names
|
||||||
|
except SyntaxError:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def evaluate_sort_expression(expression, labels):
|
||||||
|
"""Evaluate a sort expression with label values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: Python expression like "effortless + exercise * 2"
|
||||||
|
labels: Dict of label names to values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The evaluated result, or a default value if evaluation fails.
|
||||||
|
"""
|
||||||
|
if not expression:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Extract label names from the expression
|
||||||
|
label_names = extract_label_names(expression)
|
||||||
|
|
||||||
|
# Build a namespace with label values (defaulting to 0 for missing labels)
|
||||||
|
namespace = {name: labels.get(name, 0) for name in label_names}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Evaluate with restricted builtins for safety
|
||||||
|
result = eval(expression, {"__builtins__": {}}, namespace)
|
||||||
|
return float(result) if result is not None else 0
|
||||||
|
except Exception:
|
||||||
|
# If evaluation fails, return 0
|
||||||
|
return 0
|
||||||
|
|
||||||
class PlaylistValueSort(SlowFieldSort):
|
class PlaylistValueSort(SlowFieldSort):
|
||||||
def __init__(self, field, ascending=True, case_insensitive=True):
|
def __init__(self, field, ascending=True, case_insensitive=True):
|
||||||
super().__init__(field, ascending, case_insensitive)
|
super().__init__(field, ascending, case_insensitive)
|
||||||
|
|
||||||
self.playlist_key = field[len("playlist:"):]
|
self.playlist_key = field[len("playlist:"):]
|
||||||
|
self.sort_expr = get_sort_expression(self.playlist_key)
|
||||||
|
|
||||||
def sort(self, objs):
|
def sort(self, objs):
|
||||||
def key(obj):
|
def key(obj):
|
||||||
|
|
@ -37,12 +91,19 @@ class PlaylistValueSort(SlowFieldSort):
|
||||||
|
|
||||||
labels = json.loads(labels_json)
|
labels = json.loads(labels_json)
|
||||||
|
|
||||||
|
# If there's a sort expression, evaluate it
|
||||||
|
if self.sort_expr:
|
||||||
|
return evaluate_sort_expression(self.sort_expr, labels)
|
||||||
|
|
||||||
|
# Otherwise, use the old behavior:
|
||||||
|
# Check if there's a label with the playlist name
|
||||||
if self.playlist_key in labels:
|
if self.playlist_key in labels:
|
||||||
return labels[self.playlist_key]
|
return labels[self.playlist_key]
|
||||||
|
|
||||||
|
# Sum all matching label values from the playlist query
|
||||||
matching_labels = \
|
matching_labels = \
|
||||||
[label for label in labels
|
[label for label in labels
|
||||||
if label in playlist_config[self.playlist_key]]
|
if label in playlist_config[self.playlist_key]["query"]]
|
||||||
|
|
||||||
if len(matching_labels) == 0:
|
if len(matching_labels) == 0:
|
||||||
return float('-inf') if self.ascending else float('inf')
|
return float('-inf') if self.ascending else float('inf')
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue