fix final_answer issue in e2b_executor (#319)
This commit is contained in:
parent
2105811da6
commit
33b38e6cb7
|
@ -14,6 +14,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
import base64
|
||||
import pickle
|
||||
import textwrap
|
||||
|
@ -45,6 +46,8 @@ class E2BExecutor:
|
|||
)
|
||||
|
||||
self.custom_tools = {}
|
||||
self.final_answer = False
|
||||
self.final_answer_pattern = re.compile(r'^final_answer\((.*)\)$')
|
||||
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
||||
# TODO: validate installing agents package or not
|
||||
# print("Installing agents package on remote executor...")
|
||||
|
@ -85,6 +88,8 @@ class E2BExecutor:
|
|||
self.logger.log(tool_definition_execution.logs)
|
||||
|
||||
def run_code_raise_errors(self, code: str):
|
||||
if self.final_answer_pattern.match(code):
|
||||
self.final_answer = True
|
||||
execution = self.sbx.run_code(
|
||||
code,
|
||||
)
|
||||
|
@ -122,7 +127,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
|||
execution = self.run_code_raise_errors(code_action)
|
||||
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||
if not execution.results:
|
||||
return None, execution_logs
|
||||
return None, execution_logs, self.final_answer
|
||||
else:
|
||||
for result in execution.results:
|
||||
if result.is_main_result:
|
||||
|
@ -144,7 +149,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
|||
"text",
|
||||
]:
|
||||
if getattr(result, attribute_name) is not None:
|
||||
return getattr(result, attribute_name), execution_logs
|
||||
return getattr(result, attribute_name), execution_logs, self.final_answer
|
||||
raise ValueError("No main result returned by executor!")
|
||||
|
||||
|
||||
|
|
|
@ -25,6 +25,9 @@ class MethodChecker(ast.NodeVisitor):
|
|||
self.class_attributes = class_attributes
|
||||
self.errors = []
|
||||
self.check_imports = check_imports
|
||||
self.typing_names = {
|
||||
'Any'
|
||||
}
|
||||
|
||||
def visit_arguments(self, node):
|
||||
"""Collect function arguments"""
|
||||
|
@ -97,6 +100,7 @@ class MethodChecker(ast.NodeVisitor):
|
|||
or node.id in self.imports
|
||||
or node.id in self.from_imports
|
||||
or node.id in self.assigned_names
|
||||
or node.id in self.typing_names
|
||||
):
|
||||
self.errors.append(f"Name '{node.id}' is undefined.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue