diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 994fda75df6..e262d67f3c7 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -49,6 +49,10 @@ class ChatNodeSerializer(serializers.Serializer): label=_("Skill IDs"), ) mcp_output_enable = serializers.BooleanField(required=False, default=True, label=_("Whether to enable MCP output")) + video_list = serializers.ListField(required=False, label=_("video")) + + image_list = serializers.ListField(required=False, label=_("picture")) + class IChatNode(INode): type = 'ai-chat-node' diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index f32eb49e19b..cdd8a5f0af8 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,10 +6,12 @@ @date:2024/6/4 14:30 @desc: """ +import base64 import json import re import time from functools import reduce +from imghdr import what from typing import List, Dict from django.db.models import QuerySet @@ -25,6 +27,7 @@ from common.utils.rsa_util import rsa_long_decrypt from common.utils.shared_resource_auth import filter_authorized_ids from common.utils.tool_code import ToolExecutor +from knowledge.models import File from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id from tools.models import Tool, ToolType @@ -196,11 +199,11 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record self.runtime_node_id) self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in (history_message if history_message is not None else [])] - question = self.generate_prompt_question(prompt) + question = self.generate_prompt_question(prompt, chat_model) self.context['question'] = question.content system = self.workflow_manage.generate_prompt(system) self.context['system'] = system - message_list = self.generate_message_list(prompt, history_message) + message_list = self.generate_message_list(question, history_message) self.context['message_list'] = message_list # 过滤tool_id @@ -386,11 +389,73 @@ def get_history_message(history_chat_record, dialogue_number, dialogue_type, run message.content = re.sub(r'.*?<\/form_rander>', '', message.content, flags=re.DOTALL) return history_message - def generate_prompt_question(self, prompt): - return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + def generate_prompt_question(self, prompt, model): + image = self.get_image() + video = self.get_video() + videos = [] + images = [] + if image: + images = self._process_images(image) + if video: + videos = self._process_videos(video, model) + return HumanMessage( + content=[*videos, *images, {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)}]) + + def get_image(self): + if 'image_list' in self.node_params_serializer.data: + image = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0], + self.node_params_serializer.data.get('image_list')[1:]) + return image + return None + + def get_video(self): + if 'video_list' in self.node_params_serializer.data: + video = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('video_list')[0], + self.node_params_serializer.data.get('video_list')[1:]) + return video + return None - def generate_message_list(self, prompt: str, history_message): - return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + def _process_videos(self, image, video_model): + videos = [] + if isinstance(image, str) and image.startswith('http'): + videos.append({'type': 'video_url', 'video_url': {'url': image}}) + elif image is not None and len(image) > 0: + for img in image: + if 'file_id' in img: + file_id = img['file_id'] + file = QuerySet(File).filter(id=file_id).first() + url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name) + videos.append( + {'type': 'video_url', 'video_url': {'url': url}}) + elif 'url' in img and img['url'].startswith('http'): + videos.append( + {'type': 'video_url', 'video_url': {'url': img['url']}}) + return videos + + def _process_images(self, image): + """ + 处理图像数据,转换为模型可识别的格式 + """ + images = [] + if isinstance(image, str) and image.startswith('http'): + images.append({'type': 'image_url', 'image_url': {'url': image}}) + elif image is not None and len(image) > 0: + for img in image: + if 'file_id' in img: + file_id = img['file_id'] + file = QuerySet(File).filter(id=file_id).first() + image_bytes = file.get_bytes() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_format = what(None, image_bytes) + images.append( + {'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) + elif 'url' in img and img['url'].startswith('http'): + images.append( + {'type': 'image_url', 'image_url': {'url': img["url"]}}) + return images + + def generate_message_list(self, question, history_message): + return [*history_message, question] @staticmethod def reset_message_list(message_list: List[BaseMessage], answer_text): diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index a7777706abf..7844909f1cf 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -170,7 +170,60 @@ :step-strictly="true" /> - + +
+
+ {{ $t('workflow.nodes.imageUnderstandNode.image.label') }} +
+ + + + +
+ +
+ +
+
+ {{ $t('workflow.nodes.videoUnderstandNode.video.label') }} +
+ + + + +
+ +
{{ $t('views.tool.skill.title') }} @@ -555,6 +608,7 @@ import { resetUrl } from '@/utils/common' import { relatedObject } from '@/utils/array.ts' import { WorkflowMode } from '@/enums/application' import ApplicationDialog from '@/views/application/component/ApplicationDialog.vue' +import { fileTooltip } from '@/workflow/common/data.ts' const workflowMode = (inject('workflowMode') as WorkflowMode) || WorkflowMode.Application const getResourceDetail = inject('getResourceDetail') as any const route = useRoute() @@ -604,7 +658,6 @@ const defaultPrompt = `${t('workflow.nodes.aiChatNode.defaultPrompt')}: ${t('views.problem.title')}: {{${t('workflow.nodes.startNode.label')}.question}}` - const collapseData = reactive({ MCP: true, tool: true,