Refactor evaluate ast to improve readability (#625)

This commit is contained in:
CalOmnie 2025-02-18 09:37:40 +01:00 committed by GitHub
parent f631c7521e
commit 4e05fab51d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 41 additions and 54 deletions

View File

@ -1179,126 +1179,113 @@ def evaluate_ast(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
) )
state["_operations_count"] += 1 state["_operations_count"] += 1
common_params = (state, static_tools, custom_tools, authorized_imports)
if isinstance(expression, ast.Assign): if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state # Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result. # We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_assign(expression, *common_params)
elif isinstance(expression, ast.AugAssign): elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_augassign(expression, *common_params)
elif isinstance(expression, ast.Call): elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call # Function call -> we return the value of the function call
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_call(expression, *common_params)
elif isinstance(expression, ast.Constant): elif isinstance(expression, ast.Constant):
# Constant -> just return the value # Constant -> just return the value
return expression.value return expression.value
elif isinstance(expression, ast.Tuple): elif isinstance(expression, ast.Tuple):
return tuple( return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_listcomp(expression, *common_params)
elif isinstance(expression, ast.UnaryOp): elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_unaryop(expression, *common_params)
elif isinstance(expression, ast.Starred): elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.BoolOp): elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation # Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_boolop(expression, *common_params)
elif isinstance(expression, ast.Break): elif isinstance(expression, ast.Break):
raise BreakException() raise BreakException()
elif isinstance(expression, ast.Continue): elif isinstance(expression, ast.Continue):
raise ContinueException() raise ContinueException()
elif isinstance(expression, ast.BinOp): elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation # Binary operation -> execute operation
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_binop(expression, *common_params)
elif isinstance(expression, ast.Compare): elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison # Comparison -> evaluate the comparison
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_condition(expression, *common_params)
elif isinstance(expression, ast.Lambda): elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_lambda(expression, *common_params)
elif isinstance(expression, ast.FunctionDef): elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_function_def(expression, *common_params)
elif isinstance(expression, ast.Dict): elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values # Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys] keys = (evaluate_ast(k, *common_params) for k in expression.keys)
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values] values = (evaluate_ast(v, *common_params) for v in expression.values)
return dict(zip(keys, values)) return dict(zip(keys, values))
elif isinstance(expression, ast.Expr): elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content # Expression -> evaluate the content
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.For): elif isinstance(expression, ast.For):
# For loop -> execute the loop # For loop -> execute the loop
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_for(expression, *common_params)
elif isinstance(expression, ast.FormattedValue): elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return # Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.If): elif isinstance(expression, ast.If):
# If -> execute the right branch # If -> execute the right branch
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_if(expression, *common_params)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.JoinedStr): elif isinstance(expression, ast.JoinedStr):
return "".join( return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values])
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
)
elif isinstance(expression, ast.List): elif isinstance(expression, ast.List):
# List -> evaluate all elements # List -> evaluate all elements
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts] return [evaluate_ast(elt, *common_params) for elt in expression.elts]
elif isinstance(expression, ast.Name): elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state # Name -> pick up the value in the state
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_name(expression, *common_params)
elif isinstance(expression, ast.Subscript): elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing # Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_subscript(expression, *common_params)
elif isinstance(expression, ast.IfExp): elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports) test_val = evaluate_ast(expression.test, *common_params)
if test_val: if test_val:
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.body, *common_params)
else: else:
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports) return evaluate_ast(expression.orelse, *common_params)
elif isinstance(expression, ast.Attribute): elif isinstance(expression, ast.Attribute):
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) value = evaluate_ast(expression.value, *common_params)
return getattr(value, expression.attr) return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice): elif isinstance(expression, ast.Slice):
return slice( return slice(
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports) evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None,
if expression.lower is not None evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
else None, evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
if expression.step is not None
else None,
) )
elif isinstance(expression, ast.DictComp): elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.While): elif isinstance(expression, ast.While):
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_while(expression, *common_params)
elif isinstance(expression, (ast.Import, ast.ImportFrom)): elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports) return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef): elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_class_def(expression, *common_params)
elif isinstance(expression, ast.Try): elif isinstance(expression, ast.Try):
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_try(expression, *common_params)
elif isinstance(expression, ast.Raise): elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_raise(expression, *common_params)
elif isinstance(expression, ast.Assert): elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_assert(expression, *common_params)
elif isinstance(expression, ast.With): elif isinstance(expression, ast.With):
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_with(expression, *common_params)
elif isinstance(expression, ast.Set): elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts} return set((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, ast.Return): elif isinstance(expression, ast.Return):
raise ReturnException( raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None)
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
if expression.value
else None
)
elif isinstance(expression, ast.Pass): elif isinstance(expression, ast.Pass):
return None return None
elif isinstance(expression, ast.Delete): elif isinstance(expression, ast.Delete):
return evaluate_delete(expression, state, static_tools, custom_tools, authorized_imports) return evaluate_delete(expression, *common_params)
else: else:
# For now we refuse anything else. Let's add things as we need them. # For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") raise InterpreterError(f"{expression.__class__.__name__} is not supported.")