904 lines
37 KiB
Python
904 lines
37 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import ast
|
|
import builtins
|
|
import difflib
|
|
from collections.abc import Mapping
|
|
from importlib import import_module
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
|
|
class InterpreterError(ValueError):
|
|
"""
|
|
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
|
operations.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
ERRORS = {
|
|
name: getattr(builtins, name)
|
|
for name in dir(builtins)
|
|
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
|
}
|
|
|
|
|
|
LIST_SAFE_MODULES = [
|
|
"random",
|
|
"collections",
|
|
"math",
|
|
"time",
|
|
"queue",
|
|
"itertools",
|
|
"re",
|
|
"stat",
|
|
"statistics",
|
|
"unicodedata",
|
|
]
|
|
|
|
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
|
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
|
|
|
|
|
class BreakException(Exception):
|
|
pass
|
|
|
|
|
|
class ContinueException(Exception):
|
|
pass
|
|
|
|
|
|
class ReturnException(Exception):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
|
|
def get_iterable(obj):
|
|
if isinstance(obj, list):
|
|
return obj
|
|
elif hasattr(obj, "__iter__"):
|
|
return list(obj)
|
|
else:
|
|
raise InterpreterError("Object is not iterable")
|
|
|
|
|
|
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
|
if isinstance(expression.op, ast.USub):
|
|
return -operand
|
|
elif isinstance(expression.op, ast.UAdd):
|
|
return operand
|
|
elif isinstance(expression.op, ast.Not):
|
|
return not operand
|
|
elif isinstance(expression.op, ast.Invert):
|
|
return ~operand
|
|
else:
|
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
|
|
|
|
|
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
|
args = [arg.arg for arg in lambda_expression.args.args]
|
|
|
|
def lambda_func(*values):
|
|
new_state = state.copy()
|
|
for arg, value in zip(args, values):
|
|
new_state[arg] = value
|
|
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
|
|
|
return lambda_func
|
|
|
|
|
|
def evaluate_while(while_loop, state, static_tools, custom_tools):
|
|
max_iterations = 1000
|
|
iterations = 0
|
|
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
|
|
for node in while_loop.body:
|
|
try:
|
|
evaluate_ast(node, state, static_tools, custom_tools)
|
|
except BreakException:
|
|
return None
|
|
except ContinueException:
|
|
break
|
|
iterations += 1
|
|
if iterations > max_iterations:
|
|
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
|
return None
|
|
|
|
|
|
def create_function(func_def, state, static_tools, custom_tools):
|
|
def new_func(*args, **kwargs):
|
|
func_state = state.copy()
|
|
arg_names = [arg.arg for arg in func_def.args.args]
|
|
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
|
|
|
# Apply default values
|
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
|
|
|
# Set positional arguments
|
|
for name, value in zip(arg_names, args):
|
|
func_state[name] = value
|
|
|
|
# # Set keyword arguments
|
|
for name, value in kwargs.items():
|
|
func_state[name] = value
|
|
|
|
# Handle variable arguments
|
|
if func_def.args.vararg:
|
|
vararg_name = func_def.args.vararg.arg
|
|
func_state[vararg_name] = args
|
|
|
|
if func_def.args.kwarg:
|
|
kwarg_name = func_def.args.kwarg.arg
|
|
func_state[kwarg_name] = kwargs
|
|
|
|
# Set default values for arguments that were not provided
|
|
for name, value in defaults.items():
|
|
if name not in func_state:
|
|
func_state[name] = value
|
|
|
|
# Update function state with self and __class__
|
|
if func_def.args.args and func_def.args.args[0].arg == "self":
|
|
if args:
|
|
func_state["self"] = args[0]
|
|
func_state["__class__"] = args[0].__class__
|
|
|
|
result = None
|
|
try:
|
|
for stmt in func_def.body:
|
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
|
except ReturnException as e:
|
|
result = e.value
|
|
return result
|
|
|
|
return new_func
|
|
|
|
|
|
def create_class(class_name, class_bases, class_body):
|
|
class_dict = {}
|
|
for key, value in class_body.items():
|
|
class_dict[key] = value
|
|
return type(class_name, tuple(class_bases), class_dict)
|
|
|
|
|
|
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
|
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
|
return custom_tools[func_def.name]
|
|
|
|
|
|
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
|
class_name = class_def.name
|
|
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
|
class_dict = {}
|
|
|
|
for stmt in class_def.body:
|
|
if isinstance(stmt, ast.FunctionDef):
|
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
|
elif isinstance(stmt, ast.Assign):
|
|
for target in stmt.targets:
|
|
if isinstance(target, ast.Name):
|
|
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
|
elif isinstance(target, ast.Attribute):
|
|
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
|
else:
|
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
|
|
|
new_class = type(class_name, tuple(bases), class_dict)
|
|
state[class_name] = new_class
|
|
return new_class
|
|
|
|
|
|
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
|
# Helper function to get current value and set new value based on the target type
|
|
def get_current_value(target):
|
|
if isinstance(target, ast.Name):
|
|
return state.get(target.id, 0)
|
|
elif isinstance(target, ast.Subscript):
|
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
|
return obj[key]
|
|
elif isinstance(target, ast.Attribute):
|
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
|
return getattr(obj, target.attr)
|
|
elif isinstance(target, ast.Tuple):
|
|
return tuple(get_current_value(elt) for elt in target.elts)
|
|
elif isinstance(target, ast.List):
|
|
return [get_current_value(elt) for elt in target.elts]
|
|
else:
|
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
|
|
|
current_value = get_current_value(expression.target)
|
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
|
|
# Determine the operation and apply it
|
|
if isinstance(expression.op, ast.Add):
|
|
if isinstance(current_value, list):
|
|
if not isinstance(value_to_add, list):
|
|
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
|
updated_value = current_value + value_to_add
|
|
else:
|
|
updated_value = current_value + value_to_add
|
|
elif isinstance(expression.op, ast.Sub):
|
|
updated_value = current_value - value_to_add
|
|
elif isinstance(expression.op, ast.Mult):
|
|
updated_value = current_value * value_to_add
|
|
elif isinstance(expression.op, ast.Div):
|
|
updated_value = current_value / value_to_add
|
|
elif isinstance(expression.op, ast.Mod):
|
|
updated_value = current_value % value_to_add
|
|
elif isinstance(expression.op, ast.Pow):
|
|
updated_value = current_value**value_to_add
|
|
elif isinstance(expression.op, ast.FloorDiv):
|
|
updated_value = current_value // value_to_add
|
|
elif isinstance(expression.op, ast.BitAnd):
|
|
updated_value = current_value & value_to_add
|
|
elif isinstance(expression.op, ast.BitOr):
|
|
updated_value = current_value | value_to_add
|
|
elif isinstance(expression.op, ast.BitXor):
|
|
updated_value = current_value ^ value_to_add
|
|
elif isinstance(expression.op, ast.LShift):
|
|
updated_value = current_value << value_to_add
|
|
elif isinstance(expression.op, ast.RShift):
|
|
updated_value = current_value >> value_to_add
|
|
else:
|
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
|
|
|
# Update the state
|
|
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
|
|
|
return updated_value
|
|
|
|
|
|
def evaluate_boolop(node, state, static_tools, custom_tools):
|
|
if isinstance(node.op, ast.And):
|
|
for value in node.values:
|
|
if not evaluate_ast(value, state, static_tools, custom_tools):
|
|
return False
|
|
return True
|
|
elif isinstance(node.op, ast.Or):
|
|
for value in node.values:
|
|
if evaluate_ast(value, state, static_tools, custom_tools):
|
|
return True
|
|
return False
|
|
|
|
|
|
def evaluate_binop(binop, state, static_tools, custom_tools):
|
|
# Recursively evaluate the left and right operands
|
|
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
|
|
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
|
|
|
|
# Determine the operation based on the type of the operator in the BinOp
|
|
if isinstance(binop.op, ast.Add):
|
|
return left_val + right_val
|
|
elif isinstance(binop.op, ast.Sub):
|
|
return left_val - right_val
|
|
elif isinstance(binop.op, ast.Mult):
|
|
return left_val * right_val
|
|
elif isinstance(binop.op, ast.Div):
|
|
return left_val / right_val
|
|
elif isinstance(binop.op, ast.Mod):
|
|
return left_val % right_val
|
|
elif isinstance(binop.op, ast.Pow):
|
|
return left_val**right_val
|
|
elif isinstance(binop.op, ast.FloorDiv):
|
|
return left_val // right_val
|
|
elif isinstance(binop.op, ast.BitAnd):
|
|
return left_val & right_val
|
|
elif isinstance(binop.op, ast.BitOr):
|
|
return left_val | right_val
|
|
elif isinstance(binop.op, ast.BitXor):
|
|
return left_val ^ right_val
|
|
elif isinstance(binop.op, ast.LShift):
|
|
return left_val << right_val
|
|
elif isinstance(binop.op, ast.RShift):
|
|
return left_val >> right_val
|
|
else:
|
|
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
|
|
|
|
|
def evaluate_assign(assign, state, static_tools, custom_tools):
|
|
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
|
|
if len(assign.targets) == 1:
|
|
target = assign.targets[0]
|
|
set_value(target, result, state, static_tools, custom_tools)
|
|
else:
|
|
if len(assign.targets) != len(result):
|
|
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
|
expanded_values = []
|
|
for tgt in assign.targets:
|
|
if isinstance(tgt, ast.Starred):
|
|
expanded_values.extend(result)
|
|
else:
|
|
expanded_values.append(result)
|
|
for tgt, val in zip(assign.targets, expanded_values):
|
|
set_value(tgt, val, state, static_tools, custom_tools)
|
|
return result
|
|
|
|
|
|
def set_value(target, value, state, static_tools, custom_tools):
|
|
if isinstance(target, ast.Name):
|
|
if target.id in static_tools:
|
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
|
state[target.id] = value
|
|
elif isinstance(target, ast.Tuple):
|
|
if not isinstance(value, tuple):
|
|
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
|
value = tuple(value)
|
|
else:
|
|
raise InterpreterError("Cannot unpack non-tuple value")
|
|
if len(target.elts) != len(value):
|
|
raise InterpreterError("Cannot unpack tuple of wrong size")
|
|
for i, elem in enumerate(target.elts):
|
|
set_value(elem, value[i], state, static_tools, custom_tools)
|
|
elif isinstance(target, ast.Subscript):
|
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
|
obj[key] = value
|
|
elif isinstance(target, ast.Attribute):
|
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
|
setattr(obj, target.attr, value)
|
|
|
|
|
|
def evaluate_call(call, state, static_tools, custom_tools):
|
|
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
|
if isinstance(call.func, ast.Attribute):
|
|
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
|
|
func_name = call.func.attr
|
|
if not hasattr(obj, func_name):
|
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
|
func = getattr(obj, func_name)
|
|
|
|
elif isinstance(call.func, ast.Name):
|
|
func_name = call.func.id
|
|
if func_name in state:
|
|
func = state[func_name]
|
|
elif func_name in static_tools:
|
|
func = static_tools[func_name]
|
|
elif func_name in custom_tools:
|
|
func = custom_tools[func_name]
|
|
elif func_name in ERRORS:
|
|
func = ERRORS[func_name]
|
|
else:
|
|
raise InterpreterError(
|
|
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
|
)
|
|
|
|
args = []
|
|
for arg in call.args:
|
|
if isinstance(arg, ast.Starred):
|
|
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
|
|
else:
|
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
|
|
|
args = []
|
|
for arg in call.args:
|
|
if isinstance(arg, ast.Starred):
|
|
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
|
|
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
|
|
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
|
|
args.extend(unpacked)
|
|
else:
|
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
|
|
|
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
|
|
|
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
|
# Instantiate the class using its constructor
|
|
obj = func.__new__(func) # Create a new instance of the class
|
|
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
|
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
|
|
return obj
|
|
else:
|
|
if func_name == "super":
|
|
if not args:
|
|
if "__class__" in state and "self" in state:
|
|
return super(state["__class__"], state["self"])
|
|
else:
|
|
raise InterpreterError("super() needs at least one argument")
|
|
cls = args[0]
|
|
if not isinstance(cls, type):
|
|
raise InterpreterError("super() argument 1 must be type")
|
|
if len(args) == 1:
|
|
return super(cls)
|
|
elif len(args) == 2:
|
|
instance = args[1]
|
|
return super(cls, instance)
|
|
else:
|
|
raise InterpreterError("super() takes at most 2 arguments")
|
|
else:
|
|
if func_name == "print":
|
|
output = " ".join(map(str, args))
|
|
global PRINT_OUTPUTS
|
|
PRINT_OUTPUTS += output + "\n"
|
|
# cap the number of lines
|
|
return None
|
|
else: # Assume it's a callable object
|
|
output = func(*args, **kwargs)
|
|
return output
|
|
|
|
|
|
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
|
|
|
if isinstance(value, str) and isinstance(index, str):
|
|
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
|
if isinstance(value, pd.core.indexing._LocIndexer):
|
|
parent_object = value.obj
|
|
return parent_object.loc[index]
|
|
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
|
|
return value[index]
|
|
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
|
|
return value[index]
|
|
elif isinstance(index, slice):
|
|
return value[index]
|
|
elif isinstance(value, (list, tuple)):
|
|
if not (-len(value) <= index < len(value)):
|
|
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
|
return value[int(index)]
|
|
elif isinstance(value, str):
|
|
if not (-len(value) <= index < len(value)):
|
|
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
|
return value[index]
|
|
elif index in value:
|
|
return value[index]
|
|
elif isinstance(index, str) and isinstance(value, Mapping):
|
|
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
|
if len(close_matches) > 0:
|
|
return value[close_matches[0]]
|
|
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
|
|
|
|
|
def evaluate_name(name, state, static_tools, custom_tools):
|
|
if name.id in state:
|
|
return state[name.id]
|
|
elif name.id in static_tools:
|
|
return static_tools[name.id]
|
|
elif name.id in ERRORS:
|
|
return ERRORS[name.id]
|
|
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
|
if len(close_matches) > 0:
|
|
return state[close_matches[0]]
|
|
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
|
|
|
|
|
def evaluate_condition(condition, state, static_tools, custom_tools):
|
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
|
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
|
ops = [type(op) for op in condition.ops]
|
|
|
|
result = True
|
|
current_left = left
|
|
|
|
for op, comparator in zip(ops, comparators):
|
|
if op == ast.Eq:
|
|
current_result = current_left == comparator
|
|
elif op == ast.NotEq:
|
|
current_result = current_left != comparator
|
|
elif op == ast.Lt:
|
|
current_result = current_left < comparator
|
|
elif op == ast.LtE:
|
|
current_result = current_left <= comparator
|
|
elif op == ast.Gt:
|
|
current_result = current_left > comparator
|
|
elif op == ast.GtE:
|
|
current_result = current_left >= comparator
|
|
elif op == ast.Is:
|
|
current_result = current_left is comparator
|
|
elif op == ast.IsNot:
|
|
current_result = current_left is not comparator
|
|
elif op == ast.In:
|
|
current_result = current_left in comparator
|
|
elif op == ast.NotIn:
|
|
current_result = current_left not in comparator
|
|
else:
|
|
raise InterpreterError(f"Operator not supported: {op}")
|
|
|
|
result = result & current_result
|
|
current_left = comparator
|
|
|
|
if isinstance(result, bool) and not result:
|
|
break
|
|
|
|
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
|
|
|
|
|
def evaluate_if(if_statement, state, static_tools, custom_tools):
|
|
result = None
|
|
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
|
|
if test_result:
|
|
for line in if_statement.body:
|
|
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
|
if line_result is not None:
|
|
result = line_result
|
|
else:
|
|
for line in if_statement.orelse:
|
|
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
|
if line_result is not None:
|
|
result = line_result
|
|
return result
|
|
|
|
|
|
def evaluate_for(for_loop, state, static_tools, custom_tools):
|
|
result = None
|
|
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
|
|
for counter in iterator:
|
|
set_value(for_loop.target, counter, state, static_tools, custom_tools)
|
|
for node in for_loop.body:
|
|
try:
|
|
line_result = evaluate_ast(node, state, static_tools, custom_tools)
|
|
if line_result is not None:
|
|
result = line_result
|
|
except BreakException:
|
|
break
|
|
except ContinueException:
|
|
continue
|
|
else:
|
|
continue
|
|
break
|
|
return result
|
|
|
|
|
|
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
|
def inner_evaluate(generators, index, current_state):
|
|
if index >= len(generators):
|
|
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
|
generator = generators[index]
|
|
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
|
result = []
|
|
for value in iter_value:
|
|
new_state = current_state.copy()
|
|
if isinstance(generator.target, ast.Tuple):
|
|
for idx, elem in enumerate(generator.target.elts):
|
|
new_state[elem.id] = value[idx]
|
|
else:
|
|
new_state[generator.target.id] = value
|
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
|
result.extend(inner_evaluate(generators, index + 1, new_state))
|
|
return result
|
|
|
|
return inner_evaluate(listcomp.generators, 0, state)
|
|
|
|
|
|
def evaluate_try(try_node, state, static_tools, custom_tools):
|
|
try:
|
|
for stmt in try_node.body:
|
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
|
except Exception as e:
|
|
matched = False
|
|
for handler in try_node.handlers:
|
|
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
|
matched = True
|
|
if handler.name:
|
|
state[handler.name] = e
|
|
for stmt in handler.body:
|
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
|
break
|
|
if not matched:
|
|
raise e
|
|
else:
|
|
if try_node.orelse:
|
|
for stmt in try_node.orelse:
|
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
|
finally:
|
|
if try_node.finalbody:
|
|
for stmt in try_node.finalbody:
|
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
|
|
|
|
|
def evaluate_raise(raise_node, state, static_tools, custom_tools):
|
|
if raise_node.exc is not None:
|
|
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
|
|
else:
|
|
exc = None
|
|
if raise_node.cause is not None:
|
|
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
|
|
else:
|
|
cause = None
|
|
if exc is not None:
|
|
if cause is not None:
|
|
raise exc from cause
|
|
else:
|
|
raise exc
|
|
else:
|
|
raise InterpreterError("Re-raise is not supported without an active exception")
|
|
|
|
|
|
def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
|
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
|
|
if not test_result:
|
|
if assert_node.msg:
|
|
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
|
|
raise AssertionError(msg)
|
|
else:
|
|
# Include the failing condition in the assertion message
|
|
test_code = ast.unparse(assert_node.test)
|
|
raise AssertionError(f"Assertion failed: {test_code}")
|
|
|
|
|
|
def evaluate_with(with_node, state, static_tools, custom_tools):
|
|
contexts = []
|
|
for item in with_node.items:
|
|
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
|
if item.optional_vars:
|
|
state[item.optional_vars.id] = context_expr.__enter__()
|
|
contexts.append(state[item.optional_vars.id])
|
|
else:
|
|
context_var = context_expr.__enter__()
|
|
contexts.append(context_var)
|
|
|
|
try:
|
|
for stmt in with_node.body:
|
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
|
except Exception as e:
|
|
for context in reversed(contexts):
|
|
context.__exit__(type(e), e, e.__traceback__)
|
|
raise
|
|
else:
|
|
for context in reversed(contexts):
|
|
context.__exit__(None, None, None)
|
|
|
|
|
|
def import_modules(expression, state, authorized_imports):
|
|
def check_module_authorized(module_name):
|
|
module_path = module_name.split(".")
|
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
|
|
|
if isinstance(expression, ast.Import):
|
|
for alias in expression.names:
|
|
if check_module_authorized(alias.name):
|
|
module = import_module(alias.name)
|
|
state[alias.asname or alias.name] = module
|
|
else:
|
|
raise InterpreterError(
|
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
|
)
|
|
return None
|
|
elif isinstance(expression, ast.ImportFrom):
|
|
if check_module_authorized(expression.module):
|
|
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
|
for alias in expression.names:
|
|
state[alias.asname or alias.name] = getattr(module, alias.name)
|
|
else:
|
|
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
|
return None
|
|
|
|
|
|
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
|
result = {}
|
|
for gen in dictcomp.generators:
|
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
|
|
for value in iter_value:
|
|
new_state = state.copy()
|
|
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
|
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
|
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
|
result[key] = val
|
|
return result
|
|
|
|
|
|
def evaluate_ast(
|
|
expression: ast.AST,
|
|
state: Dict[str, Any],
|
|
static_tools: Dict[str, Callable],
|
|
custom_tools: Dict[str, Callable],
|
|
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
|
):
|
|
"""
|
|
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
|
set of functions.
|
|
|
|
This function will recurse trough the nodes of the tree provided.
|
|
|
|
Args:
|
|
expression (`ast.AST`):
|
|
The code to evaluate, as an abstract syntax tree.
|
|
state (`Dict[str, Any]`):
|
|
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
|
encounters assignements.
|
|
static_tools (`Dict[str, Callable]`):
|
|
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
|
custom_tools (`Dict[str, Callable]`):
|
|
Functions that may be called during the evaluation. These static_tools can be overwritten.
|
|
authorized_imports (`List[str]`):
|
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
|
Add more at your own risk!
|
|
"""
|
|
global OPERATIONS_COUNT
|
|
if OPERATIONS_COUNT >= MAX_OPERATIONS:
|
|
raise InterpreterError(
|
|
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
|
)
|
|
OPERATIONS_COUNT += 1
|
|
if isinstance(expression, ast.Assign):
|
|
# Assignement -> we evaluate the assignment which should update the state
|
|
# We return the variable assigned as it may be used to determine the final result.
|
|
return evaluate_assign(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.AugAssign):
|
|
return evaluate_augassign(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Call):
|
|
# Function call -> we return the value of the function call
|
|
return evaluate_call(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Constant):
|
|
# Constant -> just return the value
|
|
return expression.value
|
|
elif isinstance(expression, ast.Tuple):
|
|
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
|
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.UnaryOp):
|
|
return evaluate_unaryop(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Starred):
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.BoolOp):
|
|
# Boolean operation -> evaluate the operation
|
|
return evaluate_boolop(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Break):
|
|
raise BreakException()
|
|
elif isinstance(expression, ast.Continue):
|
|
raise ContinueException()
|
|
elif isinstance(expression, ast.BinOp):
|
|
# Binary operation -> execute operation
|
|
return evaluate_binop(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Compare):
|
|
# Comparison -> evaluate the comparison
|
|
return evaluate_condition(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Lambda):
|
|
return evaluate_lambda(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.FunctionDef):
|
|
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Dict):
|
|
# Dict -> evaluate all keys and values
|
|
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
|
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
|
return dict(zip(keys, values))
|
|
elif isinstance(expression, ast.Expr):
|
|
# Expression -> evaluate the content
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.For):
|
|
# For loop -> execute the loop
|
|
return evaluate_for(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.FormattedValue):
|
|
# Formatted value (part of f-string) -> evaluate the content and return
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.If):
|
|
# If -> execute the right branch
|
|
return evaluate_if(expression, state, static_tools, custom_tools)
|
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.JoinedStr):
|
|
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
|
elif isinstance(expression, ast.List):
|
|
# List -> evaluate all elements
|
|
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
|
elif isinstance(expression, ast.Name):
|
|
# Name -> pick up the value in the state
|
|
return evaluate_name(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Subscript):
|
|
# Subscript -> return the value of the indexing
|
|
return evaluate_subscript(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.IfExp):
|
|
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
|
|
if test_val:
|
|
return evaluate_ast(expression.body, state, static_tools, custom_tools)
|
|
else:
|
|
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Attribute):
|
|
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
|
return getattr(value, expression.attr)
|
|
elif isinstance(expression, ast.Slice):
|
|
return slice(
|
|
evaluate_ast(expression.lower, state, static_tools, custom_tools)
|
|
if expression.lower is not None
|
|
else None,
|
|
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
|
if expression.upper is not None
|
|
else None,
|
|
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
|
)
|
|
elif isinstance(expression, ast.DictComp):
|
|
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.While):
|
|
return evaluate_while(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
|
return import_modules(expression, state, authorized_imports)
|
|
elif isinstance(expression, ast.ClassDef):
|
|
return evaluate_class_def(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Try):
|
|
return evaluate_try(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Raise):
|
|
return evaluate_raise(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Assert):
|
|
return evaluate_assert(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.With):
|
|
return evaluate_with(expression, state, static_tools, custom_tools)
|
|
elif isinstance(expression, ast.Set):
|
|
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
|
elif isinstance(expression, ast.Return):
|
|
raise ReturnException(
|
|
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
|
)
|
|
else:
|
|
# For now we refuse anything else. Let's add things as we need them.
|
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
|
|
|
|
|
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
|
|
if len(print_outputs) < max_len_outputs:
|
|
return print_outputs
|
|
else:
|
|
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
|
|
|
|
|
def evaluate_python_code(
|
|
code: str,
|
|
static_tools: Optional[Dict[str, Callable]] = None,
|
|
custom_tools: Optional[Dict[str, Callable]] = None,
|
|
state: Optional[Dict[str, Any]] = None,
|
|
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
|
):
|
|
"""
|
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
|
of functions.
|
|
|
|
This function will recurse through the nodes of the tree provided.
|
|
|
|
Args:
|
|
code (`str`):
|
|
The code to evaluate.
|
|
static_tools (`Dict[str, Callable]`):
|
|
The functions that may be called during the evaluation.
|
|
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
|
custom_tools (`Dict[str, Callable]`):
|
|
The functions that may be called during the evaluation.
|
|
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
|
state (`Dict[str, Any]`):
|
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
|
updated by this function to contain all variables as they are evaluated.
|
|
The print outputs will be stored in the state under the key 'print_outputs'.
|
|
"""
|
|
try:
|
|
expression = ast.parse(code)
|
|
except SyntaxError as e:
|
|
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
|
if state is None:
|
|
state = {}
|
|
if static_tools is None:
|
|
static_tools = {}
|
|
if custom_tools is None:
|
|
custom_tools = {}
|
|
result = None
|
|
global PRINT_OUTPUTS
|
|
PRINT_OUTPUTS = ""
|
|
global OPERATIONS_COUNT
|
|
OPERATIONS_COUNT = 0
|
|
try:
|
|
for node in expression.body:
|
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
|
state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
|
return result
|
|
except InterpreterError as e:
|
|
msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
|
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
|
raise InterpreterError(msg)
|