diff --git a/.gitattributes b/.gitattributes index dcf74f06de2455294ac1618b78564a1c3ce54d97..795574b32ce3a0b061b5bf1d6eebe5be2f50b20b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,3 +5,4 @@ # them. *.sh text eol=lf +api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..68f3c65a4b93c54f2b4be32ffbd1ae3e339bc8f1 --- /dev/null +++ b/api/docker/entrypoint.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -e + +if [[ "${MIGRATION_ENABLED}" == "true" ]]; then + echo "Running migrations" + flask upgrade-db +fi + +if [[ "${MODE}" == "worker" ]]; then + + # Get the number of available CPU cores + if [ "${CELERY_AUTO_SCALE,,}" = "true" ]; then + # Set MAX_WORKERS to the number of available cores if not specified + AVAILABLE_CORES=$(nproc) + MAX_WORKERS=${CELERY_MAX_WORKERS:-$AVAILABLE_CORES} + MIN_WORKERS=${CELERY_MIN_WORKERS:-1} + CONCURRENCY_OPTION="--autoscale=${MAX_WORKERS},${MIN_WORKERS}" + else + CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" + fi + + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \ + -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion} + +elif [[ "${MODE}" == "beat" ]]; then + exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} +else + if [[ "${DEBUG}" == "true" ]]; then + exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + else + exec gunicorn \ + --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ + --workers ${SERVER_WORKER_AMOUNT:-1} \ + --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ + --timeout ${GUNICORN_TIMEOUT:-200} \ + app:app + fi +fi diff --git a/api/events/__init__.py b/api/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/events/app_event.py b/api/events/app_event.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ce71bbbb3632457a1cad9683f87cf4a94f8b63 --- /dev/null +++ b/api/events/app_event.py @@ -0,0 +1,13 @@ +from blinker import signal + +# sender: app +app_was_created = signal("app-was-created") + +# sender: app, kwargs: app_model_config +app_model_config_was_updated = signal("app-model-config-was-updated") + +# sender: app, kwargs: published_workflow +app_published_workflow_was_updated = signal("app-published-workflow-was-updated") + +# sender: app, kwargs: synced_draft_workflow +app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py new file mode 100644 index 0000000000000000000000000000000000000000..750b7424e2347b73386f759d68f3e76704253d2d --- /dev/null +++ b/api/events/dataset_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: dataset +dataset_was_deleted = signal("dataset-was-deleted") diff --git a/api/events/document_event.py b/api/events/document_event.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5a416a5e0c91fd6cc7370c8338b57f40dcc0f9 --- /dev/null +++ b/api/events/document_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: document +document_was_deleted = signal("document-was-deleted") diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6ad35333c0149810c1db3102db1935db6de1d5 --- /dev/null +++ b/api/events/event_handlers/__init__.py @@ -0,0 +1,10 @@ +from .clean_when_dataset_deleted import handle +from .clean_when_document_deleted import handle +from .create_document_index import handle +from .create_installed_app_when_app_created import handle +from .create_site_record_when_app_created import handle +from .deduct_quota_when_message_created import handle +from .delete_tool_parameters_cache_when_sync_draft_workflow import handle +from .update_app_dataset_join_when_app_model_config_updated import handle +from .update_app_dataset_join_when_app_published_workflow_updated import handle +from .update_provider_last_used_at_when_message_created import handle diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py new file mode 100644 index 0000000000000000000000000000000000000000..7caa2d1cc9f3f26e95b8b3d05067480d505ccf54 --- /dev/null +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -0,0 +1,15 @@ +from events.dataset_event import dataset_was_deleted +from tasks.clean_dataset_task import clean_dataset_task + + +@dataset_was_deleted.connect +def handle(sender, **kwargs): + dataset = sender + clean_dataset_task.delay( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + ) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py new file mode 100644 index 0000000000000000000000000000000000000000..00a66f50ad93192a4e21cbed7e9d23cf395c3592 --- /dev/null +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -0,0 +1,11 @@ +from events.document_event import document_was_deleted +from tasks.clean_document_task import clean_document_task + + +@document_was_deleted.connect +def handle(sender, **kwargs): + document_id = sender + dataset_id = kwargs.get("dataset_id") + doc_form = kwargs.get("doc_form") + file_id = kwargs.get("file_id") + clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py new file mode 100644 index 0000000000000000000000000000000000000000..8a677f6b6fc017f11e07866f06881961f34badb3 --- /dev/null +++ b/api/events/event_handlers/create_document_index.py @@ -0,0 +1,49 @@ +import datetime +import logging +import time + +import click +from werkzeug.exceptions import NotFound + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from events.event_handlers.document_index_event import document_index_created +from extensions.ext_database import db +from models.dataset import Document + + +@document_index_created.connect +def handle(sender, **kwargs): + dataset_id = sender + document_ids = kwargs.get("document_ids", []) + documents = [] + start_at = time.perf_counter() + for document_id in document_ids: + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + + document = ( + db.session.query(Document) + .filter( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) + .first() + ) + + if not document: + raise NotFound("Document not found") + + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + documents.append(document) + db.session.add(document) + db.session.commit() + + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py new file mode 100644 index 0000000000000000000000000000000000000000..57412cc4ad0af2d7a7484cb04c5e36874242b2c9 --- /dev/null +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -0,0 +1,16 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import InstalledApp + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create an installed app when an app is created.""" + app = sender + installed_app = InstalledApp( + tenant_id=app.tenant_id, + app_id=app.id, + app_owner_tenant_id=app.tenant_id, + ) + db.session.add(installed_app) + db.session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7caf8cbed71e6b1bcbf80405aba49e33c80f70 --- /dev/null +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -0,0 +1,26 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import Site + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create site record when an app is created.""" + app = sender + account = kwargs.get("account") + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) + + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py new file mode 100644 index 0000000000000000000000000000000000000000..d196a4862013b7d2c42556a2239e9d80aeda71c8 --- /dev/null +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -0,0 +1,53 @@ +from configs import dify_config +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity +from core.entities.provider_entities import QuotaUnit +from events.message_event import message_was_created +from extensions.ext_database import db +from models.provider import Provider, ProviderType + + +@message_was_created.connect +def handle(sender, **kwargs): + message = sender + application_generate_entity = kwargs.get("application_generate_entity") + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return + + model_config = application_generate_entity.model_conf + provider_model_bundle = model_config.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = message.message_tokens + message.answer_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = dify_config.get_model_credits(model_config.model) + else: + used_quota = 1 + + if used_quota is not None and system_configuration.current_quota_type is not None: + db.session.query(Provider).filter( + Provider.tenant_id == application_generate_entity.app_config.tenant_id, + Provider.provider_name == model_config.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) + db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..249bd144299c94d1188aedca4f91120acae9725f --- /dev/null +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -0,0 +1,34 @@ +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.nodes import NodeType +from core.workflow.nodes.tool.entities import ToolEntity +from events.app_event import app_draft_workflow_was_synced + + +@app_draft_workflow_was_synced.connect +def handle(sender, **kwargs): + app = sender + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): + if node_data.get("data", {}).get("type") == NodeType.TOOL.value: + try: + tool_entity = ToolEntity(**node_data["data"]) + tool_runtime = ToolManager.get_tool_runtime( + provider_type=tool_entity.provider_type, + provider_id=tool_entity.provider_id, + tool_name=tool_entity.tool_name, + tenant_id=app.tenant_id, + ) + manager = ToolParameterConfigurationManager( + tenant_id=app.tenant_id, + tool_runtime=tool_runtime, + provider_name=tool_entity.provider_name, + provider_type=tool_entity.provider_type, + identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", + ) + manager.delete_tool_parameters_cache() + except: + # tool dose not exist + pass diff --git a/api/events/event_handlers/document_index_event.py b/api/events/event_handlers/document_index_event.py new file mode 100644 index 0000000000000000000000000000000000000000..3d463fe5b35acf7c2a5249903508cf7bd252736d --- /dev/null +++ b/api/events/event_handlers/document_index_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: document +document_index_created = signal("document-index-created") diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py new file mode 100644 index 0000000000000000000000000000000000000000..14396e9920a2c2dd4c25f9cf0f65c8abf998b662 --- /dev/null +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -0,0 +1,68 @@ +from events.app_event import app_model_config_was_updated +from extensions.ext_database import db +from models.dataset import AppDatasetJoin +from models.model import AppModelConfig + + +@app_model_config_was_updated.connect +def handle(sender, **kwargs): + app = sender + app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return + + dataset_ids = get_dataset_ids_from_model_config(app_model_config) + + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + + removed_dataset_ids: set[str] = set() + if not app_dataset_joins: + added_dataset_ids = dataset_ids + else: + old_dataset_ids: set[str] = set() + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) + + added_dataset_ids = dataset_ids - old_dataset_ids + removed_dataset_ids = old_dataset_ids - dataset_ids + + if removed_dataset_ids: + for dataset_id in removed_dataset_ids: + db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id + ).delete() + + if added_dataset_ids: + for dataset_id in added_dataset_ids: + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) + db.session.add(app_dataset_join) + + db.session.commit() + + +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]: + dataset_ids: set[str] = set() + if not app_model_config: + return dataset_ids + + agent_mode = app_model_config.agent_mode_dict + + tools = agent_mode.get("tools", []) or [] + for tool in tools: + if len(list(tool.keys())) != 1: + continue + + tool_type = list(tool.keys())[0] + tool_config = list(tool.values())[0] + if tool_type == "dataset": + dataset_ids.add(tool_config.get("id")) + + # get dataset from dataset_configs + dataset_configs = app_model_config.dataset_configs_dict + datasets = dataset_configs.get("datasets", {}) or {} + for dataset in datasets.get("datasets", []) or []: + keys = list(dataset.keys()) + if len(keys) == 1 and keys[0] == "dataset": + if dataset["dataset"].get("id"): + dataset_ids.add(dataset["dataset"].get("id")) + + return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2efed94bca7f2440623d46a95c4906b241b03f --- /dev/null +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -0,0 +1,67 @@ +from typing import cast + +from core.workflow.nodes import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from events.app_event import app_published_workflow_was_updated +from extensions.ext_database import db +from models.dataset import AppDatasetJoin +from models.workflow import Workflow + + +@app_published_workflow_was_updated.connect +def handle(sender, **kwargs): + app = sender + published_workflow = kwargs.get("published_workflow") + published_workflow = cast(Workflow, published_workflow) + + dataset_ids = get_dataset_ids_from_workflow(published_workflow) + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + + removed_dataset_ids: set[str] = set() + if not app_dataset_joins: + added_dataset_ids = dataset_ids + else: + old_dataset_ids: set[str] = set() + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) + + added_dataset_ids = dataset_ids - old_dataset_ids + removed_dataset_ids = old_dataset_ids - dataset_ids + + if removed_dataset_ids: + for dataset_id in removed_dataset_ids: + db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id + ).delete() + + if added_dataset_ids: + for dataset_id in added_dataset_ids: + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) + db.session.add(app_dataset_join) + + db.session.commit() + + +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: + dataset_ids: set[str] = set() + graph = published_workflow.graph_dict + if not graph: + return dataset_ids + + nodes = graph.get("nodes", []) + + # fetch all knowledge retrieval nodes + knowledge_retrieval_nodes = [ + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + ] + + if not knowledge_retrieval_nodes: + return dataset_ids + + for node in knowledge_retrieval_nodes: + try: + node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) + dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) + except Exception as e: + continue + + return dataset_ids diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py new file mode 100644 index 0000000000000000000000000000000000000000..f225ef8e880771235fa2e0284fe6493befb8d258 --- /dev/null +++ b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py @@ -0,0 +1,21 @@ +from datetime import UTC, datetime + +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity +from events.message_event import message_was_created +from extensions.ext_database import db +from models.provider import Provider + + +@message_was_created.connect +def handle(sender, **kwargs): + message = sender + application_generate_entity = kwargs.get("application_generate_entity") + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return + + db.session.query(Provider).filter( + Provider.tenant_id == application_generate_entity.app_config.tenant_id, + Provider.provider_name == application_generate_entity.model_conf.provider, + ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) + db.session.commit() diff --git a/api/events/message_event.py b/api/events/message_event.py new file mode 100644 index 0000000000000000000000000000000000000000..6576c35c453c9538f269861a9e05a276de8dc61a --- /dev/null +++ b/api/events/message_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: message, kwargs: conversation +message_was_created = signal("message-was-created") diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py new file mode 100644 index 0000000000000000000000000000000000000000..d99feaac40896d14c4cb01465cdb9089b8be8473 --- /dev/null +++ b/api/events/tenant_event.py @@ -0,0 +1,7 @@ +from blinker import signal + +# sender: tenant +tenant_was_created = signal("tenant-was-created") + +# sender: tenant +tenant_was_updated = signal("tenant-was-updated") diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d412d68deda107ebc341cf019123405d57b4ed --- /dev/null +++ b/api/extensions/ext_app_metrics.py @@ -0,0 +1,67 @@ +import json +import os +import threading + +from flask import Response + +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + @app.after_request + def after_request(response): + """Add Version headers to the response.""" + response.headers.add("X-Version", dify_config.CURRENT_VERSION) + response.headers.add("X-Env", dify_config.DEPLOY_ENV) + return response + + @app.route("/health") + def health(): + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), + status=200, + content_type="application/json", + ) + + @app.route("/threads") + def threads(): + num_threads = threading.active_count() + threads = threading.enumerate() + + thread_list = [] + for thread in threads: + thread_name = thread.name + thread_id = thread.ident + is_alive = thread.is_alive() + + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) + + return { + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, + } + + @app.route("/db-pool-stat") + def pool_stat(): + from extensions.ext_database import db + + engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue + return { + "pid": os.getpid(), + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore + } diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py new file mode 100644 index 0000000000000000000000000000000000000000..316be12f5c14b8ad006bd463b586239b7506f95f --- /dev/null +++ b/api/extensions/ext_blueprints.py @@ -0,0 +1,48 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + # register blueprint routers + + from flask_cors import CORS # type: ignore + + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py new file mode 100644 index 0000000000000000000000000000000000000000..26bd6b357712c9a5e499349fe8f685fe51669e78 --- /dev/null +++ b/api/extensions/ext_celery.py @@ -0,0 +1,104 @@ +from datetime import timedelta + +import pytz +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore + +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp) -> Celery: + class FlaskTask(Task): + def __call__(self, *args: object, **kwargs: object) -> object: + with app.app_context(): + return self.run(*args, **kwargs) + + broker_transport_options = {} + + if dify_config.CELERY_USE_SENTINEL: + broker_transport_options = { + "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, + "sentinel_kwargs": { + "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + }, + } + + celery_app = Celery( + app.name, + task_cls=FlaskTask, + broker=dify_config.CELERY_BROKER_URL, + backend=dify_config.CELERY_BACKEND, + task_ignore_result=True, + ) + + # Add SSL options to the Celery configuration + ssl_options = { + "ssl_cert_reqs": None, + "ssl_ca_certs": None, + "ssl_certfile": None, + "ssl_keyfile": None, + } + + celery_app.conf.update( + result_backend=dify_config.CELERY_RESULT_BACKEND, + broker_transport_options=broker_transport_options, + broker_connection_retry_on_startup=True, + worker_log_format=dify_config.LOG_FORMAT, + worker_task_log_format=dify_config.LOG_FORMAT, + worker_hijack_root_logger=False, + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), + ) + + if dify_config.BROKER_USE_SSL: + celery_app.conf.update( + broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration + ) + + if dify_config.LOG_FILE: + celery_app.conf.update( + worker_logfile=dify_config.LOG_FILE, + ) + + celery_app.set_default() + app.extensions["celery"] = celery_app + + imports = [ + "schedule.clean_embedding_cache_task", + "schedule.clean_unused_datasets_task", + "schedule.create_tidb_serverless_task", + "schedule.update_tidb_serverless_status_task", + "schedule.clean_messages", + "schedule.mail_clean_document_notify_task", + ] + day = dify_config.CELERY_BEAT_SCHEDULER_TIME + beat_schedule = { + "clean_embedding_cache_task": { + "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", + "schedule": timedelta(days=day), + }, + "clean_unused_datasets_task": { + "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", + "schedule": timedelta(days=day), + }, + "create_tidb_serverless_task": { + "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", + "schedule": crontab(minute="0", hour="*"), + }, + "update_tidb_serverless_status_task": { + "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", + "schedule": timedelta(minutes=10), + }, + "clean_messages": { + "task": "schedule.clean_messages.clean_messages", + "schedule": timedelta(days=day), + }, + # every Monday + "mail_clean_document_notify_task": { + "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", + "schedule": crontab(minute="0", hour="10", day_of_week="1"), + }, + } + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) + + return celery_app diff --git a/api/extensions/ext_code_based_extension.py b/api/extensions/ext_code_based_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4b4a41d917c2870210a799cc4437788739f0f7 --- /dev/null +++ b/api/extensions/ext_code_based_extension.py @@ -0,0 +1,9 @@ +from core.extension.extension import Extension +from dify_app import DifyApp + + +def init_app(app: DifyApp): + code_based_extension.init() + + +code_based_extension = Extension() diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf0d316ca486e5cbb997c0d44149ee9aff6860e --- /dev/null +++ b/api/extensions/ext_commands.py @@ -0,0 +1,29 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + from commands import ( + add_qdrant_doc_id_index, + convert_to_agent_apps, + create_tenant, + fix_app_site_missing, + reset_email, + reset_encrypt_key_pair, + reset_password, + upgrade_db, + vdb_migrate, + ) + + cmds_to_register = [ + reset_password, + reset_email, + reset_encrypt_key_pair, + vdb_migrate, + convert_to_agent_apps, + add_qdrant_doc_id_index, + create_tenant, + upgrade_db, + fix_app_site_missing, + ] + for cmd in cmds_to_register: + app.cli.add_command(cmd) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py new file mode 100644 index 0000000000000000000000000000000000000000..26ff6427bef1ccaaf40cc21d52dd283772fd983d --- /dev/null +++ b/api/extensions/ext_compress.py @@ -0,0 +1,13 @@ +from configs import dify_config +from dify_app import DifyApp + + +def is_enabled() -> bool: + return dify_config.API_COMPRESSION_ENABLED + + +def init_app(app: DifyApp): + from flask_compress import Compress # type: ignore + + compress = Compress() + compress.init_app(app) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py new file mode 100644 index 0000000000000000000000000000000000000000..93842a303683bb31d6f75d2934a0c28d2bdc862a --- /dev/null +++ b/api/extensions/ext_database.py @@ -0,0 +1,6 @@ +from dify_app import DifyApp +from models import db + + +def init_app(app: DifyApp): + db.init_app(app) diff --git a/api/extensions/ext_hosting_provider.py b/api/extensions/ext_hosting_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3980eccf8edc2aa56ca70834d3d24205da0ec752 --- /dev/null +++ b/api/extensions/ext_hosting_provider.py @@ -0,0 +1,10 @@ +from core.hosting_configuration import HostingConfiguration + +hosting_configuration = HostingConfiguration() + + +from dify_app import DifyApp + + +def init_app(app: DifyApp): + hosting_configuration.init_app(app) diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9566f430b647fe2dfaf774802a5239da68b26e92 --- /dev/null +++ b/api/extensions/ext_import_modules.py @@ -0,0 +1,5 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + from events import event_handlers # noqa: F401 diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9b492a506970aa0f05b2ea75228e3340889bed --- /dev/null +++ b/api/extensions/ext_logging.py @@ -0,0 +1,71 @@ +import logging +import os +import sys +import uuid +from logging.handlers import RotatingFileHandler + +import flask + +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + log_handlers: list[logging.Handler] = [] + log_file = dify_config.LOG_FILE + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers.append( + RotatingFileHandler( + filename=log_file, + maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024, + backupCount=dify_config.LOG_FILE_BACKUP_COUNT, + ) + ) + + # Always add StreamHandler to log to console + sh = logging.StreamHandler(sys.stdout) + sh.addFilter(RequestIdFilter()) + log_handlers.append(sh) + + logging.basicConfig( + level=dify_config.LOG_LEVEL, + format=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, + handlers=log_handlers, + force=True, + ) + log_tz = dify_config.LOG_TZ + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.fromtimestamp(seconds, tz=timezone).timetuple() + + for handler in logging.root.handlers: + if handler.formatter: + handler.formatter.converter = time_converter + + +def get_request_id(): + if getattr(flask.g, "request_id", None): + return flask.g.request_id + + new_uuid = uuid.uuid4().hex[:10] + flask.g.request_id = new_uuid + + return new_uuid + + +class RequestIdFilter(logging.Filter): + # This is a logging filter that makes the request ID available for use in + # the logging format. Note that we're checking if we're in a request + # context, as we may want to log things before Flask is fully loaded. + def filter(self, record): + record.req_id = get_request_id() if flask.has_request_context() else "" + return True diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py new file mode 100644 index 0000000000000000000000000000000000000000..10fb89eb7370eefd34a11e108099357c089d9ca1 --- /dev/null +++ b/api/extensions/ext_login.py @@ -0,0 +1,62 @@ +import json + +import flask_login # type: ignore +from flask import Response, request +from flask_login import user_loaded_from_request, user_logged_in +from werkzeug.exceptions import Unauthorized + +import contexts +from dify_app import DifyApp +from libs.passport import PassportService +from services.account_service import AccountService + +login_manager = flask_login.LoginManager() + + +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + return logged_in_account + + +@user_logged_in.connect +@user_loaded_from_request.connect +def on_user_logged_in(_sender, user): + """Called when a user logged in.""" + if user: + contexts.tenant_id.set(user.current_tenant_id) + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +def init_app(app: DifyApp): + login_manager.init_app(app) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py new file mode 100644 index 0000000000000000000000000000000000000000..9240ebe7fcba73f06e62c0ce7b2a0be338a55bb2 --- /dev/null +++ b/api/extensions/ext_mail.py @@ -0,0 +1,97 @@ +import logging +from typing import Optional + +from flask import Flask + +from configs import dify_config +from dify_app import DifyApp + + +class Mail: + def __init__(self): + self._client = None + self._default_send_from = None + + def is_inited(self) -> bool: + return self._client is not None + + def init_app(self, app: Flask): + mail_type = dify_config.MAIL_TYPE + if not mail_type: + logging.warning("MAIL_TYPE is not set") + return + + if dify_config.MAIL_DEFAULT_SEND_FROM: + self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM + + match mail_type: + case "resend": + import resend # type: ignore + + api_key = dify_config.RESEND_API_KEY + if not api_key: + raise ValueError("RESEND_API_KEY is not set") + + api_url = dify_config.RESEND_API_URL + if api_url: + resend.api_url = api_url + + resend.api_key = api_key + self._client = resend.Emails + case "smtp": + from libs.smtp import SMTPClient + + if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT: + raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") + if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS: + raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") + self._client = SMTPClient( + server=dify_config.SMTP_SERVER, + port=dify_config.SMTP_PORT, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", + use_tls=dify_config.SMTP_USE_TLS, + opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, + ) + case _: + raise ValueError("Unsupported mail type {}".format(mail_type)) + + def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): + if not self._client: + raise ValueError("Mail client is not initialized") + + if not from_ and self._default_send_from: + from_ = self._default_send_from + + if not from_: + raise ValueError("mail from is not set") + + if not to: + raise ValueError("mail to is not set") + + if not subject: + raise ValueError("mail subject is not set") + + if not html: + raise ValueError("mail html is not set") + + self._client.send( + { + "from": from_, + "to": to, + "subject": subject, + "html": html, + } + ) + + +def is_enabled() -> bool: + return dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + + +def init_app(app: DifyApp): + mail.init_app(app) + + +mail = Mail() diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py new file mode 100644 index 0000000000000000000000000000000000000000..5f862181fa8540d570fcb34b1d69eb07ea7c93a5 --- /dev/null +++ b/api/extensions/ext_migrate.py @@ -0,0 +1,9 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + import flask_migrate # type: ignore + + from extensions.ext_database import db + + flask_migrate.Migrate(app, db) diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..c085aed98643d33623ff383b5ff7f100c7d5c5de --- /dev/null +++ b/api/extensions/ext_proxy_fix.py @@ -0,0 +1,9 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: + from werkzeug.middleware.proxy_fix import ProxyFix + + app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py new file mode 100644 index 0000000000000000000000000000000000000000..da4180570793e5c60524f555aeaa89bb99a26a64 --- /dev/null +++ b/api/extensions/ext_redis.py @@ -0,0 +1,98 @@ +from typing import Any, Union + +import redis +from redis.cluster import ClusterNode, RedisCluster +from redis.connection import Connection, SSLConnection +from redis.sentinel import Sentinel + +from configs import dify_config +from dify_app import DifyApp + + +class RedisClientWrapper: + """ + A wrapper class for the Redis client that addresses the issue where the global + `redis_client` variable cannot be updated when a new Redis instance is returned + by Sentinel. + + This class allows for deferred initialization of the Redis client, enabling the + client to be re-initialized with a new instance when necessary. This is particularly + useful in scenarios where the Redis instance may change dynamically, such as during + a failover in a Sentinel-managed Redis setup. + + Attributes: + _client (redis.Redis): The actual Redis client instance. It remains None until + initialized with the `initialize` method. + + Methods: + initialize(client): Initializes the Redis client if it hasn't been initialized already. + __getattr__(item): Delegates attribute access to the Redis client, raising an error + if the client is not initialized. + """ + + def __init__(self): + self._client = None + + def initialize(self, client): + if self._client is None: + self._client = client + + def __getattr__(self, item): + if self._client is None: + raise RuntimeError("Redis client is not initialized. Call init_app first.") + return getattr(self._client, item) + + +redis_client = RedisClientWrapper() + + +def init_app(app: DifyApp): + global redis_client + connection_class: type[Union[Connection, SSLConnection]] = Connection + if dify_config.REDIS_USE_SSL: + connection_class = SSLConnection + + redis_params: dict[str, Any] = { + "username": dify_config.REDIS_USERNAME, + "password": dify_config.REDIS_PASSWORD, + "db": dify_config.REDIS_DB, + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + } + + if dify_config.REDIS_USE_SENTINEL: + assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" + sentinel_hosts = [ + (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") + ] + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + }, + ) + master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + redis_client.initialize(master) + elif dify_config.REDIS_USE_CLUSTERS: + assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" + nodes = [ + ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) + for node in dify_config.REDIS_CLUSTERS.split(",") + ] + # FIXME: mypy error here, try to figure out how to fix it + redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore + else: + redis_params.update( + { + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } + ) + pool = redis.ConnectionPool(**redis_params) + redis_client.initialize(redis.Redis(connection_pool=pool)) + + app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py new file mode 100644 index 0000000000000000000000000000000000000000..3a74aace6a34cf551d55ce8058171b9ad64e5642 --- /dev/null +++ b/api/extensions/ext_sentry.py @@ -0,0 +1,40 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + if dify_config.SENTRY_DSN: + import openai + import sentry_sdk + from langfuse import parse_error # type: ignore + from sentry_sdk.integrations.celery import CeleryIntegration + from sentry_sdk.integrations.flask import FlaskIntegration + from werkzeug.exceptions import HTTPException + + from core.model_runtime.errors.invoke import InvokeRateLimitError + + def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + + return event + + sentry_sdk.init( + dsn=dify_config.SENTRY_DSN, + integrations=[FlaskIntegration(), CeleryIntegration()], + ignore_errors=[ + HTTPException, + ValueError, + FileNotFoundError, + openai.APIStatusError, + InvokeRateLimitError, + parse_error.defaultErrorResponse, + ], + traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, + profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, + environment=dify_config.DEPLOY_ENV, + release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}", + before_send=before_send, + ) diff --git a/api/extensions/ext_set_secretkey.py b/api/extensions/ext_set_secretkey.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb87c0167dfbfe85a34de5b2957916144412652 --- /dev/null +++ b/api/extensions/ext_set_secretkey.py @@ -0,0 +1,6 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + app.secret_key = dify_config.SECRET_KEY diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..588bdb2d2717e089abec20ace1eae4788e45a3cf --- /dev/null +++ b/api/extensions/ext_storage.py @@ -0,0 +1,138 @@ +import logging +from collections.abc import Callable, Generator +from typing import Literal, Union, overload + +from flask import Flask + +from configs import dify_config +from dify_app import DifyApp +from extensions.storage.base_storage import BaseStorage +from extensions.storage.storage_type import StorageType + +logger = logging.getLogger(__name__) + + +class Storage: + def init_app(self, app: Flask): + storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) + with app.app_context(): + self.storage_runner = storage_factory() + + @staticmethod + def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]: + match storage_type: + case StorageType.S3: + from extensions.storage.aws_s3_storage import AwsS3Storage + + return AwsS3Storage + case StorageType.OPENDAL: + from extensions.storage.opendal_storage import OpenDALStorage + + return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME) + case StorageType.LOCAL: + from extensions.storage.opendal_storage import OpenDALStorage + + return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH) + case StorageType.AZURE_BLOB: + from extensions.storage.azure_blob_storage import AzureBlobStorage + + return AzureBlobStorage + case StorageType.ALIYUN_OSS: + from extensions.storage.aliyun_oss_storage import AliyunOssStorage + + return AliyunOssStorage + case StorageType.GOOGLE_STORAGE: + from extensions.storage.google_cloud_storage import GoogleCloudStorage + + return GoogleCloudStorage + case StorageType.TENCENT_COS: + from extensions.storage.tencent_cos_storage import TencentCosStorage + + return TencentCosStorage + case StorageType.OCI_STORAGE: + from extensions.storage.oracle_oci_storage import OracleOCIStorage + + return OracleOCIStorage + case StorageType.HUAWEI_OBS: + from extensions.storage.huawei_obs_storage import HuaweiObsStorage + + return HuaweiObsStorage + case StorageType.BAIDU_OBS: + from extensions.storage.baidu_obs_storage import BaiduObsStorage + + return BaiduObsStorage + case StorageType.VOLCENGINE_TOS: + from extensions.storage.volcengine_tos_storage import VolcengineTosStorage + + return VolcengineTosStorage + case StorageType.SUPBASE: + from extensions.storage.supabase_storage import SupabaseStorage + + return SupabaseStorage + case _: + raise ValueError(f"unsupported storage type {storage_type}") + + def save(self, filename, data): + try: + self.storage_runner.save(filename, data) + except Exception as e: + logger.exception(f"Failed to save file {filename}") + raise e + + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: + try: + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) + except Exception as e: + logger.exception(f"Failed to load file {filename}") + raise e + + def load_once(self, filename: str) -> bytes: + try: + return self.storage_runner.load_once(filename) + except Exception as e: + logger.exception(f"Failed to load_once file {filename}") + raise e + + def load_stream(self, filename: str) -> Generator: + try: + return self.storage_runner.load_stream(filename) + except Exception as e: + logger.exception(f"Failed to load_stream file {filename}") + raise e + + def download(self, filename, target_filepath): + try: + self.storage_runner.download(filename, target_filepath) + except Exception as e: + logger.exception(f"Failed to download file {filename}") + raise e + + def exists(self, filename): + try: + return self.storage_runner.exists(filename) + except Exception as e: + logger.exception(f"Failed to check file exists {filename}") + raise e + + def delete(self, filename): + try: + return self.storage_runner.delete(filename) + except Exception as e: + logger.exception(f"Failed to delete file {filename}") + raise e + + +storage = Storage() + + +def init_app(app: DifyApp): + storage.init_app(app) diff --git a/api/extensions/ext_timezone.py b/api/extensions/ext_timezone.py new file mode 100644 index 0000000000000000000000000000000000000000..77650bf972a0b665ef898208128d0f5cf767c0b8 --- /dev/null +++ b/api/extensions/ext_timezone.py @@ -0,0 +1,11 @@ +import os +import time + +from dify_app import DifyApp + + +def init_app(app: DifyApp): + os.environ["TZ"] = "UTC" + # windows platform not support tzset + if hasattr(time, "tzset"): + time.tzset() diff --git a/api/extensions/ext_warnings.py b/api/extensions/ext_warnings.py new file mode 100644 index 0000000000000000000000000000000000000000..246f977af5e436f0837fe6e79965deec6b9b3da9 --- /dev/null +++ b/api/extensions/ext_warnings.py @@ -0,0 +1,7 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + import warnings + + warnings.simplefilter("ignore", ResourceWarning) diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..00bf5d4f93ae3b640b09a8dbfbc954a44decc847 --- /dev/null +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -0,0 +1,54 @@ +import posixpath +from collections.abc import Generator + +import oss2 as aliyun_s3 # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class AliyunOssStorage(BaseStorage): + """Implementation for Aliyun OSS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME + self.folder = dify_config.ALIYUN_OSS_PATH + oss_auth_method = aliyun_s3.Auth + region = None + if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4": + oss_auth_method = aliyun_s3.AuthV4 + region = dify_config.ALIYUN_OSS_REGION + oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY) + self.client = aliyun_s3.Bucket( + oss_auth, + dify_config.ALIYUN_OSS_ENDPOINT, + self.bucket_name, + connect_timeout=30, + region=region, + ) + + def save(self, filename, data): + self.client.put_object(self.__wrapper_folder_filename(filename), data) + + def load_once(self, filename: str) -> bytes: + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + data: bytes = obj.read() + return data + + def load_stream(self, filename: str) -> Generator: + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + while chunk := obj.read(4096): + yield chunk + + def download(self, filename: str, target_filepath): + self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) + + def exists(self, filename: str): + return self.client.object_exists(self.__wrapper_folder_filename(filename)) + + def delete(self, filename: str): + self.client.delete_object(self.__wrapper_folder_filename(filename)) + + def __wrapper_folder_filename(self, filename: str) -> str: + return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..ce6b194f4063f12cc1e768b1ee5b09ce2722fcb1 --- /dev/null +++ b/api/extensions/storage/aws_s3_storage.py @@ -0,0 +1,91 @@ +import logging +from collections.abc import Generator + +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + +logger = logging.getLogger(__name__) + + +class AwsS3Storage(BaseStorage): + """Implementation for Amazon Web Services S3 storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.S3_BUCKET_NAME + if dify_config.S3_USE_AWS_MANAGED_IAM: + logger.info("Using AWS managed IAM role for S3") + + session = boto3.Session() + region_name = dify_config.S3_REGION + self.client = session.client(service_name="s3", region_name=region_name) + else: + logger.info("Using ak and sk for S3") + + self.client = boto3.client( + "s3", + aws_secret_access_key=dify_config.S3_SECRET_KEY, + aws_access_key_id=dify_config.S3_ACCESS_KEY, + endpoint_url=dify_config.S3_ENDPOINT, + region_name=dify_config.S3_REGION, + config=Config( + s3={"addressing_style": dify_config.S3_ADDRESS_STYLE}, + request_checksum_calculation="when_required", + response_checksum_validation="when_required", + ), + ) + # create bucket + try: + self.client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + # if bucket not exists, create it + if e.response["Error"]["Code"] == "404": + self.client.create_bucket(Bucket=self.bucket_name) + # if bucket is not accessible, pass, maybe the bucket is existing but not accessible + elif e.response["Error"]["Code"] == "403": + pass + else: + # other error, raise exception + raise + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + + def load_once(self, filename: str) -> bytes: + try: + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + return data + + def load_stream(self, filename: str) -> Generator: + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("file not found") + elif "reached max retries" in str(ex): + raise ValueError("please do not request the same file too frequently") + else: + raise + + def download(self, filename, target_filepath): + self.client.download_file(self.bucket_name, filename, target_filepath) + + def exists(self, filename): + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..7448fd4a6bb4becd751dade543d9f93b99f2cd4c --- /dev/null +++ b/api/extensions/storage/azure_blob_storage.py @@ -0,0 +1,84 @@ +from collections.abc import Generator +from datetime import UTC, datetime, timedelta +from typing import Optional + +from azure.identity import ChainedTokenCredential, DefaultAzureCredential +from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas + +from configs import dify_config +from extensions.ext_redis import redis_client +from extensions.storage.base_storage import BaseStorage + + +class AzureBlobStorage(BaseStorage): + """Implementation for Azure Blob storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME + self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL + self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME + self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY + + self.credential: Optional[ChainedTokenCredential] = None + if self.account_key == "managedidentity": + self.credential = DefaultAzureCredential() + else: + self.credential = None + + def save(self, filename, data): + client = self._sync_client() + blob_container = client.get_container_client(container=self.bucket_name) + blob_container.upload_blob(filename, data) + + def load_once(self, filename: str) -> bytes: + client = self._sync_client() + blob = client.get_container_client(container=self.bucket_name) + blob = blob.get_blob_client(blob=filename) + data: bytes = blob.download_blob().readall() + return data + + def load_stream(self, filename: str) -> Generator: + client = self._sync_client() + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + blob_data = blob.download_blob() + yield from blob_data.chunks() + + def download(self, filename, target_filepath): + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + with open(target_filepath, "wb") as my_blob: + blob_data = blob.download_blob() + blob_data.readinto(my_blob) + + def exists(self, filename): + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + return blob.exists() + + def delete(self, filename): + client = self._sync_client() + + blob_container = client.get_container_client(container=self.bucket_name) + blob_container.delete_blob(filename) + + def _sync_client(self): + if self.account_key == "managedidentity": + return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore + + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) + cache_result = redis_client.get(cache_key) + if cache_result is not None: + sas_token = cache_result.decode("utf-8") + else: + sas_token = generate_account_sas( + account_name=self.account_name or "", + account_key=self.account_key or "", + resource_types=ResourceTypes(service=True, container=True, object=True), + permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), + expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), + ) + redis_client.set(cache_key, sas_token, ex=3000) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b94efa08be761384ce757d9b4f56eec1d792733d --- /dev/null +++ b/api/extensions/storage/baidu_obs_storage.py @@ -0,0 +1,57 @@ +import base64 +import hashlib +from collections.abc import Generator + +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class BaiduObsStorage(BaseStorage): + """Implementation for Baidu OBS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME + client_config = BceClientConfiguration( + credentials=BceCredentials( + access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY, + secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY, + ), + endpoint=dify_config.BAIDU_OBS_ENDPOINT, + ) + + self.client = BosClient(config=client_config) + + def save(self, filename, data): + md5 = hashlib.md5() + md5.update(data) + content_md5 = base64.standard_b64encode(md5.digest()) + self.client.put_object( + bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5 + ) + + def load_once(self, filename: str) -> bytes: + response = self.client.get_object(bucket_name=self.bucket_name, key=filename) + data: bytes = response.data.read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath) + + def exists(self, filename): + res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename) + if res is None: + return False + return True + + def delete(self, filename): + self.client.delete_object(bucket_name=self.bucket_name, key=filename) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..0dedd7ff8cc3259bc63680888668960e8ca6a49f --- /dev/null +++ b/api/extensions/storage/base_storage.py @@ -0,0 +1,32 @@ +"""Abstract interface for file storage implementations.""" + +from abc import ABC, abstractmethod +from collections.abc import Generator + + +class BaseStorage(ABC): + """Interface for file storage.""" + + @abstractmethod + def save(self, filename, data): + raise NotImplementedError + + @abstractmethod + def load_once(self, filename: str) -> bytes: + raise NotImplementedError + + @abstractmethod + def load_stream(self, filename: str) -> Generator: + raise NotImplementedError + + @abstractmethod + def download(self, filename, target_filepath): + raise NotImplementedError + + @abstractmethod + def exists(self, filename): + raise NotImplementedError + + @abstractmethod + def delete(self, filename): + raise NotImplementedError diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..705639f42e716f5906bbcda8d7986b5907a9b8f8 --- /dev/null +++ b/api/extensions/storage/google_cloud_storage.py @@ -0,0 +1,60 @@ +import base64 +import io +import json +from collections.abc import Generator + +from google.cloud import storage as google_cloud_storage # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class GoogleCloudStorage(BaseStorage): + """Implementation for Google Cloud storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME + service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 + # if service_account_json_str is empty, use Application Default Credentials + if service_account_json_str: + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") + # convert str to object + service_account_obj = json.loads(service_account_json) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) + else: + self.client = google_cloud_storage.Client() + + def save(self, filename, data): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.blob(filename) + with io.BytesIO(data) as stream: + blob.upload_from_file(stream) + + def load_once(self, filename: str) -> bytes: + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + data: bytes = blob.download_as_bytes() + return data + + def load_stream(self, filename: str) -> Generator: + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + with blob.open(mode="rb") as blob_stream: + while chunk := blob_stream.read(4096): + yield chunk + + def download(self, filename, target_filepath): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + blob.download_to_filename(target_filepath) + + def exists(self, filename): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.blob(filename) + return blob.exists() + + def delete(self, filename): + bucket = self.client.get_bucket(self.bucket_name) + bucket.delete_blob(filename) diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..07f1d199701be4f59c4f3992314af1d79aec6f1e --- /dev/null +++ b/api/extensions/storage/huawei_obs_storage.py @@ -0,0 +1,51 @@ +from collections.abc import Generator + +from obs import ObsClient # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class HuaweiObsStorage(BaseStorage): + """Implementation for Huawei OBS storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME + self.client = ObsClient( + access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY, + secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY, + server=dify_config.HUAWEI_OBS_SERVER, + ) + + def save(self, filename, data): + self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) + + def exists(self, filename): + res = self._get_meta(filename) + if res is None: + return False + return True + + def delete(self, filename): + self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename) + + def _get_meta(self, filename): + res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename) + if res.status < 300: + return res + else: + return None diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b78fc94dae7843058c946aa052427cc6182b51c0 --- /dev/null +++ b/api/extensions/storage/opendal_storage.py @@ -0,0 +1,89 @@ +import logging +import os +from collections.abc import Generator +from pathlib import Path + +import opendal # type: ignore[import] +from dotenv import dotenv_values + +from extensions.storage.base_storage import BaseStorage + +logger = logging.getLogger(__name__) + + +def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"): + kwargs = {} + config_prefix = prefix + scheme.upper() + "_" + for key, value in os.environ.items(): + if key.startswith(config_prefix): + kwargs[key[len(config_prefix) :].lower()] = value + + file_env_vars: dict = dotenv_values(env_file_path) or {} + for key, value in file_env_vars.items(): + if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: + kwargs[key[len(config_prefix) :].lower()] = value + + return kwargs + + +class OpenDALStorage(BaseStorage): + def __init__(self, scheme: str, **kwargs): + kwargs = kwargs or _get_opendal_kwargs(scheme=scheme) + + if scheme == "fs": + root = kwargs.get("root", "storage") + Path(root).mkdir(parents=True, exist_ok=True) + + self.op = opendal.Operator(scheme=scheme, **kwargs) + logger.debug(f"opendal operator created with scheme {scheme}") + retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) + self.op = self.op.layer(retry_layer) + logger.debug("added retry layer to opendal operator") + + def save(self, filename: str, data: bytes) -> None: + self.op.write(path=filename, bs=data) + logger.debug(f"file {filename} saved") + + def load_once(self, filename: str) -> bytes: + if not self.exists(filename): + raise FileNotFoundError("File not found") + + content: bytes = self.op.read(path=filename) + logger.debug(f"file {filename} loaded") + return content + + def load_stream(self, filename: str) -> Generator: + if not self.exists(filename): + raise FileNotFoundError("File not found") + + batch_size = 4096 + file = self.op.open(path=filename, mode="rb") + while chunk := file.read(batch_size): + yield chunk + logger.debug(f"file {filename} loaded as stream") + + def download(self, filename: str, target_filepath: str): + if not self.exists(filename): + raise FileNotFoundError("File not found") + + with Path(target_filepath).open("wb") as f: + f.write(self.op.read(path=filename)) + logger.debug(f"file {filename} downloaded to {target_filepath}") + + def exists(self, filename: str) -> bool: + # FIXME this is a workaround for opendal python-binding do not have a exists method and no better + # error handler here when opendal python-binding has a exists method, we should use it + # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs + try: + res: bool = self.op.stat(path=filename).mode.is_file() + logger.debug(f"file {filename} checked") + return res + except Exception: + return False + + def delete(self, filename: str): + if self.exists(filename): + self.op.delete(path=filename) + logger.debug(f"file {filename} deleted") + return + logger.debug(f"file {filename} not found, skip delete") diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..82829f7fd50d652e67b4abf1beccf01626808bcd --- /dev/null +++ b/api/extensions/storage/oracle_oci_storage.py @@ -0,0 +1,59 @@ +from collections.abc import Generator + +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class OracleOCIStorage(BaseStorage): + """Implementation for Oracle OCI storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.OCI_BUCKET_NAME + self.client = boto3.client( + "s3", + aws_secret_access_key=dify_config.OCI_SECRET_KEY, + aws_access_key_id=dify_config.OCI_ACCESS_KEY, + endpoint_url=dify_config.OCI_ENDPOINT, + region_name=dify_config.OCI_REGION, + ) + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + + def load_once(self, filename: str) -> bytes: + try: + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + return data + + def load_stream(self, filename: str) -> Generator: + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + + def download(self, filename, target_filepath): + self.client.download_file(self.bucket_name, filename, target_filepath) + + def exists(self, filename): + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py new file mode 100644 index 0000000000000000000000000000000000000000..0a891e36cf17a9a85721074bb092636985f8f7ba --- /dev/null +++ b/api/extensions/storage/storage_type.py @@ -0,0 +1,16 @@ +from enum import StrEnum + + +class StorageType(StrEnum): + ALIYUN_OSS = "aliyun-oss" + AZURE_BLOB = "azure-blob" + BAIDU_OBS = "baidu-obs" + GOOGLE_STORAGE = "google-storage" + HUAWEI_OBS = "huawei-obs" + LOCAL = "local" + OCI_STORAGE = "oci-storage" + OPENDAL = "opendal" + S3 = "s3" + TENCENT_COS = "tencent-cos" + VOLCENGINE_TOS = "volcengine-tos" + SUPBASE = "supabase" diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..711c3f72117c86dfdb944f2fabdb85d7f16b9914 --- /dev/null +++ b/api/extensions/storage/supabase_storage.py @@ -0,0 +1,59 @@ +import io +from collections.abc import Generator +from pathlib import Path + +from supabase import Client + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class SupabaseStorage(BaseStorage): + """Implementation for supabase obs storage.""" + + def __init__(self): + super().__init__() + if dify_config.SUPABASE_URL is None: + raise ValueError("SUPABASE_URL is not set") + if dify_config.SUPABASE_API_KEY is None: + raise ValueError("SUPABASE_API_KEY is not set") + if dify_config.SUPABASE_BUCKET_NAME is None: + raise ValueError("SUPABASE_BUCKET_NAME is not set") + + self.bucket_name = dify_config.SUPABASE_BUCKET_NAME + self.client = Client(supabase_url=dify_config.SUPABASE_URL, supabase_key=dify_config.SUPABASE_API_KEY) + self.create_bucket(id=dify_config.SUPABASE_BUCKET_NAME, bucket_name=dify_config.SUPABASE_BUCKET_NAME) + + def create_bucket(self, id, bucket_name): + if not self.bucket_exists(): + self.client.storage.create_bucket(id=id, name=bucket_name) + + def save(self, filename, data): + self.client.storage.from_(self.bucket_name).upload(filename, data) + + def load_once(self, filename: str) -> bytes: + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) + return content + + def load_stream(self, filename: str) -> Generator: + result = self.client.storage.from_(self.bucket_name).download(filename) + byte_stream = io.BytesIO(result) + while chunk := byte_stream.read(4096): # Read in chunks of 4KB + yield chunk + + def download(self, filename, target_filepath): + result = self.client.storage.from_(self.bucket_name).download(filename) + Path(target_filepath).write_bytes(result) + + def exists(self, filename): + result = self.client.storage.from_(self.bucket_name).list(filename) + if result.count() > 0: + return True + return False + + def delete(self, filename): + self.client.storage.from_(self.bucket_name).remove(filename) + + def bucket_exists(self): + buckets = self.client.storage.list_buckets() + return any(bucket.name == self.bucket_name for bucket in buckets) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..9cdd3e67f75aab342e2c25b071608b959c9bfca2 --- /dev/null +++ b/api/extensions/storage/tencent_cos_storage.py @@ -0,0 +1,43 @@ +from collections.abc import Generator + +from qcloud_cos import CosConfig, CosS3Client # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class TencentCosStorage(BaseStorage): + """Implementation for Tencent Cloud COS storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME + config = CosConfig( + Region=dify_config.TENCENT_COS_REGION, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) + self.client = CosS3Client(config) + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) + + def load_once(self, filename: str) -> bytes: + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].get_stream(chunk_size=4096) + + def download(self, filename, target_filepath): + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + response["Body"].get_stream_to_file(target_filepath) + + def exists(self, filename): + return self.client.object_exists(Bucket=self.bucket_name, Key=filename) + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..55fe6545ec3d2d1d82a5d3447738c67fe52c6087 --- /dev/null +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -0,0 +1,46 @@ +from collections.abc import Generator + +import tos # type: ignore + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class VolcengineTosStorage(BaseStorage): + """Implementation for Volcengine TOS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME + self.client = tos.TosClientV2( + ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, + sk=dify_config.VOLCENGINE_TOS_SECRET_KEY, + endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT, + region=dify_config.VOLCENGINE_TOS_REGION, + ) + + def save(self, filename, data): + self.client.put_object(bucket=self.bucket_name, key=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(bucket=self.bucket_name, key=filename) + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) + + def exists(self, filename): + res = self.client.head_object(bucket=self.bucket_name, key=filename) + if res.status_code != 200: + return False + return True + + def delete(self, filename): + self.client.delete_object(bucket=self.bucket_name, key=filename) diff --git a/api/factories/__init__.py b/api/factories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..c6dc748e93bd03b8927b6973b07ab55a41283ced --- /dev/null +++ b/api/factories/file_factory.py @@ -0,0 +1,304 @@ +import mimetypes +import uuid +from collections.abc import Callable, Mapping, Sequence +from typing import Any, cast + +import httpx +from sqlalchemy import select + +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models import MessageFile, ToolFile, UploadFile + + +def build_from_message_files( + *, + message_files: Sequence["MessageFile"], + tenant_id: str, + config: FileUploadConfig, +) -> Sequence[File]: + results = [ + build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) + for file in message_files + if file.belongs_to != FileBelongsTo.ASSISTANT + ] + return results + + +def build_from_message_file( + *, + message_file: "MessageFile", + tenant_id: str, + config: FileUploadConfig, +): + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "id": message_file.id, + "type": message_file.type, + "upload_file_id": message_file.upload_file_id, + } + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + ) + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, +) -> File: + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) + + build_functions: dict[FileTransferMethod, Callable] = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + } + + build_func = build_functions.get(transfer_method) + if not build_func: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + file: File = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + ) + + if config and not _is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, +) -> Sequence[File]: + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + ) + for mapping in mappings + ] + + if ( + config + # If image config is set. + and config.image_config + # And the number of image files exceeds the maximum limit + and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + # check if upload_file_id is a valid uuid + try: + uuid.UUID(upload_file_id) + except ValueError: + raise ValueError("Invalid upload file id format") + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + + row = db.session.scalar(stmt) + if row is None: + raise ValueError("Invalid upload file") + + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + related_id=mapping.get("upload_file_id"), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = _get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type) + + return File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + storage_key="", + ) + + +def _get_remote_file_info(url: str): + file_size = -1 + filename = url.split("/")[-1].split("?")[0] or "unknown_file" + mime_type = mimetypes.guess_type(filename)[0] or "" + + resp = ssrf_proxy.head(url, follow_redirects=True) + resp = cast(httpx.Response, resp) + if resp.status_code == httpx.codes.OK: + if content_disposition := resp.headers.get("Content-Disposition"): + filename = str(content_disposition.split("filename=")[-1].strip('"')) + file_size = int(resp.headers.get("Content-Length", file_size)) + mime_type = mime_type or str(resp.headers.get("Content-Type", "")) + + return mime_type, filename, file_size + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + tool_file = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == mapping.get("tool_file_id"), + ToolFile.tenant_id == tenant_id, + ) + .first() + ) + + if tool_file is None: + raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype) + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods: + return False + + return True + + +def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType: + """ + If custom type, try to guess the file type by extension and mime_type. + """ + if file_type != FileType.CUSTOM: + return FileType(file_type) + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = _get_file_type_by_mimetype(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + extension = extension.lstrip(".") + if extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + elif extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + elif extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + elif extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + return None + + +def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: + if "image" in mime_type: + file_type = FileType.IMAGE + elif "video" in mime_type: + file_type = FileType.VIDEO + elif "audio" in mime_type: + file_type = FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + return file_type diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca8448ec06621d85b52a786855e1fbf1364a55 --- /dev/null +++ b/api/factories/variable_factory.py @@ -0,0 +1,178 @@ +from collections.abc import Mapping, Sequence +from typing import Any, cast +from uuid import uuid4 + +from configs import dify_config +from core.file import File +from core.variables.exc import VariableError +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArraySegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from core.variables.types import SegmentType +from core.variables.variables import ( + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID + + +class InvalidSelectorError(ValueError): + pass + + +class UnsupportedSegmentTypeError(Exception): + pass + + +# Define the constant +SEGMENT_TO_VARIABLE_MAP = { + StringSegment: StringVariable, + IntegerSegment: IntegerVariable, + FloatSegment: FloatVariable, + ObjectSegment: ObjectVariable, + FileSegment: FileVariable, + ArrayStringSegment: ArrayStringVariable, + ArrayNumberSegment: ArrayNumberVariable, + ArrayObjectSegment: ArrayObjectVariable, + ArrayFileSegment: ArrayFileVariable, + ArrayAnySegment: ArrayAnyVariable, + NoneSegment: NoneVariable, +} + + +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("name"): + raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) + + +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("name"): + raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) + + +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: + """ + This factory function is used to create the environment variable or the conversation variable, + not support the File type. + """ + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any + match value_type: + case SegmentType.STRING: + result = StringVariable.model_validate(mapping) + case SegmentType.SECRET: + result = SecretVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, int): + result = IntegerVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, float): + result = FloatVariable.model_validate(mapping) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise VariableError(f"invalid number value {value}") + case SegmentType.OBJECT if isinstance(value, dict): + result = ObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_STRING if isinstance(value, list): + result = ArrayStringVariable.model_validate(mapping) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + result = ArrayNumberVariable.model_validate(mapping) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f"not supported value type {value_type}") + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + if not result.selector: + result = result.model_copy(update={"selector": selector}) + return cast(Variable, result) + + +def build_segment(value: Any, /) -> Segment: + if value is None: + return NoneSegment() + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) + if isinstance(value, list): + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + return ArrayAnySegment(value=value) + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER: + return ArrayNumberSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) + case _: + raise ValueError(f"not supported value {value}") + raise ValueError(f"not supported value {value}") + + +def segment_to_variable( + *, + segment: Segment, + selector: Sequence[str], + id: str | None = None, + name: str | None = None, + description: str = "", +) -> Variable: + if isinstance(segment, Variable): + return segment + name = name or selector[-1] + id = id or str(uuid4()) + + segment_type = type(segment) + if segment_type not in SEGMENT_TO_VARIABLE_MAP: + raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") + + variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), + ) diff --git a/api/fields/__init__.py b/api/fields/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..1c58b3a25790871081e182b94c2baf751e03659d --- /dev/null +++ b/api/fields/annotation_fields.py @@ -0,0 +1,30 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +annotation_fields = { + "id": fields.String, + "question": fields.String, + "answer": fields.Raw(attribute="content"), + "hit_count": fields.Integer, + "created_at": TimestampField, + # 'account': fields.Nested(simple_account_fields, allow_null=True) +} + +annotation_list_fields = { + "data": fields.List(fields.Nested(annotation_fields)), +} + +annotation_hit_history_fields = { + "id": fields.String, + "source": fields.String, + "score": fields.Float, + "question": fields.String, + "created_at": TimestampField, + "match": fields.String(attribute="annotation_question"), + "response": fields.String(attribute="annotation_content"), +} + +annotation_hit_history_list_fields = { + "data": fields.List(fields.Nested(annotation_hit_history_fields)), +} diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..d40407bfcc6193ffce1369a558177955956b8a7b --- /dev/null +++ b/api/fields/api_based_extension_fields.py @@ -0,0 +1,23 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + + +class HiddenAPIKey(fields.Raw): + def output(self, key, obj): + api_key = obj.api_key + # If the length of the api_key is less than 8 characters, show the first and last characters + if len(api_key) <= 8: + return api_key[0] + "******" + api_key[-1] + # If the api_key is greater than 8 characters, show the first three and the last three characters + else: + return api_key[:3] + "******" + api_key[-3:] + + +api_based_extension_fields = { + "id": fields.String, + "name": fields.String, + "api_endpoint": fields.String, + "api_key": HiddenAPIKey, + "created_at": TimestampField, +} diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..73800eab853cd34182a360cf7beceb05e9490e5a --- /dev/null +++ b/api/fields/app_fields.py @@ -0,0 +1,201 @@ +from flask_restful import fields # type: ignore + +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField + +app_detail_kernel_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, +} + +related_app_list = { + "data": fields.List(fields.Nested(app_detail_kernel_fields)), + "total": fields.Integer, +} + +model_config_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "speech_to_text": fields.Raw(attribute="speech_to_text_dict"), + "text_to_speech": fields.Raw(attribute="text_to_speech_dict"), + "retriever_resource": fields.Raw(attribute="retriever_resource_dict"), + "annotation_reply": fields.Raw(attribute="annotation_reply_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"), + "external_data_tools": fields.Raw(attribute="external_data_tools_list"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "dataset_query_variable": fields.String, + "pre_prompt": fields.String, + "agent_mode": fields.Raw(attribute="agent_mode_dict"), + "prompt_type": fields.String, + "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"), + "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), + "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), + "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +app_detail_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "tracing": fields.Raw, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +prompt_config_fields = { + "prompt_template": fields.String, +} + +model_config_partial_fields = { + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + +app_partial_fields = { + "id": fields.String, + "name": fields.String, + "max_active_requests": fields.Raw(), + "description": fields.String(attribute="desc_or_prompt"), + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), +} + + +app_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), +} + +template_fields = { + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, + "model_config": fields.Nested(model_config_fields), +} + +template_list_fields = { + "data": fields.List(fields.Nested(template_fields)), +} + +site_fields = { + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +app_detail_fields_with_site = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "deleted_tools": fields.List(fields.String), +} + +app_site_fields = { + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, +} + +app_import_fields = { + "id": fields.String, + "status": fields.String, + "app_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..c54554a6de84059b2c8beeb19062bb7bf22f6b94 --- /dev/null +++ b/api/fields/conversation_fields.py @@ -0,0 +1,225 @@ +from flask_restful import fields # type: ignore + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +from .raws import FilesContainedField + + +class MessageTextField(fields.Raw): + def format(self, value): + return value[0]["text"] if value else "" + + +feedback_fields = { + "rating": fields.String, + "content": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account": fields.Nested(simple_account_fields, allow_null=True), +} + +annotation_fields = { + "id": fields.String, + "question": fields.String, + "content": fields.String, + "account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, +} + +annotation_hit_history_fields = { + "annotation_id": fields.String(attribute="id"), + "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, +} + +message_file_fields = { + "id": fields.String, + "filename": fields.String, + "type": fields.String, + "url": fields.String, + "mime_type": fields.String, + "size": fields.Integer, + "transfer_method": fields.String, + "belongs_to": fields.String(default="user"), +} + +agent_thought_fields = { + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), +} + +message_detail_fields = { + "id": fields.String, + "conversation_id": fields.String, + "inputs": FilesContainedField, + "query": fields.String, + "message": fields.Raw, + "message_tokens": fields.Integer, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "answer_tokens": fields.Integer, + "provider_response_latency": fields.Float, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "feedbacks": fields.List(fields.Nested(feedback_fields)), + "workflow_run_id": fields.String, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields)), + "metadata": fields.Raw(attribute="message_metadata_dict"), + "status": fields.String, + "error": fields.String, + "parent_message_id": fields.String, +} + +feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} +status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer} +model_config_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "model": fields.Raw, + "user_input_form": fields.Raw, + "pre_prompt": fields.String, + "agent_mode": fields.Raw, +} + +simple_configs_fields = { + "prompt_template": fields.String, +} + +simple_model_config_fields = { + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, +} + +simple_message_detail_fields = { + "inputs": FilesContainedField, + "query": fields.String, + "message": MessageTextField, + "answer": fields.String, +} + +conversation_fields = { + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String(), + "from_account_id": fields.String, + "from_account_name": fields.String, + "read_at": TimestampField, + "created_at": TimestampField, + "updated_at": TimestampField, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "model_config": fields.Nested(simple_model_config_fields), + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), +} + +conversation_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_fields), attribute="items"), +} + +conversation_message_detail_fields = { + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "model_config": fields.Nested(model_config_fields), + "message": fields.Nested(message_detail_fields, attribute="first_message"), +} + +conversation_with_summary_fields = { + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String, + "from_account_id": fields.String, + "from_account_name": fields.String, + "name": fields.String, + "summary": fields.String(attribute="summary_or_query"), + "read_at": TimestampField, + "created_at": TimestampField, + "updated_at": TimestampField, + "annotated": fields.Boolean, + "model_config": fields.Nested(simple_model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "status_count": fields.Nested(status_count_fields), +} + +conversation_with_summary_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), +} + +conversation_detail_fields = { + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, + "annotated": fields.Boolean, + "introduction": fields.String, + "model_config": fields.Nested(model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), +} + +simple_conversation_fields = { + "id": fields.String, + "name": fields.String, + "inputs": FilesContainedField, + "status": fields.String, + "introduction": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +conversation_delete_fields = { + "result": fields.String, +} + +conversation_infinite_scroll_pagination_fields = { + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(simple_conversation_fields)), +} + +conversation_with_model_config_fields = { + **simple_conversation_fields, + "model_config": fields.Raw, +} + +conversation_with_model_config_infinite_scroll_pagination_fields = { + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_with_model_config_fields)), +} diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..c6385efb5a3cf1d8bf3cab95ba501ea26a731992 --- /dev/null +++ b/api/fields/conversation_variable_fields.py @@ -0,0 +1,21 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.String, + "description": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +paginated_conversation_variable_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"), +} diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..608672121e2b50d298f37cb7e7a5639bc3dcdee1 --- /dev/null +++ b/api/fields/data_source_fields.py @@ -0,0 +1,57 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} + +integrate_page_fields = { + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, +} + +integrate_workspace_fields = { + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), +} + +integrate_notion_info_list_fields = { + "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), +} + +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} + +integrate_page_fields = { + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "parent_id": fields.String, + "type": fields.String, +} + +integrate_workspace_fields = { + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), + "total": fields.Integer, +} + +integrate_fields = { + "id": fields.String, + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "disabled": fields.Boolean, + "link": fields.String, + "source_info": fields.Nested(integrate_workspace_fields), +} + +integrate_list_fields = { + "data": fields.List(fields.Nested(integrate_fields)), +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..bedab5750f1d664a0fdf1926bbaaf9a94ddd187f --- /dev/null +++ b/api/fields/dataset_fields.py @@ -0,0 +1,89 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +dataset_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} + +reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} + +keyword_setting_fields = {"keyword_weight": fields.Float} + +vector_setting_fields = { + "vector_weight": fields.Float, + "embedding_model_name": fields.String, + "embedding_provider_name": fields.String, +} + +weighted_score_fields = { + "keyword_setting": fields.Nested(keyword_setting_fields), + "vector_setting": fields.Nested(vector_setting_fields), +} + +dataset_retrieval_model_fields = { + "search_method": fields.String, + "reranking_enable": fields.Boolean, + "reranking_mode": fields.String, + "reranking_model": fields.Nested(reranking_model_fields), + "weights": fields.Nested(weighted_score_fields, allow_null=True), + "top_k": fields.Integer, + "score_threshold_enabled": fields.Boolean, + "score_threshold": fields.Float, +} +external_retrieval_model_fields = { + "top_k": fields.Integer, + "score_threshold": fields.Float, + "score_threshold_enabled": fields.Boolean, +} + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + +external_knowledge_info_fields = { + "external_knowledge_id": fields.String, + "external_knowledge_api_id": fields.String, + "external_knowledge_api_name": fields.String, + "external_knowledge_api_endpoint": fields.String, +} + +dataset_detail_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "provider": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "app_count": fields.Integer, + "document_count": fields.Integer, + "word_count": fields.Integer, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "embedding_model": fields.String, + "embedding_model_provider": fields.String, + "embedding_available": fields.Boolean, + "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "tags": fields.List(fields.Nested(tag_fields)), + "doc_form": fields.String, + "external_knowledge_info": fields.Nested(external_knowledge_info_fields), + "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), +} + +dataset_query_detail_fields = { + "id": fields.String, + "content": fields.String, + "source": fields.String, + "source_app_id": fields.String, + "created_by_role": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..f2250d964ac12d88dac561dd60efb7e541e869a1 --- /dev/null +++ b/api/fields/document_fields.py @@ -0,0 +1,77 @@ +from flask_restful import fields # type: ignore + +from fields.dataset_fields import dataset_fields +from libs.helper import TimestampField + +document_fields = { + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "doc_form": fields.String, +} + +document_with_segments_fields = { + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "process_rule_dict": fields.Raw(attribute="process_rule_dict"), + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, +} + +dataset_and_document_fields = { + "dataset": fields.Nested(dataset_fields), + "documents": fields.List(fields.Nested(document_fields)), + "batch": fields.String, +} + +document_status_fields = { + "id": fields.String, + "indexing_status": fields.String, + "processing_started_at": TimestampField, + "parsing_completed_at": TimestampField, + "cleaning_completed_at": TimestampField, + "splitting_completed_at": TimestampField, + "completed_at": TimestampField, + "paused_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, +} + +document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))} diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..aefa0b27580ca73726bde5534e9d3e160e47b00f --- /dev/null +++ b/api/fields/end_user_fields.py @@ -0,0 +1,8 @@ +from flask_restful import fields # type: ignore + +simple_end_user_fields = { + "id": fields.String, + "type": fields.String, + "is_anonymous": fields.Boolean, + "session_id": fields.String, +} diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc4e14a0575d713d139c5da9f24398017daf787 --- /dev/null +++ b/api/fields/external_dataset_fields.py @@ -0,0 +1,11 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +external_knowledge_api_query_detail_fields = { + "id": fields.String, + "name": fields.String, + "setting": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..f896c15f0fec70d29fd50242ca15394da1393ec1 --- /dev/null +++ b/api/fields/file_fields.py @@ -0,0 +1,39 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +upload_config_fields = { + "file_size_limit": fields.Integer, + "batch_count_limit": fields.Integer, + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, +} + +file_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} + +remote_file_info_fields = { + "file_type": fields.String(attribute="file_type"), + "file_length": fields.Integer(attribute="file_length"), +} + + +file_fields_with_signed_url = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "url": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f7e78c170529e9367e75577750b380b3584a65 --- /dev/null +++ b/api/fields/hit_testing_fields.py @@ -0,0 +1,49 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +document_fields = { + "id": fields.String, + "data_source_type": fields.String, + "name": fields.String, + "doc_type": fields.String, +} + +segment_fields = { + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "document": fields.Nested(document_fields), +} + +child_chunk_fields = { + "id": fields.String, + "content": fields.String, + "position": fields.Integer, + "score": fields.Float, +} + +hit_testing_record_fields = { + "segment": fields.Nested(segment_fields), + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "score": fields.Float, + "tsne_position": fields.Raw, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..16f265b9bb6d077a1407e8d9616852cc28cb3772 --- /dev/null +++ b/api/fields/installed_app_fields.py @@ -0,0 +1,26 @@ +from flask_restful import fields # type: ignore + +from libs.helper import AppIconUrlField, TimestampField + +app_fields = { + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "use_icon_as_answer_icon": fields.Boolean, +} + +installed_app_fields = { + "id": fields.String, + "app": fields.Nested(app_fields), + "app_owner_tenant_id": fields.String, + "is_pinned": fields.Boolean, + "last_used_at": TimestampField, + "editable": fields.Boolean, + "uninstallable": fields.Boolean, +} + +installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..0900bffb8a5d079640f2af70b12f2a703af0c7f6 --- /dev/null +++ b/api/fields/member_fields.py @@ -0,0 +1,35 @@ +from flask_restful import fields # type: ignore + +from libs.helper import AvatarUrlField, TimestampField + +simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} + +account_fields = { + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "avatar_url": AvatarUrlField, + "email": fields.String, + "is_password_set": fields.Boolean, + "interface_language": fields.String, + "interface_theme": fields.String, + "timezone": fields.String, + "last_login_at": TimestampField, + "last_login_ip": fields.String, + "created_at": TimestampField, +} + +account_with_role_fields = { + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "avatar_url": AvatarUrlField, + "email": fields.String, + "last_login_at": TimestampField, + "last_active_at": TimestampField, + "created_at": TimestampField, + "role": fields.String, + "status": fields.String, +} + +account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..0571faab08c1349f16007bc854185fa0b6403254 --- /dev/null +++ b/api/fields/message_fields.py @@ -0,0 +1,84 @@ +from flask_restful import fields # type: ignore + +from fields.conversation_fields import message_file_fields +from libs.helper import TimestampField + +from .raws import FilesContainedField + +feedback_fields = {"rating": fields.String} + +retriever_resource_fields = { + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, +} + +feedback_fields = {"rating": fields.String} + +agent_thought_fields = { + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), +} + +retriever_resource_fields = { + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, +} + +message_fields = { + "id": fields.String, + "conversation_id": fields.String, + "parent_message_id": fields.String, + "inputs": FilesContainedField, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields)), + "status": fields.String, + "error": fields.String, +} + +message_infinite_scroll_pagination_fields = { + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), +} diff --git a/api/fields/raws.py b/api/fields/raws.py new file mode 100644 index 0000000000000000000000000000000000000000..493d4b6cce7d3141db0393e354dc72bb32d2889c --- /dev/null +++ b/api/fields/raws.py @@ -0,0 +1,17 @@ +from flask_restful import fields # type: ignore + +from core.file import File + + +class FilesContainedField(fields.Raw): + def format(self, value): + return self._format_file_object(value) + + def _format_file_object(self, v): + if isinstance(v, File): + return v.model_dump() + if isinstance(v, dict): + return {k: self._format_file_object(vv) for k, vv in v.items()} + if isinstance(v, list): + return [self._format_file_object(vv) for vv in v] + return v diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..52f89859c931b74186d98cc74ffb608ae42850a9 --- /dev/null +++ b/api/fields/segment_fields.py @@ -0,0 +1,47 @@ +from flask_restful import fields # type: ignore + +from libs.helper import TimestampField + +child_chunk_fields = { + "id": fields.String, + "segment_id": fields.String, + "content": fields.String, + "position": fields.Integer, + "word_count": fields.Integer, + "type": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +segment_fields = { + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, + "updated_by": fields.String, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), +} + +segment_list_response = { + "data": fields.List(fields.Nested(segment_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, +} diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..986cd725f70910d9f1c697e1d1d1d933673c66b1 --- /dev/null +++ b/api/fields/tag_fields.py @@ -0,0 +1,3 @@ +from flask_restful import fields # type: ignore + +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..c45b33597b39788f28b5f13239bbba0b9b0690cd --- /dev/null +++ b/api/fields/workflow_app_log_fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields # type: ignore + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from fields.workflow_run_fields import workflow_run_for_log_fields +from libs.helper import TimestampField + +workflow_app_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), + "created_from": fields.String, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, +} + +workflow_app_log_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"), +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..32f979a5f2aa08549b34e8eda02246a7858fa186 --- /dev/null +++ b/api/fields/workflow_fields.py @@ -0,0 +1,71 @@ +from flask_restful import fields # type: ignore + +from core.helper import encrypter +from core.variables import SecretVariable, SegmentType, Variable +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) + + +class EnvironmentVariableField(fields.Raw): + def format(self, value): + # Mask secret variables values in environment_variables + if isinstance(value, SecretVariable): + return { + "id": value.id, + "name": value.name, + "value": encrypter.obfuscated_token(value.value), + "value_type": value.value_type.value, + } + if isinstance(value, Variable): + return { + "id": value.id, + "name": value.name, + "value": value.value, + "value_type": value.value_type.value, + } + if isinstance(value, dict): + value_type = value.get("value_type") + if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: + raise ValueError(f"Unsupported environment variable value type: {value_type}") + return value + + +conversation_variable_fields = { + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, + "description": fields.String, +} + +workflow_fields = { + "id": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "features": fields.Raw(attribute="features_dict"), + "hash": fields.String(attribute="unique_hash"), + "version": fields.String(attribute="version"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, + "tool_published": fields.Boolean, + "environment_variables": fields.List(EnvironmentVariableField()), + "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), +} + +workflow_partial_fields = { + "id": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} + +workflow_pagination_fields = { + "items": fields.List(fields.Nested(workflow_fields), attribute="items"), + "page": fields.Integer, + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), +} diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..ef59c57ec379571c307f53dab95bdf05d337cbbc --- /dev/null +++ b/api/fields/workflow_run_fields.py @@ -0,0 +1,121 @@ +from flask_restful import fields # type: ignore + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +workflow_run_for_log_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_at": TimestampField, + "finished_at": TimestampField, + "exceptions_count": fields.Integer, +} + +workflow_run_for_list_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "status": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField, + "exceptions_count": fields.Integer, + "retry_index": fields.Integer, +} + +advanced_chat_workflow_run_for_list_fields = { + "id": fields.String, + "conversation_id": fields.String, + "message_id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "status": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField, + "exceptions_count": fields.Integer, + "retry_index": fields.Integer, +} + +advanced_chat_workflow_run_pagination_fields = { + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), +} + +workflow_run_pagination_fields = { + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), +} + +workflow_run_detail_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "inputs": fields.Raw(attribute="inputs_dict"), + "status": fields.String, + "outputs": fields.Raw(attribute="outputs_dict"), + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField, + "exceptions_count": fields.Integer, +} + +retry_event_field = { + "elapsed_time": fields.Float, + "status": fields.String, + "inputs": fields.Raw(attribute="inputs"), + "process_data": fields.Raw(attribute="process_data"), + "outputs": fields.Raw(attribute="outputs"), + "metadata": fields.Raw(attribute="metadata"), + "llm_usage": fields.Raw(attribute="llm_usage"), + "error": fields.String, + "retry_index": fields.Integer, +} + + +workflow_run_node_execution_fields = { + "id": fields.String, + "index": fields.Integer, + "predecessor_node_id": fields.String, + "node_id": fields.String, + "node_type": fields.String, + "title": fields.String, + "inputs": fields.Raw(attribute="inputs_dict"), + "process_data": fields.Raw(attribute="process_data_dict"), + "outputs": fields.Raw(attribute="outputs_dict"), + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), + "extras": fields.Raw, + "created_at": TimestampField, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "finished_at": TimestampField, +} + +workflow_run_node_execution_list_fields = { + "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), +} diff --git a/api/libs/__init__.py b/api/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/libs/exception.py b/api/libs/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..5970269ecdbed5f159871d8b3ac9727af78cda18 --- /dev/null +++ b/api/libs/exception.py @@ -0,0 +1,17 @@ +from typing import Optional + +from werkzeug.exceptions import HTTPException + + +class BaseHTTPException(HTTPException): + error_code: str = "unknown" + data: Optional[dict] = None + + def __init__(self, description=None, response=None): + super().__init__(description, response) + + self.data = { + "code": self.error_code, + "message": self.description, + "status": self.code, + } diff --git a/api/libs/external_api.py b/api/libs/external_api.py new file mode 100644 index 0000000000000000000000000000000000000000..922d2d9cd3332446b899e80a1a4de9d4659b2243 --- /dev/null +++ b/api/libs/external_api.py @@ -0,0 +1,119 @@ +import re +import sys +from typing import Any + +from flask import current_app, got_request_exception +from flask_restful import Api, http_status_message # type: ignore +from werkzeug.datastructures import Headers +from werkzeug.exceptions import HTTPException + +from core.errors.error import AppInvokeQuotaExceededError + + +class ExternalApi(Api): + def handle_error(self, e): + """Error handler for the API transforms a raised exception into a Flask + response, with the appropriate HTTP status code and body. + + :param e: the raised Exception object + :type e: Exception + + """ + got_request_exception.send(current_app, exception=e) + + headers = Headers() + if isinstance(e, HTTPException): + if e.response is not None: + resp = e.get_response() + return resp + + status_code = e.code + default_data = { + "code": re.sub(r"(?= 500: + exc_info: Any = sys.exc_info() + if exc_info[1] is None: + exc_info = None + current_app.log_exception(exc_info) + + if status_code == 406 and self.default_mediatype is None: + # if we are handling NotAcceptable (406), make sure that + # make_response uses a representation we support as the + # default mediatype (so that make_response doesn't throw + # another NotAcceptable error). + supported_mediatypes = list(self.representations.keys()) # only supported application/json + fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" + data = {"code": "not_acceptable", "message": data.get("message")} + resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) + elif status_code == 400: + if isinstance(data.get("message"), dict): + param_key, param_value = list(data.get("message", {}).items())[0] + data = {"code": "invalid_param", "message": param_value, "params": param_key} + else: + if "code" not in data: + data["code"] = "unknown" + + resp = self.make_response(data, status_code, headers) + else: + if "code" not in data: + data["code"] = "unknown" + + resp = self.make_response(data, status_code, headers) + + if status_code == 401: + resp = self.unauthorized(resp) + return resp diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py new file mode 100644 index 0000000000000000000000000000000000000000..2dae87e1710bf646f8cfae2dba31c2e814205e72 --- /dev/null +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -0,0 +1,241 @@ +# +# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP +# +# =================================================================== +# The contents of this file are dedicated to the public domain. To +# the extent that dedication to the public domain is not available, +# everyone is granted a worldwide, perpetual, royalty-free, +# non-exclusive license to exercise all rights associated with the +# contents of this file for any purpose whatsoever. +# No rights are reserved. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# =================================================================== + +from hashlib import sha1 + +import Crypto.Hash.SHA1 +import Crypto.Util.number +import gmpy2 # type: ignore +from Crypto import Random +from Crypto.Signature.pss import MGF1 +from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes +from Crypto.Util.py3compat import _copy_bytes, bord +from Crypto.Util.strxor import strxor + + +class PKCS1OAepCipher: + """Cipher object for PKCS#1 v1.5 OAEP. + Do not create directly: use :func:`new` instead.""" + + def __init__(self, key, hashAlgo, mgfunc, label, randfunc): + """Initialize this PKCS#1 OAEP cipher object. + + :Parameters: + key : an RSA key object + If a private half is given, both encryption and decryption are possible. + If a public half is given, only encryption is possible. + hashAlgo : hash object + The hash function to use. This can be a module under `Crypto.Hash` + or an existing hash object created from any of such modules. If not specified, + `Crypto.Hash.SHA1` is used. + mgfunc : callable + A mask generation function that accepts two parameters: a string to + use as seed, and the length of the mask to generate, in bytes. + If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice). + label : bytes/bytearray/memoryview + A label to apply to this particular encryption. If not specified, + an empty string is used. Specifying a label does not improve + security. + randfunc : callable + A function that returns random bytes. + + :attention: Modify the mask generation function only if you know what you are doing. + Sender and receiver must use the same one. + """ + self._key = key + + if hashAlgo: + self._hashObj = hashAlgo + else: + self._hashObj = Crypto.Hash.SHA1 + + if mgfunc: + self._mgf = mgfunc + else: + self._mgf = lambda x, y: MGF1(x, y, self._hashObj) + + self._label = _copy_bytes(None, None, label) + self._randfunc = randfunc + + def can_encrypt(self): + """Legacy function to check if you can call :meth:`encrypt`. + + .. deprecated:: 3.0""" + return self._key.can_encrypt() + + def can_decrypt(self): + """Legacy function to check if you can call :meth:`decrypt`. + + .. deprecated:: 3.0""" + return self._key.can_decrypt() + + def encrypt(self, message): + """Encrypt a message with PKCS#1 OAEP. + + :param message: + The message to encrypt, also known as plaintext. It can be of + variable length, but not longer than the RSA modulus (in bytes) + minus 2, minus twice the hash output size. + For instance, if you use RSA 2048 and SHA-256, the longest message + you can encrypt is 190 byte long. + :type message: bytes/bytearray/memoryview + + :returns: The ciphertext, as large as the RSA modulus. + :rtype: bytes + + :raises ValueError: + if the message is too long. + """ + + # See 7.1.1 in RFC3447 + modBits = Crypto.Util.number.size(self._key.n) + k = ceil_div(modBits, 8) # Convert from bits to bytes + hLen = self._hashObj.digest_size + mLen = len(message) + + # Step 1b + ps_len = k - mLen - 2 * hLen - 2 + if ps_len < 0: + raise ValueError("Plaintext is too long.") + # Step 2a + lHash = sha1(self._label).digest() + # Step 2b + ps = b"\x00" * ps_len + # Step 2c + db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) + # Step 2d + ros = self._randfunc(hLen) + # Step 2e + dbMask = self._mgf(ros, k - hLen - 1) + # Step 2f + maskedDB = strxor(db, dbMask) + # Step 2g + seedMask = self._mgf(maskedDB, hLen) + # Step 2h + maskedSeed = strxor(ros, seedMask) + # Step 2i + em = b"\x00" + maskedSeed + maskedDB + # Step 3a (OS2IP) + em_int = bytes_to_long(em) + # Step 3b (RSAEP) + m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) + # Step 3c (I2OSP) + c = long_to_bytes(m_int, k) + return c + + def decrypt(self, ciphertext): + """Decrypt a message with PKCS#1 OAEP. + + :param ciphertext: The encrypted message. + :type ciphertext: bytes/bytearray/memoryview + + :returns: The original message (plaintext). + :rtype: bytes + + :raises ValueError: + if the ciphertext has the wrong length, or if decryption + fails the integrity check (in which case, the decryption + key is probably wrong). + :raises TypeError: + if the RSA key has no private half (i.e. you are trying + to decrypt using a public key). + """ + # See 7.1.2 in RFC3447 + modBits = Crypto.Util.number.size(self._key.n) + k = ceil_div(modBits, 8) # Convert from bits to bytes + hLen = self._hashObj.digest_size + # Step 1b and 1c + if len(ciphertext) != k or k < hLen + 2: + raise ValueError("Ciphertext with incorrect length.") + # Step 2a (O2SIP) + ct_int = bytes_to_long(ciphertext) + # Step 2b (RSADP) + # m_int = self._key._decrypt(ct_int) + m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) + # Complete step 2c (I2OSP) + em = long_to_bytes(m_int, k) + # Step 3a + lHash = sha1(self._label).digest() + # Step 3b + y = em[0] + # y must be 0, but we MUST NOT check it here in order not to + # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143) + maskedSeed = em[1 : hLen + 1] + maskedDB = em[hLen + 1 :] + # Step 3c + seedMask = self._mgf(maskedDB, hLen) + # Step 3d + seed = strxor(maskedSeed, seedMask) + # Step 3e + dbMask = self._mgf(seed, k - hLen - 1) + # Step 3f + db = strxor(maskedDB, dbMask) + # Step 3g + one_pos = hLen + db[hLen:].find(b"\x01") + lHash1 = db[:hLen] + invalid = bord(y) | int(one_pos < hLen) # type: ignore + hash_compare = strxor(lHash1, lHash) + for x in hash_compare: + invalid |= bord(x) # type: ignore + for x in db[hLen:one_pos]: + invalid |= bord(x) # type: ignore + if invalid != 0: + raise ValueError("Incorrect decryption.") + # Step 4 + return db[one_pos + 1 :] + + +def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): + """Return a cipher object :class:`PKCS1OAEP_Cipher` + that can be used to perform PKCS#1 OAEP encryption or decryption. + + :param key: + The key object to use to encrypt or decrypt the message. + Decryption is only possible with a private RSA key. + :type key: RSA key object + + :param hashAlgo: + The hash function to use. This can be a module under `Crypto.Hash` + or an existing hash object created from any of such modules. + If not specified, `Crypto.Hash.SHA1` is used. + :type hashAlgo: hash object + + :param mgfunc: + A mask generation function that accepts two parameters: a string to + use as seed, and the length of the mask to generate, in bytes. + If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice). + :type mgfunc: callable + + :param label: + A label to apply to this particular encryption. If not specified, + an empty string is used. Specifying a label does not improve + security. + :type label: bytes/bytearray/memoryview + + :param randfunc: + A function that returns random bytes. + The default is `Random.get_random_bytes`. + :type randfunc: callable + """ + + if randfunc is None: + randfunc = Random.get_random_bytes + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4f14f010f4b98bdadbb04c28c6a33d0cb1822adc --- /dev/null +++ b/api/libs/helper.py @@ -0,0 +1,311 @@ +import json +import logging +import random +import re +import string +import subprocess +import time +import uuid +from collections.abc import Generator, Mapping +from datetime import datetime +from hashlib import sha256 +from typing import Any, Optional, Union, cast +from zoneinfo import available_timezones + +from flask import Response, stream_with_context +from flask_restful import fields # type: ignore + +from configs import dify_config +from core.app.features.rate_limiting.rate_limit import RateLimitGenerator +from core.file import helpers as file_helpers +from extensions.ext_redis import redis_client +from models.account import Account + + +def run(script): + return subprocess.getstatusoutput("source /root/.bashrc && " + script) + + +class AppIconUrlField(fields.Raw): + def output(self, key, obj): + if obj is None: + return None + + from models.model import App, IconType, Site + + if isinstance(obj, dict) and "app" in obj: + obj = obj["app"] + + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: + return file_helpers.get_signed_file_url(obj.icon) + return None + + +class AvatarUrlField(fields.Raw): + def output(self, key, obj): + if obj is None: + return None + + from models.account import Account + + if isinstance(obj, Account) and obj.avatar is not None: + return file_helpers.get_signed_file_url(obj.avatar) + return None + + +class TimestampField(fields.Raw): + def format(self, value) -> int: + return int(value.timestamp()) + + +def email(email): + # Define a regex pattern for email addresses + pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" + # Check if the email matches the pattern + if re.match(pattern, email) is not None: + return email + + error = "{email} is not a valid email.".format(email=email) + raise ValueError(error) + + +def uuid_value(value): + if value == "": + return str(value) + + try: + uuid_obj = uuid.UUID(value) + return str(uuid_obj) + except ValueError: + error = "{value} is not a valid uuid.".format(value=value) + raise ValueError(error) + + +def alphanumeric(value: str): + # check if the value is alphanumeric and underlined + if re.match(r"^[a-zA-Z0-9_]+$", value): + return value + + raise ValueError(f"{value} is not a valid alphanumeric value") + + +def timestamp_value(timestamp): + try: + int_timestamp = int(timestamp) + if int_timestamp < 0: + raise ValueError + return int_timestamp + except ValueError: + error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) + raise ValueError(error) + + +class StrLen: + """Restrict input to an integer in a range (inclusive)""" + + def __init__(self, max_length, argument="argument"): + self.max_length = max_length + self.argument = argument + + def __call__(self, value): + length = len(value) + if length > self.max_length: + error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( + arg=self.argument, val=value, length=self.max_length + ) + raise ValueError(error) + + return value + + +class FloatRange: + """Restrict input to an float in a range (inclusive)""" + + def __init__(self, low, high, argument="argument"): + self.low = low + self.high = high + self.argument = argument + + def __call__(self, value): + value = _get_float(value) + if value < self.low or value > self.high: + error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( + arg=self.argument, val=value, lo=self.low, hi=self.high + ) + raise ValueError(error) + + return value + + +class DatetimeString: + def __init__(self, format, argument="argument"): + self.format = format + self.argument = argument + + def __call__(self, value): + try: + datetime.strptime(value, self.format) + except ValueError: + error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( + arg=self.argument, val=value, format=self.format + ) + raise ValueError(error) + + return value + + +def _get_float(value): + try: + return float(value) + except (TypeError, ValueError): + raise ValueError("{} is not a valid float".format(value)) + + +def timezone(timezone_string): + if timezone_string and timezone_string in available_timezones(): + return timezone_string + + error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) + raise ValueError(error) + + +def generate_string(n): + letters_digits = string.ascii_letters + string.digits + result = "" + for i in range(n): + result += random.choice(letters_digits) + + return result + + +def extract_remote_ip(request) -> str: + if request.headers.get("CF-Connecting-IP"): + return cast(str, request.headers.get("Cf-Connecting-Ip")) + elif request.headers.getlist("X-Forwarded-For"): + return cast(str, request.headers.getlist("X-Forwarded-For")[0]) + else: + return cast(str, request.remote_addr) + + +def generate_text_hash(text: str) -> str: + hash_text = str(text) + "None" + return sha256(hash_text.encode()).hexdigest() + + +def compact_generate_response( + response: Union[Mapping[str, Any], RateLimitGenerator, Generator[str, None, None]], +) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype="application/json") + else: + + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + + +class TokenManager: + @classmethod + def generate_token( + cls, + token_type: str, + account: Optional[Account] = None, + email: Optional[str] = None, + additional_data: Optional[dict] = None, + ) -> str: + if account is None and email is None: + raise ValueError("Account or email must be provided") + + account_id = account.id if account else None + account_email = account.email if account else email + + if account_id: + old_token = cls._get_current_token_for_account(account_id, token_type) + if old_token: + if isinstance(old_token, bytes): + old_token = old_token.decode("utf-8") + cls.revoke_token(old_token, token_type) + + token = str(uuid.uuid4()) + token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} + if additional_data: + token_data.update(additional_data) + + expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") + if expiry_minutes is None: + raise ValueError(f"Expiry minutes for {token_type} token is not set") + token_key = cls._get_token_key(token, token_type) + expiry_time = int(expiry_minutes * 60) + redis_client.setex(token_key, expiry_time, json.dumps(token_data)) + + if account_id: + cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes) + + return token + + @classmethod + def _get_token_key(cls, token: str, token_type: str) -> str: + return f"{token_type}:token:{token}" + + @classmethod + def revoke_token(cls, token: str, token_type: str): + token_key = cls._get_token_key(token, token_type) + redis_client.delete(token_key) + + @classmethod + def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: + key = cls._get_token_key(token, token_type) + token_data_json = redis_client.get(key) + if token_data_json is None: + logging.warning(f"{token_type} token {token} not found with key {key}") + return None + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) + return token_data + + @classmethod + def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: + key = cls._get_account_token_key(account_id, token_type) + current_token: Optional[str] = redis_client.get(key) + return current_token + + @classmethod + def _set_current_token_for_account( + cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] + ): + key = cls._get_account_token_key(account_id, token_type) + expiry_time = int(expiry_hours * 60 * 60) + redis_client.setex(key, expiry_time, token) + + @classmethod + def _get_account_token_key(cls, account_id: str, token_type: str) -> str: + return f"{token_type}:account:{account_id}" + + +class RateLimiter: + def __init__(self, prefix: str, max_attempts: int, time_window: int): + self.prefix = prefix + self.max_attempts = max_attempts + self.time_window = time_window + + def _get_key(self, email: str) -> str: + return f"{self.prefix}:{email}" + + def is_rate_limited(self, email: str) -> bool: + key = self._get_key(email) + current_time = int(time.time()) + window_start_time = current_time - self.time_window + + redis_client.zremrangebyscore(key, "-inf", window_start_time) + attempts = redis_client.zcard(key) + + if attempts and int(attempts) >= self.max_attempts: + return True + return False + + def increment_rate_limit(self, email: str): + key = self._get_key(email) + current_time = int(time.time()) + + redis_client.zadd(key, {current_time: current_time}) + redis_client.expire(key, self.time_window * 2) diff --git a/api/libs/infinite_scroll_pagination.py b/api/libs/infinite_scroll_pagination.py new file mode 100644 index 0000000000000000000000000000000000000000..133ccb188338e26365eda2da325753b2801c564c --- /dev/null +++ b/api/libs/infinite_scroll_pagination.py @@ -0,0 +1,5 @@ +class InfiniteScrollPagination: + def __init__(self, data, limit, has_more): + self.data = data + self.limit = limit + self.has_more = has_more diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab53b6294db93a6a804dbe179bf7dc21548f091 --- /dev/null +++ b/api/libs/json_in_md_parser.py @@ -0,0 +1,46 @@ +import json + +from core.llm_generator.output_parser.errors import OutputParserError + + +def parse_json_markdown(json_string: str) -> dict: + # Get json from the backticks/braces + json_string = json_string.strip() + starts = ["```json", "```", "``", "`", "{"] + ends = ["```", "``", "`", "}"] + end_index = -1 + start_index = 0 + parsed: dict = {} + for s in starts: + start_index = json_string.find(s) + if start_index != -1: + if json_string[start_index] != "{": + start_index += len(s) + break + if start_index != -1: + for e in ends: + end_index = json_string.rfind(e, start_index) + if end_index != -1: + if json_string[end_index] == "}": + end_index += 1 + break + if start_index != -1 and end_index != -1 and start_index < end_index: + extracted_content = json_string[start_index:end_index].strip() + parsed = json.loads(extracted_content) + else: + raise ValueError("could not find json block in the output.") + + return parsed + + +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise OutputParserError(f"got invalid json object. error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserError( + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" + ) + return json_obj diff --git a/api/libs/login.py b/api/libs/login.py new file mode 100644 index 0000000000000000000000000000000000000000..5395534a6df93ad9d2c02a51a1535fe33a7fce35 --- /dev/null +++ b/api/libs/login.py @@ -0,0 +1,106 @@ +from functools import wraps +from typing import Any + +from flask import current_app, g, has_request_context, request +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore +from werkzeug.exceptions import Unauthorized +from werkzeug.local import LocalProxy + +from configs import dify_config +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin + +#: A proxy for the current user. If no user is logged in, this will be an +#: anonymous user +current_user: Any = LocalProxy(lambda: _get_user()) + + +def login_required(func): + """ + If you decorate a view with this, it will ensure that the current user is + logged in and authenticated before calling the actual view. (If they are + not, it calls the :attr:`LoginManager.unauthorized` callback.) For + example:: + + @app.route('/post') + @login_required + def post(): + pass + + If there are only certain times you need to require that your user is + logged in, you can do so with:: + + if not current_user.is_authenticated: + return current_app.login_manager.unauthorized() + + ...which is essentially the code that this function adds to your views. + + It can be convenient to globally turn off authentication when unit testing. + To enable this, if the application configuration variable `LOGIN_DISABLED` + is set to `True`, this decorator will be ignored. + + .. Note :: + + Per `W3 guidelines for CORS preflight requests + `_, + HTTP ``OPTIONS`` requests are exempt from login checks. + + :param func: The view function to decorate. + :type func: function + """ + + @wraps(func) + def decorated_view(*args, **kwargs): + auth_header = request.headers.get("Authorization") + if dify_config.ADMIN_API_KEY_ENABLE: + if auth_header: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + admin_api_key = dify_config.ADMIN_API_KEY + if admin_api_key: + if admin_api_key == auth_token: + workspace_id = request.headers.get("X-WORKSPACE-ID") + if workspace_id: + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == workspace_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role == "owner") + .one_or_none() + ) + if tenant_account_join: + tenant, ta = tenant_account_join + account = Account.query.filter_by(id=ta.account_id).first() + # Login admin + if account: + account.current_tenant = tenant + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: + pass + elif not current_user.is_authenticated: + return current_app.login_manager.unauthorized() # type: ignore + + # flask 1.x compatibility + # current_app.ensure_sync is only available in Flask >= 2.0 + if callable(getattr(current_app, "ensure_sync", None)): + return current_app.ensure_sync(func)(*args, **kwargs) + return func(*args, **kwargs) + + return decorated_view + + +def _get_user(): + if has_request_context(): + if "_login_user" not in g: + current_app.login_manager._load_user() # type: ignore + + return g._login_user + + return None diff --git a/api/libs/oauth.py b/api/libs/oauth.py new file mode 100644 index 0000000000000000000000000000000000000000..df75b550195298529017122f0a1cce0e950eaa52 --- /dev/null +++ b/api/libs/oauth.py @@ -0,0 +1,133 @@ +import urllib.parse +from dataclasses import dataclass +from typing import Optional + +import requests + + +@dataclass +class OAuthUserInfo: + id: str + name: str + email: str + + +class OAuth: + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self): + raise NotImplementedError() + + def get_access_token(self, code: str): + raise NotImplementedError() + + def get_raw_user_info(self, token: str): + raise NotImplementedError() + + def get_user_info(self, token: str) -> OAuthUserInfo: + raw_info = self.get_raw_user_info(token) + return self._transform_user_info(raw_info) + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + raise NotImplementedError() + + +class GitHubOAuth(OAuth): + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _USER_INFO_URL = "https://api.github.com/user" + _EMAIL_INFO_URL = "https://api.github.com/user/emails" + + def get_authorization_url(self, invite_token: Optional[str] = None): + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": "user:email", # Request only basic user information + } + if invite_token: + params["state"] = invite_token + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + } + headers = {"Accept": "application/json"} + response = requests.post(self._TOKEN_URL, data=data, headers=headers) + + response_json = response.json() + access_token = response_json.get("access_token") + + if not access_token: + raise ValueError(f"Error in GitHub OAuth: {response_json}") + + return access_token + + def get_raw_user_info(self, token: str): + headers = {"Authorization": f"token {token}"} + response = requests.get(self._USER_INFO_URL, headers=headers) + response.raise_for_status() + user_info = response.json() + + email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) + email_info = email_response.json() + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) + + return {**user_info, "email": primary_email.get("email", "")} + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + email = raw_info.get("email") + if not email: + email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" + return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) + + +class GoogleOAuth(OAuth): + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" + + def get_authorization_url(self, invite_token: Optional[str] = None): + params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid email", + } + if invite_token: + params["state"] = invite_token + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, + } + headers = {"Accept": "application/json"} + response = requests.post(self._TOKEN_URL, data=data, headers=headers) + + response_json = response.json() + access_token = response_json.get("access_token") + + if not access_token: + raise ValueError(f"Error in Google OAuth: {response_json}") + + return access_token + + def get_raw_user_info(self, token: str): + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(self._USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ba08d351c0ccff42d2714f438e0159e12822af --- /dev/null +++ b/api/libs/oauth_data_source.py @@ -0,0 +1,303 @@ +import datetime +import urllib.parse +from typing import Any + +import requests +from flask_login import current_user # type: ignore + +from extensions.ext_database import db +from models.source import DataSourceOauthBinding + + +class OAuthDataSource: + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self): + raise NotImplementedError() + + def get_access_token(self, code: str): + raise NotImplementedError() + + +class NotionOAuth(OAuthDataSource): + _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" + _TOKEN_URL = "https://api.notion.com/v1/oauth/token" + _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" + _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" + _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" + + def get_authorization_url(self): + params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "owner": "user", + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} + headers = {"Accept": "application/json"} + auth = (self.client_id, self.client_secret) + response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + + response_json = response.json() + access_token = response_json.get("access_token") + if not access_token: + raise ValueError(f"Error in Notion OAuth: {response_json}") + workspace_name = response_json.get("workspace_name") + workspace_icon = response_json.get("workspace_icon") + workspace_id = response_json.get("workspace_id") + # get all authorized pages + pages = self.get_authorized_pages(access_token) + source_info = { + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), + } + # save data source binding + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) + ).first() + if data_source_binding: + data_source_binding.source_info = source_info + data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=source_info, + provider="notion", + ) + db.session.add(new_data_source_binding) + db.session.commit() + + def save_internal_access_token(self, access_token: str): + workspace_name = self.notion_workspace_name(access_token) + workspace_icon = None + workspace_id = current_user.current_tenant_id + # get all authorized pages + pages = self.get_authorized_pages(access_token) + source_info = { + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), + } + # save data source binding + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) + ).first() + if data_source_binding: + data_source_binding.source_info = source_info + data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=source_info, + provider="notion", + ) + db.session.add(new_data_source_binding) + db.session.commit() + + def sync_data_source(self, binding_id: str): + # save data source binding + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) + ).first() + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = data_source_binding.source_info + new_source_info = { + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, + "total": len(pages), + } + data_source_binding.source_info = new_source_info + data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + else: + raise ValueError("Data source binding not found") + + def get_authorized_pages(self, access_token: str): + pages = [] + page_results = self.notion_page_search(access_token) + database_results = self.notion_database_search(access_token) + # get page detail + for page_result in page_results: + page_id = page_result["id"] + page_name = "Untitled" + for key in page_result["properties"]: + if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: + title_list = page_result["properties"][key]["title"] + if len(title_list) > 0 and "plain_text" in title_list[0]: + page_name = title_list[0]["plain_text"] + page_icon = page_result["icon"] + if page_icon: + icon_type = page_icon["type"] + if icon_type in {"external", "file"}: + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} + else: + icon = {"type": "emoji", "emoji": page_icon[icon_type]} + else: + icon = None + parent = page_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": + parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) + elif parent_type == "workspace": + parent_id = "root" + else: + parent_id = parent[parent_type] + page = { + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "page", + } + pages.append(page) + # get database detail + for database_result in database_results: + page_id = database_result["id"] + if len(database_result["title"]) > 0: + page_name = database_result["title"][0]["plain_text"] + else: + page_name = "Untitled" + page_icon = database_result["icon"] + if page_icon: + icon_type = page_icon["type"] + if icon_type in {"external", "file"}: + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} + else: + icon = {"type": icon_type, icon_type: page_icon[icon_type]} + else: + icon = None + parent = database_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": + parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) + elif parent_type == "workspace": + parent_id = "root" + else: + parent_id = parent[parent_type] + page = { + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "database", + } + pages.append(page) + return pages + + def notion_page_search(self, access_token: str): + results = [] + next_cursor = None + has_more = True + + while has_more: + data: dict[str, Any] = { + "filter": {"value": "page", "property": "object"}, + **({"start_cursor": next_cursor} if next_cursor else {}), + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + + results.extend(response_json.get("results", [])) + + has_more = response_json.get("has_more", False) + next_cursor = response_json.get("next_cursor", None) + + return results + + def notion_block_parent_page_id(self, access_token: str, block_id: str): + headers = { + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) + response_json = response.json() + if response.status_code != 200: + message = response_json.get("message", "unknown error") + raise ValueError(f"Error fetching block parent page ID: {message}") + parent = response_json["parent"] + parent_type = parent["type"] + if parent_type == "block_id": + return self.notion_block_parent_page_id(access_token, parent[parent_type]) + return parent[parent_type] + + def notion_workspace_name(self, access_token: str): + headers = { + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + response = requests.get(url=self._NOTION_BOT_USER, headers=headers) + response_json = response.json() + if "object" in response_json and response_json["object"] == "user": + user_type = response_json["type"] + user_info = response_json[user_type] + if "workspace_name" in user_info: + return user_info["workspace_name"] + return "workspace" + + def notion_database_search(self, access_token: str): + results = [] + next_cursor = None + has_more = True + + while has_more: + data: dict[str, Any] = { + "filter": {"value": "database", "property": "object"}, + **({"start_cursor": next_cursor} if next_cursor else {}), + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + + results.extend(response_json.get("results", [])) + + has_more = response_json.get("has_more", False) + next_cursor = response_json.get("next_cursor", None) + + return results diff --git a/api/libs/passport.py b/api/libs/passport.py new file mode 100644 index 0000000000000000000000000000000000000000..8df4f529bc389830c6db7d60389834a87b9403ab --- /dev/null +++ b/api/libs/passport.py @@ -0,0 +1,22 @@ +import jwt +from werkzeug.exceptions import Unauthorized + +from configs import dify_config + + +class PassportService: + def __init__(self): + self.sk = dify_config.SECRET_KEY + + def issue(self, payload): + return jwt.encode(payload, self.sk, algorithm="HS256") + + def verify(self, token): + try: + return jwt.decode(token, self.sk, algorithms=["HS256"]) + except jwt.exceptions.InvalidSignatureError: + raise Unauthorized("Invalid token signature.") + except jwt.exceptions.DecodeError: + raise Unauthorized("Invalid token.") + except jwt.exceptions.ExpiredSignatureError: + raise Unauthorized("Token has expired.") diff --git a/api/libs/password.py b/api/libs/password.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf55c57e5bc6eeb272a7db87d19657d9e9d8feb --- /dev/null +++ b/api/libs/password.py @@ -0,0 +1,26 @@ +import base64 +import binascii +import hashlib +import re + +password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + + +def valid_password(password): + # Define a regex pattern for password rules + pattern = password_pattern + # Check if the password matches the pattern + if re.match(pattern, password) is not None: + return password + + raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.") + + +def hash_password(password_str, salt_byte): + dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) + return binascii.hexlify(dk) + + +def compare_password(password_str, password_hashed_base64, salt_base64): + # compare password for login + return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64) diff --git a/api/libs/rsa.py b/api/libs/rsa.py new file mode 100644 index 0000000000000000000000000000000000000000..637bcc4a1dda6177745ceeee4b3754d25b892b0a --- /dev/null +++ b/api/libs/rsa.py @@ -0,0 +1,93 @@ +import hashlib + +from Crypto.Cipher import AES +from Crypto.PublicKey import RSA +from Crypto.Random import get_random_bytes + +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher + + +def generate_key_pair(tenant_id): + private_key = RSA.generate(2048) + public_key = private_key.publickey() + + pem_private = private_key.export_key() + pem_public = public_key.export_key() + + filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" + + storage.save(filepath, pem_private) + + return pem_public.decode() + + +prefix_hybrid = b"HYBRID:" + + +def encrypt(text, public_key): + if isinstance(public_key, str): + public_key = public_key.encode() + + aes_key = get_random_bytes(16) + cipher_aes = AES.new(aes_key, AES.MODE_EAX) + + ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode()) + + rsa_key = RSA.import_key(public_key) + cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key) + + enc_aes_key = cipher_rsa.encrypt(aes_key) + + encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext + + return prefix_hybrid + encrypted_data + + +def get_decrypt_decoding(tenant_id): + filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" + + cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + private_key = redis_client.get(cache_key) + if not private_key: + try: + private_key = storage.load(filepath) + except FileNotFoundError: + raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id)) + + redis_client.setex(cache_key, 120, private_key) + + rsa_key = RSA.import_key(private_key) + cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key) + + return rsa_key, cipher_rsa + + +def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): + if encrypted_text.startswith(prefix_hybrid): + encrypted_text = encrypted_text[len(prefix_hybrid) :] + + enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] + nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] + tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] + ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] + + aes_key = cipher_rsa.decrypt(enc_aes_key) + + cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce) + decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag) + else: + decrypted_text = cipher_rsa.decrypt(encrypted_text) + + return decrypted_text.decode() + + +def decrypt(encrypted_text, tenant_id): + rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id) + + return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa) + + +class PrivkeyNotFoundError(Exception): + pass diff --git a/api/libs/smtp.py b/api/libs/smtp.py new file mode 100644 index 0000000000000000000000000000000000000000..2325d69a413d4a89d4c49c9b6a992f69464e25ab --- /dev/null +++ b/api/libs/smtp.py @@ -0,0 +1,52 @@ +import logging +import smtplib +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText + + +class SMTPClient: + def __init__( + self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False + ): + self.server = server + self.port = port + self._from = _from + self.username = username + self.password = password + self.use_tls = use_tls + self.opportunistic_tls = opportunistic_tls + + def send(self, mail: dict): + smtp = None + try: + if self.use_tls: + if self.opportunistic_tls: + smtp = smtplib.SMTP(self.server, self.port, timeout=10) + smtp.starttls() + else: + smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) + else: + smtp = smtplib.SMTP(self.server, self.port, timeout=10) + + if self.username and self.password: + smtp.login(self.username, self.password) + + msg = MIMEMultipart() + msg["Subject"] = mail["subject"] + msg["From"] = self._from + msg["To"] = mail["to"] + msg.attach(MIMEText(mail["html"], "html")) + + smtp.sendmail(self._from, mail["to"], msg.as_string()) + except smtplib.SMTPException as e: + logging.exception("SMTP error occurred") + raise + except TimeoutError as e: + logging.exception("Timeout occurred while sending email") + raise + except Exception as e: + logging.exception(f"Unexpected error occurred while sending email to {mail['to']}") + raise + finally: + if smtp: + smtp.quit() diff --git a/api/migrations/README b/api/migrations/README new file mode 100644 index 0000000000000000000000000000000000000000..0e048441597444a7e2850d6d7c4ce15550f79bda --- /dev/null +++ b/api/migrations/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/api/migrations/alembic.ini b/api/migrations/alembic.ini new file mode 100644 index 0000000000000000000000000000000000000000..aa21ecabcddd27eef9415fb8b25082402806968f --- /dev/null +++ b/api/migrations/alembic.ini @@ -0,0 +1,51 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/migrations/env.py b/api/migrations/env.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3a122c04bc2d266a45165deea4b7df19ef7055 --- /dev/null +++ b/api/migrations/env.py @@ -0,0 +1,113 @@ +import logging +from logging.config import fileConfig + +from alembic import context +from flask import current_app + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + +def include_object(object, name, type_, reflected, compare_to): + if type_ == "foreign_key_constraint": + return False + else: + return True + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=get_metadata(), literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = get_engine() + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + include_object=include_object, + **current_app.extensions['migrate'].configure_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() + diff --git a/api/migrations/script.py.mako b/api/migrations/script.py.mako new file mode 100644 index 0000000000000000000000000000000000000000..728ccc6a9a530dc907554d9899d6df09279d5fba --- /dev/null +++ b/api/migrations/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import models as models +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae9e8769a21e5c71dc119ecb0d6ca3c1bff1540 --- /dev/null +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -0,0 +1,33 @@ +"""rename api provider description + +Revision ID: 00bacef91f18 +Revises: 8ec536f3c800 +Create Date: 2024-01-07 04:07:34.482983 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '00bacef91f18' +down_revision = '8ec536f3c800' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.Text(), nullable=False)) + batch_op.drop_column('description_str') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) + batch_op.drop_column('description') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd4ec552b4ea9e9867799513fc77e33638a2764 --- /dev/null +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -0,0 +1,33 @@ +"""add workflow tool label and tool bindings idx + +Revision ID: 03f98355ba0e +Revises: 9e98fbaffb88 +Create Date: 2024-05-25 07:17:00.539125 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '03f98355ba0e' +down_revision = '9e98fbaffb88' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.create_unique_constraint('unique_tool_label_bind', ['tool_id', 'label_name']) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('label') + + with op.batch_alter_table('tool_label_bindings', schema=None) as batch_op: + batch_op.drop_constraint('unique_tool_label_bind', type_='unique') diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py new file mode 100644 index 0000000000000000000000000000000000000000..153861a71a59948617f18191c75f63d191f0f3bd --- /dev/null +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -0,0 +1,39 @@ +"""update AppModelConfig and add table TracingAppConfig + +Revision ID: 04c602f5dc9b +Revises: 4e99a8df00ff +Create Date: 2024-06-12 07:49:07.666510 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '04c602f5dc9b' +down_revision = '4ff534e1eb11' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tracing_app_configs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('tracing_provider', sa.String(length=255), nullable=True), + sa.Column('tracing_config', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ## + op.drop_table('tracing_app_configs') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py new file mode 100644 index 0000000000000000000000000000000000000000..a589f1f08b099c771d6c6ce008cbd059f4a518cf --- /dev/null +++ b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py @@ -0,0 +1,51 @@ +"""add api tool privacy + +Revision ID: 053da0c1d756 +Revises: 4829e54d2fee +Create Date: 2024-01-12 06:47:21.656262 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '053da0c1d756' +down_revision = '4829e54d2fee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_conversation_variables', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('variables_str', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey') + ) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), nullable=True)) + batch_op.alter_column('icon', + existing_type=sa.VARCHAR(length=256), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('icon', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=256), + existing_nullable=False) + batch_op.drop_column('privacy_policy') + + op.drop_table('tool_conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py new file mode 100644 index 0000000000000000000000000000000000000000..58863fe3a7b890badda1225e95d0096317863431 --- /dev/null +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -0,0 +1,32 @@ +"""remove tool id from model invoke + +Revision ID: 114eed84c228 +Revises: c71211c8f604 +Create Date: 2024-01-10 04:40:57.257824 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '114eed84c228' +down_revision = 'c71211c8f604' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.drop_column('tool_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py new file mode 100644 index 0000000000000000000000000000000000000000..8907f781174b6bd7b3194d3c2c207f4757ea17e3 --- /dev/null +++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py @@ -0,0 +1,34 @@ +"""add dataset permission tenant id + +Revision ID: 161cadc1af8d +Revises: 7e6a8693e07a +Create Date: 2024-07-05 14:30:59.472593 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '161cadc1af8d' +down_revision = '7e6a8693e07a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + # Step 1: Add column without NOT NULL constraint + op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + batch_op.drop_column('tenant_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/16830a790f0f_.py b/api/migrations/versions/16830a790f0f_.py new file mode 100644 index 0000000000000000000000000000000000000000..38d6e4940a0196e9afd4a8908ac057ab922ba781 --- /dev/null +++ b/api/migrations/versions/16830a790f0f_.py @@ -0,0 +1,31 @@ +"""empty message + +Revision ID: 16830a790f0f +Revises: 380c6aa5a70d +Create Date: 2024-02-01 08:21:31.111119 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '16830a790f0f' +down_revision = '380c6aa5a70d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.add_column(sa.Column('current', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.drop_column('current') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py new file mode 100644 index 0000000000000000000000000000000000000000..6791cf4578332d6234a847c17b09acbf0ff5c574 --- /dev/null +++ b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py @@ -0,0 +1,79 @@ +"""add provider model support + +Revision ID: 16fa53d9faec +Revises: 8d2d099ceb74 +Create Date: 2023-08-06 16:57:51.248337 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '16fa53d9faec' +down_revision = '8d2d099ceb74' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('provider_models', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_pkey'), + sa.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + ) + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.create_index('provider_model_tenant_id_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + op.create_table('tenant_default_models', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_default_model_pkey') + ) + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.create_index('tenant_default_model_tenant_id_provider_type_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + op.create_table('tenant_preferred_model_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('preferred_provider_type', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey') + ) + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.create_index('tenant_preferred_model_provider_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.drop_index('tenant_preferred_model_provider_tenant_provider_idx') + + op.drop_table('tenant_preferred_model_providers') + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.drop_index('tenant_default_model_tenant_id_provider_type_idx') + + op.drop_table('tenant_default_models') + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_index('provider_model_tenant_id_provider_idx') + + op.drop_table('provider_models') + # ### end Alembic commands ### diff --git a/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py b/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py new file mode 100644 index 0000000000000000000000000000000000000000..77071484892cb1ffdc2c872a5c4f18956001769e --- /dev/null +++ b/api/migrations/versions/17b5ab037c40_add_keyworg_table_storage_type.py @@ -0,0 +1,33 @@ +"""add-keyworg-table-storage-type + +Revision ID: 17b5ab037c40 +Revises: a8f9b3c45e4a +Create Date: 2024-04-01 09:48:54.232201 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '17b5ab037c40' +down_revision = 'a8f9b3c45e4a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.drop_column('data_source_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py new file mode 100644 index 0000000000000000000000000000000000000000..13a823f7ec6373b34530336c58e6360ceb07b2c3 --- /dev/null +++ b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py @@ -0,0 +1,37 @@ +"""modify provider model name length + +Revision ID: 187385f442fc +Revises: 88072f0caa04 +Create Date: 2024-01-02 07:18:43.887428 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '187385f442fc' +down_revision = '88072f0caa04' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py new file mode 100644 index 0000000000000000000000000000000000000000..db966252f1a63c49fe335cfeb13c08c10215527b --- /dev/null +++ b/api/migrations/versions/2024_08_09_0801-1787fbae959a_update_tools_original_url_length.py @@ -0,0 +1,39 @@ +"""update tools original_url length + +Revision ID: 1787fbae959a +Revises: eeb2e349e6ac +Create Date: 2024-08-09 08:01:12.817620 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '1787fbae959a' +down_revision = 'eeb2e349e6ac' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('original_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.String(length=2048), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('original_url', + existing_type=sa.String(length=2048), + type_=sa.VARCHAR(length=255), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..16e1efd4efd4ed07ef2d293f69ab2b456a3f5af9 --- /dev/null +++ b/api/migrations/versions/2024_08_13_0633-63a83fcf12ba_support_conversation_variables.py @@ -0,0 +1,51 @@ +"""support conversation variables + +Revision ID: 63a83fcf12ba +Revises: 1787fbae959a +Create Date: 2024-08-13 06:33:07.950379 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '63a83fcf12ba' +down_revision = '1787fbae959a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow__conversation_variables', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('conversation_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('data', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey')) + ) + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('conversation_variables') + + with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow__conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow__conversation_variables_app_id_idx')) + + op.drop_table('workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py new file mode 100644 index 0000000000000000000000000000000000000000..eba78e2e77d5d85a920814d20d9ad8a39e625c04 --- /dev/null +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -0,0 +1,33 @@ +"""add conversations.dialogue_count + +Revision ID: 8782057ff0dc +Revises: 63a83fcf12ba +Create Date: 2024-08-14 13:54:25.161324 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8782057ff0dc' +down_revision = '63a83fcf12ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('dialogue_count') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2e4104426275e38b1ea7ca38974f590b9b3d39 --- /dev/null +++ b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py @@ -0,0 +1,51 @@ +"""add-tidb-auth-binding + +Revision ID: 0251a1c768cc +Revises: 63a83fcf12ba +Create Date: 2024-08-15 09:56:59.012490 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '0251a1c768cc' +down_revision = 'bbadea11becb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tidb_auth_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('cluster_id', sa.String(length=255), nullable=False), + sa.Column('cluster_name', sa.String(length=255), nullable=False), + sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), + sa.Column('account', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') + ) + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) + batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) + batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('tidb_auth_bindings_tenant_idx') + batch_op.drop_index('tidb_auth_bindings_created_at_idx') + batch_op.drop_index('tidb_auth_bindings_active_idx') + batch_op.drop_index('tidb_auth_bindings_status_idx') + op.drop_table('tidb_auth_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py new file mode 100644 index 0000000000000000000000000000000000000000..d814666eefd2f2ced646b5e1f66e63712d3d0d32 --- /dev/null +++ b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py @@ -0,0 +1,39 @@ +"""app and site icon type + +Revision ID: a6be81136580 +Revises: 8782057ff0dc +Create Date: 2024-08-15 10:01:24.697888 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'a6be81136580' +down_revision = '8782057ff0dc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc7fed818ea2b02a641047f432efa3bdcaf5334 --- /dev/null +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -0,0 +1,28 @@ +"""rename workflow__conversation_variables to workflow_conversation_variables + +Revision ID: 2dbe42621d96 +Revises: a6be81136580 +Create Date: 2024-08-20 04:55:38.160010 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2dbe42621d96' +down_revision = 'a6be81136580' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py new file mode 100644 index 0000000000000000000000000000000000000000..e0066a302cd5a2adc51bc07d527d66237bd652db --- /dev/null +++ b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py @@ -0,0 +1,52 @@ +"""add created_by and updated_by to app, modelconfig, and site + +Revision ID: d0187d6a88dd +Revises: 2dbe42621d96 +Create Date: 2024-08-25 04:41:18.157397 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "d0187d6a88dd" +down_revision = "2dbe42621d96" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py new file mode 100644 index 0000000000000000000000000000000000000000..4406d51ed07aa250a17e8b924016193d613bbd03 --- /dev/null +++ b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py @@ -0,0 +1,45 @@ +"""add use_icon_as_answer_icon fields for app and site + +Revision ID: 030f4915f36a +Revises: d0187d6a88dd +Create Date: 2024-09-01 12:55:45.129687 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "030f4915f36a" +down_revision = "d0187d6a88dd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..fd957eeafb2b6c9a393855624b885cda9bd91c2b --- /dev/null +++ b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py @@ -0,0 +1,36 @@ +"""add parent_message_id to messages + +Revision ID: d57ba9ebb251 +Revises: 675b5321501b +Create Date: 2024-09-11 10:12:45.826265 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'd57ba9ebb251' +down_revision = '675b5321501b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True)) + + # Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs + op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('parent_message_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py new file mode 100644 index 0000000000000000000000000000000000000000..5337b340db7690f6ee0d13483f4ab2cb71205438 --- /dev/null +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -0,0 +1,48 @@ +"""update-retrieval-resource + +Revision ID: 6af6a521a53e +Revises: ec3df697ebbb +Create Date: 2024-09-24 09:22:43.570120 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6af6a521a53e' +down_revision = 'd57ba9ebb251' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb76e72c1ebb24b4867ba4eda636275ecad8791 --- /dev/null +++ b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py @@ -0,0 +1,73 @@ +"""external_knowledge_api + +Revision ID: 33f5fac87f29 +Revises: 6af6a521a53e +Create Date: 2024-09-25 04:34:57.249436 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '33f5fac87f29' +down_revision = '6af6a521a53e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('external_knowledge_apis', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('settings', sa.Text(), nullable=True), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey') + ) + with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op: + batch_op.create_index('external_knowledge_apis_name_idx', ['name'], unique=False) + batch_op.create_index('external_knowledge_apis_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('external_knowledge_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('external_knowledge_id', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey') + ) + with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: + batch_op.create_index('external_knowledge_bindings_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_external_knowledge_api_idx', ['external_knowledge_api_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_external_knowledge_idx', ['external_knowledge_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: + batch_op.drop_index('external_knowledge_bindings_tenant_idx') + batch_op.drop_index('external_knowledge_bindings_external_knowledge_idx') + batch_op.drop_index('external_knowledge_bindings_external_knowledge_api_idx') + batch_op.drop_index('external_knowledge_bindings_dataset_idx') + + op.drop_table('external_knowledge_bindings') + with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op: + batch_op.drop_index('external_knowledge_apis_tenant_idx') + batch_op.drop_index('external_knowledge_apis_name_idx') + + op.drop_table('external_knowledge_apis') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py new file mode 100644 index 0000000000000000000000000000000000000000..38a5cdf8e5008f916bfa521e801a33ffbc2b27b5 --- /dev/null +++ b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py @@ -0,0 +1,48 @@ +"""fix wrong service-api history + +Revision ID: d8e744d88ed6 +Revises: 33f5fac87f29 +Create Date: 2024-10-09 13:29:23.548498 + +""" +from alembic import op +from constants import UUID_NIL +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd8e744d88ed6' +down_revision = '33f5fac87f29' +branch_labels = None +depends_on = None + +# (UTC) release date of v0.9.0 +v0_9_0_release_date= '2024-09-29 12:00:00' + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + sql = f"""UPDATE + messages +SET + parent_message_id = '{UUID_NIL}' +WHERE + invoke_from = 'service-api' + AND parent_message_id IS NULL + AND created_at >= '{v0_9_0_release_date}';""" + op.execute(sql) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + sql = f"""UPDATE + messages +SET + parent_message_id = NULL +WHERE + invoke_from = 'service-api' + AND parent_message_id = '{UUID_NIL}' + AND created_at >= '{v0_9_0_release_date}';""" + op.execute(sql) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py new file mode 100644 index 0000000000000000000000000000000000000000..c17d1db77a96df418df8520d1d833cb432a761d0 --- /dev/null +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -0,0 +1,49 @@ +"""add name and size to tool_files + +Revision ID: bbadea11becb +Revises: 33f5fac87f29 +Create Date: 2024-10-10 05:16:14.764268 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bbadea11becb' +down_revision = 'd8e744d88ed6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Get the database connection + conn = op.get_bind() + + # Use SQLAlchemy inspector to get the columns of the 'tool_files' table + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('tool_files')] + + # If 'name' or 'size' columns already exist, exit the upgrade function + if 'name' in columns or 'size' in columns: + return + + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True)) + op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL") + op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL") + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('name', existing_type=sa.String(), nullable=False) + batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.drop_column('size') + batch_op.drop_column('name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf148bc4e881aff05c64a748271bc1621b72a9 --- /dev/null +++ b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py @@ -0,0 +1,42 @@ +"""add_white_list + +Revision ID: 43fa78bc3b7d +Revises: 0251a1c768cc +Create Date: 2024-10-22 09:59:23.713716 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '43fa78bc3b7d' +down_revision = '0251a1c768cc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('whitelists', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='whitelists_pkey') + ) + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.drop_index('whitelists_tenant_idx') + + op.drop_table('whitelists') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py new file mode 100644 index 0000000000000000000000000000000000000000..a749c8bddfee012a85c28781e153ffab7df5d8fc --- /dev/null +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -0,0 +1,31 @@ +"""Add upload_files.source_url + +Revision ID: d3f6769a94a3 +Revises: 43fa78bc3b7d +Create Date: 2024-11-01 04:34:23.816198 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd3f6769a94a3' +down_revision = '43fa78bc3b7d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('source_url', sa.String(length=255), server_default='', nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('source_url') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py new file mode 100644 index 0000000000000000000000000000000000000000..81a7978f730a37fd70817609b95676a03080080f --- /dev/null +++ b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py @@ -0,0 +1,52 @@ +"""rename conversation variables index name + +Revision ID: 93ad8c19c40b +Revises: d3f6769a94a3 +Create Date: 2024-11-01 04:49:53.100250 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '93ad8c19c40b' +down_revision = 'd3f6769a94a3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes for PostgreSQL + op.execute('ALTER INDEX workflow__conversation_variables_app_id_idx RENAME TO workflow_conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow__conversation_variables_created_at_idx RENAME TO workflow_conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('workflow__conversation_variables_app_id_idx') + batch_op.drop_index('workflow__conversation_variables_created_at_idx') + batch_op.create_index(batch_op.f('workflow_conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow_conversation_variables_created_at_idx'), ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes back for PostgreSQL + op.execute('ALTER INDEX workflow_conversation_variables_app_id_idx RENAME TO workflow__conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow_conversation_variables_created_at_idx RENAME TO workflow__conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow_conversation_variables_app_id_idx')) + batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py new file mode 100644 index 0000000000000000000000000000000000000000..222379a49021a6c465038e023f4c73b2449c55c8 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py @@ -0,0 +1,41 @@ +"""update upload_files.source_url + +Revision ID: f4d7ce70a7ca +Revises: 93ad8c19c40b +Create Date: 2024-11-01 05:40:03.531751 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f4d7ce70a7ca' +down_revision = '93ad8c19c40b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4ccf352df0983d55b83aadeb801ca8ff69a097 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -0,0 +1,67 @@ +"""update type of custom_disclaimer to TEXT + +Revision ID: d07474999927 +Revises: f4d7ce70a7ca +Create Date: 2024-11-01 06:22:27.981398 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd07474999927' +down_revision = 'f4d7ce70a7ca' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE recommended_apps SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py new file mode 100644 index 0000000000000000000000000000000000000000..117a7351cd67e7a6dc1b61faae0815dcd94f52f8 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -0,0 +1,73 @@ +"""update workflows graph, features and updated_at + +Revision ID: 09a8d1878d9b +Revises: d07474999927 +Create Date: 2024-11-01 06:23:59.579186 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '09a8d1878d9b' +down_revision = 'd07474999927' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") + op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") + op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=True) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py b/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..d94508edcf44b3f63af175a13e094ac7e4cacf92 --- /dev/null +++ b/api/migrations/versions/2024_11_12_0925-01d6889832f7_add_created_at_index_for_messages.py @@ -0,0 +1,31 @@ +"""add_created_at_index_for_messages + +Revision ID: 01d6889832f7 +Revises: 09a8d1878d9b +Create Date: 2024-11-12 09:25:05.527827 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '01d6889832f7' +down_revision = '09a8d1878d9b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_idx', ['created_at'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_idx') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py new file mode 100644 index 0000000000000000000000000000000000000000..9238e5a0a81c5a0a5f0796f2ae217f1fbec54a4d --- /dev/null +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -0,0 +1,55 @@ +"""parent-child-index + +Revision ID: e19037032219 +Revises: 01d6889832f7 +Create Date: 2024-11-22 07:01:17.550037 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e19037032219' +down_revision = 'd7999dfa4aae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('child_chunks', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') + ) + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunk_dataset_id_idx') + + op.drop_table('child_chunks') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py new file mode 100644 index 0000000000000000000000000000000000000000..8c576339bae8cfd3bdc7a37c80f48476276b9864 --- /dev/null +++ b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py @@ -0,0 +1,33 @@ +"""add exceptions_count field to WorkflowRun model + +Revision ID: cf8f4fc45278 +Revises: 01d6889832f7 +Create Date: 2024-11-28 05:53:21.576178 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf8f4fc45278' +down_revision = '01d6889832f7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_column('exceptions_count') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py new file mode 100644 index 0000000000000000000000000000000000000000..881a9e3c1e06b6b954c9aa3c9b0ade45610ec882 --- /dev/null +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -0,0 +1,39 @@ +"""remove unused tool_providers + +Revision ID: 11b07f66c737 +Revises: cf8f4fc45278 +Create Date: 2024-12-19 17:46:25.780116 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '11b07f66c737' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_providers') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py new file mode 100644 index 0000000000000000000000000000000000000000..814dec423c63c4529ea11e711e6e880c33ac1aa0 --- /dev/null +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -0,0 +1,37 @@ +"""add retry_index field to node-execution model +Revision ID: e1944c35e15e +Revises: 11b07f66c737 +Create Date: 2024-12-20 06:28:30.287197 +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e1944c35e15e' +down_revision = '11b07f66c737' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # We don't need these fields anymore, but this file is already merged into the main branch, + # so we need to keep this file for the sake of history, and this change will be reverted in the next migration. + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + pass + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.drop_column('retry_index') + pass + + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py new file mode 100644 index 0000000000000000000000000000000000000000..ea129d15f7e6e6e0c5a5934d7ce1b47999537d66 --- /dev/null +++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py @@ -0,0 +1,34 @@ +"""remove workflow_node_executions.retry_index if exists + +Revision ID: d7999dfa4aae +Revises: e1944c35e15e +Create Date: 2024-12-23 11:54:15.344543 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy import inspect + + +# revision identifiers, used by Alembic. +revision = 'd7999dfa4aae' +down_revision = 'e1944c35e15e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Check if column exists before attempting to remove it + conn = op.get_bind() + inspector = inspect(conn) + has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')] + + if has_column: + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + +def downgrade(): + # No downgrade needed as we don't want to restore the column + pass diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py new file mode 100644 index 0000000000000000000000000000000000000000..6dadd4e4a8afe5e2d5975881644dff13c268f1cc --- /dev/null +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -0,0 +1,47 @@ +"""add_auto_disabled_dataset_logs + +Revision ID: 923752d42eb6 +Revises: e19037032219 +Create Date: 2024-12-25 11:37:55.467101 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '923752d42eb6' +down_revision = 'e19037032219' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_auto_disable_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') + ) + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) + batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.drop_index('dataset_auto_disable_log_tenant_idx') + batch_op.drop_index('dataset_auto_disable_log_dataset_idx') + batch_op.drop_index('dataset_auto_disable_log_created_atx') + + op.drop_table('dataset_auto_disable_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py new file mode 100644 index 0000000000000000000000000000000000000000..798c895863b8dd49e4283cbcf07b0a8dcd75850a --- /dev/null +++ b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py @@ -0,0 +1,41 @@ +"""change workflow_runs.total_tokens to bigint + +Revision ID: a91b476a53de +Revises: 923752d42eb6 +Create Date: 2025-01-01 20:00:01.207369 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a91b476a53de' +down_revision = '923752d42eb6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('total_tokens', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + existing_server_default=sa.text('0')) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('total_tokens', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + existing_server_default=sa.text('0')) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py new file mode 100644 index 0000000000000000000000000000000000000000..f3eef4681e380fdbf6e3e0e3d5a51614f1f3cd1a --- /dev/null +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -0,0 +1,31 @@ +"""add message files into agent thought + +Revision ID: 23db93619b9d +Revises: 8ae9bc661daa +Create Date: 2024-01-18 08:46:37.302657 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '23db93619b9d' +down_revision = '8ae9bc661daa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('message_files') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py new file mode 100644 index 0000000000000000000000000000000000000000..9816e92dd12c039f4c86fa08b450b524c8ed5956 --- /dev/null +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -0,0 +1,50 @@ +"""add_app_anntation_setting + +Revision ID: 246ba09cbbdb +Revises: 714aafe25d39 +Create Date: 2023-12-14 11:26:12.287264 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '246ba09cbbdb' +down_revision = '714aafe25d39' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_annotation_settings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('collection_binding_id', postgresql.UUID(), nullable=False), + sa.Column('created_user_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_user_id', postgresql.UUID(), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey') + ) + with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: + batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('annotation_reply') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) + + with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: + batch_op.drop_index('app_annotation_settings_app_idx') + + op.drop_table('app_annotation_settings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..99b7010612aa0f56cc74e3888467ede7ca4f8b0f --- /dev/null +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -0,0 +1,33 @@ +"""add app tracing + +Revision ID: 2a3aebbbf4bb +Revises: c031d46af369 +Create Date: 2024-06-17 10:08:54.803701 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2a3aebbbf4bb' +down_revision = 'c031d46af369' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('tracing') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py new file mode 100644 index 0000000000000000000000000000000000000000..e933623d1cf826f0b1bb183944466ab4948665e0 --- /dev/null +++ b/api/migrations/versions/2beac44e5f5f_add_is_universal_in_apps.py @@ -0,0 +1,31 @@ +"""add is_universal in apps + +Revision ID: 2beac44e5f5f +Revises: d3d503a3471c +Create Date: 2023-07-07 12:11:29.156057 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '2beac44e5f5f' +down_revision = 'a5b56fb053ef' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_universal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('is_universal') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2c8af9671032_add_qa_document_language.py b/api/migrations/versions/2c8af9671032_add_qa_document_language.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0c14544631b628cd94a17c2f7d3c34a80f038a --- /dev/null +++ b/api/migrations/versions/2c8af9671032_add_qa_document_language.py @@ -0,0 +1,31 @@ +"""add_qa_document_language + +Revision ID: 2c8af9671032 +Revises: 8d2d099ceb74 +Create Date: 2023-08-01 18:57:27.294973 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '2c8af9671032' +down_revision = '5022897aaceb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('doc_language', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('doc_language') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py new file mode 100644 index 0000000000000000000000000000000000000000..b06a3530b88a3dfa46c896254bf5fec4537aeebc --- /dev/null +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -0,0 +1,36 @@ +"""add_tenant_id_in_api_token + +Revision ID: 2e9819ca5b28 +Revises: 6e2cfb077b04 +Create Date: 2023-09-22 15:41:01.243183 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '2e9819ca5b28' +down_revision = 'ab23c11305d4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) + batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) + batch_op.drop_column('dataset_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) + batch_op.drop_index('api_token_tenant_idx') + batch_op.drop_column('tenant_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1381846360aa7889d0e6a5c7251680e6ba3480 --- /dev/null +++ b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py @@ -0,0 +1,31 @@ +"""add tool labels to agent thought + +Revision ID: 380c6aa5a70d +Revises: dfb3b7f477da +Create Date: 2024-01-24 10:58:15.644445 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '380c6aa5a70d' +down_revision = 'dfb3b7f477da' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('tool_labels_str') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py new file mode 100644 index 0000000000000000000000000000000000000000..bf54c247ead19c335e7a3ea35feb4e34068e0439 --- /dev/null +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -0,0 +1,42 @@ +"""add tool label bings + +Revision ID: 3b18fea55204 +Revises: 7bdef072e63a +Create Date: 2024-05-14 09:27:18.857890 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '3b18fea55204' +down_revision = '7bdef072e63a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_label_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tool_id', sa.String(length=64), nullable=False), + sa.Column('tool_type', sa.String(length=40), nullable=False), + sa.Column('label_name', sa.String(length=40), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_label_bind_pkey') + ) + + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), server_default='', nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('privacy_policy') + + op.drop_table('tool_label_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py b/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py new file mode 100644 index 0000000000000000000000000000000000000000..5f118806832489ff97cebb8e0187d1c5be3f95b1 --- /dev/null +++ b/api/migrations/versions/3c7cac9521c6_add_tags_and_binding_table.py @@ -0,0 +1,62 @@ +"""add-tags-and-binding-table + +Revision ID: 3c7cac9521c6 +Revises: c3311b089690 +Create Date: 2024-04-11 06:17:34.278594 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '3c7cac9521c6' +down_revision = 'c3311b089690' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tag_bindings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=True), + sa.Column('tag_id', postgresql.UUID(), nullable=True), + sa.Column('target_id', postgresql.UUID(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tag_binding_pkey') + ) + with op.batch_alter_table('tag_bindings', schema=None) as batch_op: + batch_op.create_index('tag_bind_tag_id_idx', ['tag_id'], unique=False) + batch_op.create_index('tag_bind_target_id_idx', ['target_id'], unique=False) + + op.create_table('tags', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=True), + sa.Column('type', sa.String(length=16), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tag_pkey') + ) + with op.batch_alter_table('tags', schema=None) as batch_op: + batch_op.create_index('tag_name_idx', ['name'], unique=False) + batch_op.create_index('tag_type_idx', ['type'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tags', schema=None) as batch_op: + batch_op.drop_index('tag_type_idx') + batch_op.drop_index('tag_name_idx') + + op.drop_table('tags') + with op.batch_alter_table('tag_bindings', schema=None) as batch_op: + batch_op.drop_index('tag_bind_target_id_idx') + batch_op.drop_index('tag_bind_tag_id_idx') + + op.drop_table('tag_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbc5703036a4a87b9e94bb522d595f98308f733 --- /dev/null +++ b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py @@ -0,0 +1,67 @@ +"""add_assistant_app + +Revision ID: 3ef9b2b6bee6 +Revises: 89c7899ca936 +Create Date: 2024-01-05 15:26:25.117551 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '3ef9b2b6bee6' +down_revision = '89c7899ca936' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_api_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('schema', sa.Text(), nullable=False), + sa.Column('schema_type_str', sa.String(length=40), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('description_str', sa.Text(), nullable=False), + sa.Column('tools_str', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_api_provider_pkey') + ) + op.create_table('tool_builtin_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=True), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=40), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + ) + op.create_table('tool_published_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('llm_description', sa.Text(), nullable=False), + sa.Column('query_description', sa.Text(), nullable=False), + sa.Column('query_name', sa.String(length=40), nullable=False), + sa.Column('tool_name', sa.String(length=40), nullable=False), + sa.Column('author', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.ForeignKeyConstraint(['app_id'], ['apps.id'], ), + sa.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), + sa.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_published_apps') + op.drop_table('tool_builtin_providers') + op.drop_table('tool_api_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/408176b91ad3_add_max_active_requests.py b/api/migrations/versions/408176b91ad3_add_max_active_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..c19a68586ff975cd9d25323bcb3f34c8dda2d271 --- /dev/null +++ b/api/migrations/versions/408176b91ad3_add_max_active_requests.py @@ -0,0 +1,33 @@ +"""'add_max_active_requests' + +Revision ID: 408176b91ad3 +Revises: 7e6a8693e07a +Create Date: 2024-07-04 09:25:14.029023 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '408176b91ad3' +down_revision = '161cadc1af8d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('max_active_requests', sa.Integer(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('max_active_requests') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py new file mode 100644 index 0000000000000000000000000000000000000000..f388b99b9068a0d69674926297882aec43238649 --- /dev/null +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -0,0 +1,48 @@ +"""conversation columns set nullable + +Revision ID: 42e85ed5564d +Revises: f9107f83abab +Create Date: 2024-03-07 08:30:29.133614 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '42e85ed5564d' +down_revision = 'f9107f83abab' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py new file mode 100644 index 0000000000000000000000000000000000000000..b47dd3c8ab6021be8b4ffa4edcddfd6bbf73091d --- /dev/null +++ b/api/migrations/versions/46976cc39132_add_annotation_histoiry_score.py @@ -0,0 +1,31 @@ +"""add-annotation-histoiry-score + +Revision ID: 46976cc39132 +Revises: e1901f623fd0 +Create Date: 2023-12-13 04:39:59.302971 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '46976cc39132' +down_revision = 'e1901f623fd0' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_column('score') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py new file mode 100644 index 0000000000000000000000000000000000000000..b37928d3c08a8696ae0a4eb17424b99700a00c4e --- /dev/null +++ b/api/migrations/versions/47cc7df8c4f3_modify_default_model_name_length.py @@ -0,0 +1,39 @@ +"""modify default model name length + +Revision ID: 47cc7df8c4f3 +Revises: 3c7cac9521c6 +Create Date: 2024-05-10 09:48:09.046298 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '47cc7df8c4f3' +down_revision = '3c7cac9521c6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('model_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/4823da1d26cf_add_tool_file.py b/api/migrations/versions/4823da1d26cf_add_tool_file.py new file mode 100644 index 0000000000000000000000000000000000000000..1a473a10fe811bc984ef60b4b9ebd755eab96be8 --- /dev/null +++ b/api/migrations/versions/4823da1d26cf_add_tool_file.py @@ -0,0 +1,37 @@ +"""add tool file + +Revision ID: 4823da1d26cf +Revises: 053da0c1d756 +Create Date: 2024-01-15 11:37:16.782718 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4823da1d26cf' +down_revision = '053da0c1d756' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('file_key', sa.String(length=255), nullable=False), + sa.Column('mimetype', sa.String(length=255), nullable=False), + sa.Column('original_url', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='tool_file_pkey') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_files') + # ### end Alembic commands ### diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py new file mode 100644 index 0000000000000000000000000000000000000000..240502185683740a667ec9318f2f54a6b1c127ee --- /dev/null +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -0,0 +1,35 @@ +"""change message chain id to nullable + +Revision ID: 4829e54d2fee +Revises: 114eed84c228 +Create Date: 2024-01-12 03:42:27.362415 + +""" +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4829e54d2fee' +down_revision = '114eed84c228' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py new file mode 100644 index 0000000000000000000000000000000000000000..178bd24e3c63b368f32e3dec4cfc659b64ade668 --- /dev/null +++ b/api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py @@ -0,0 +1,45 @@ +"""update_dataset_model_field_null_available + +Revision ID: 4bcffcd64aa4 +Revises: 853f9b9cd3b6 +Create Date: 2023-08-28 20:58:50.077056 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '4bcffcd64aa4' +down_revision = '853f9b9cd3b6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.alter_column('embedding_model', + existing_type=sa.VARCHAR(length=255), + nullable=True, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + batch_op.alter_column('embedding_model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True, + existing_server_default=sa.text("'openai'::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.alter_column('embedding_model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False, + existing_server_default=sa.text("'openai'::character varying")) + batch_op.alter_column('embedding_model', + existing_type=sa.VARCHAR(length=255), + nullable=False, + existing_server_default=sa.text("'text-embedding-ada-002'::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py new file mode 100644 index 0000000000000000000000000000000000000000..3be4ba4f2a82e4e73c0a6ca7a1590b27f44f44ff --- /dev/null +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -0,0 +1,126 @@ +"""add load balancing + +Revision ID: 4e99a8df00ff +Revises: 47cc7df8c4f3 +Create Date: 2024-05-10 12:08:09.812736 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '4e99a8df00ff' +down_revision = '64a70a7aab8b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('load_balancing_model_configs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey') + ) + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + op.create_table('provider_model_settings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey') + ) + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.drop_index('provider_model_setting_tenant_provider_model_idx') + + op.drop_table('provider_model_settings') + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx') + + op.drop_table('load_balancing_model_configs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py b/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py new file mode 100644 index 0000000000000000000000000000000000000000..c09cf2af60cdffccdb2bab1331943e5a50c82b76 --- /dev/null +++ b/api/migrations/versions/4ff534e1eb11_add_workflow_to_site.py @@ -0,0 +1,33 @@ +"""add workflow to site + +Revision ID: 4ff534e1eb11 +Revises: 7b45942e39bb +Create Date: 2024-06-21 04:16:03.419634 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '4ff534e1eb11' +down_revision = '7b45942e39bb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('show_workflow_steps', sa.Boolean(), server_default=sa.text('true'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('show_workflow_steps') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f4af5a00a2a82ca52eb3268335d42f9151e7e4 --- /dev/null +++ b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py @@ -0,0 +1,35 @@ +"""add model name in embedding + +Revision ID: 5022897aaceb +Revises: bf0aec5ba2cf +Create Date: 2023-08-11 14:38:15.499460 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '5022897aaceb' +down_revision = 'bf0aec5ba2cf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.add_column(sa.Column('model_name', sa.String(length=40), server_default=sa.text("'text-embedding-ada-002'::character varying"), nullable=False)) + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['hash']) + batch_op.drop_column('model_name') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/53bf8af60645_update_model.py b/api/migrations/versions/53bf8af60645_update_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0928d013dba45813e64da97714c0268027341b --- /dev/null +++ b/api/migrations/versions/53bf8af60645_update_model.py @@ -0,0 +1,41 @@ +"""update model + +Revision ID: 53bf8af60645 +Revises: 8e5588e6412e +Create Date: 2024-07-24 08:06:55.291031 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '53bf8af60645' +down_revision = '8e5588e6412e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py new file mode 100644 index 0000000000000000000000000000000000000000..299f442de989be9449f5a8467e9cf4ba9a2156d6 --- /dev/null +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -0,0 +1,35 @@ +"""enable tool file without conversation id + +Revision ID: 563cf8bf777b +Revises: b5429b71023c +Create Date: 2024-03-14 04:54:56.679506 + +""" +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '563cf8bf777b' +down_revision = 'b5429b71023c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/5fda94355fce_custom_disclaimer.py b/api/migrations/versions/5fda94355fce_custom_disclaimer.py new file mode 100644 index 0000000000000000000000000000000000000000..73bcdc4500041ae80162d13a17314669ba0b9cda --- /dev/null +++ b/api/migrations/versions/5fda94355fce_custom_disclaimer.py @@ -0,0 +1,45 @@ +"""Custom Disclaimer + +Revision ID: 5fda94355fce +Revises: 47cc7df8c4f3 +Create Date: 2024-05-10 20:04:45.806549 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '5fda94355fce' +down_revision = '47cc7df8c4f3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_disclaimer', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.drop_column('custom_disclaimer') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/614f77cecc48_add_last_active_at.py b/api/migrations/versions/614f77cecc48_add_last_active_at.py new file mode 100644 index 0000000000000000000000000000000000000000..182f8f89f19b307a681de7bdb757ce3b39f0a0cc --- /dev/null +++ b/api/migrations/versions/614f77cecc48_add_last_active_at.py @@ -0,0 +1,31 @@ +"""add last active at + +Revision ID: 614f77cecc48 +Revises: a45f4dfde53b +Create Date: 2023-06-15 13:33:00.357467 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '614f77cecc48' +down_revision = 'a45f4dfde53b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.add_column(sa.Column('last_active_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.drop_column('last_active_at') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/63f9175e515b_merge_branches.py b/api/migrations/versions/63f9175e515b_merge_branches.py new file mode 100644 index 0000000000000000000000000000000000000000..062365994195fa4e83b0806f1f5929457c1ed8fe --- /dev/null +++ b/api/migrations/versions/63f9175e515b_merge_branches.py @@ -0,0 +1,22 @@ +"""merge branches + +Revision ID: 63f9175e515b +Revises: 2a3aebbbf4bb, b69ca54b9208 +Create Date: 2024-06-26 09:46:36.573505 + +""" +import models as models + +# revision identifiers, used by Alembic. +revision = '63f9175e515b' +down_revision = ('2a3aebbbf4bb', 'b69ca54b9208') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py new file mode 100644 index 0000000000000000000000000000000000000000..73242908f4ae8db52d6f71a6d46297c7d3e38e61 --- /dev/null +++ b/api/migrations/versions/64a70a7aab8b_add_workflow_run_index.py @@ -0,0 +1,32 @@ +"""add workflow run index + +Revision ID: 64a70a7aab8b +Revises: 03f98355ba0e +Create Date: 2024-05-28 12:32:00.276061 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '64a70a7aab8b' +down_revision = '03f98355ba0e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_tenant_app_sequence_idx', ['tenant_id', 'app_id', 'sequence_number'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_tenant_app_sequence_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py new file mode 100644 index 0000000000000000000000000000000000000000..8c45ae898dd0623445f3e42bb8c59152dc7a1102 --- /dev/null +++ b/api/migrations/versions/64b051264f32_init.py @@ -0,0 +1,797 @@ +"""init + +Revision ID: 64b051264f32 +Revises: +Create Date: 2023-05-13 14:26:59.085018 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '64b051264f32' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + + op.create_table('account_integrates', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=16), nullable=False), + sa.Column('open_id', sa.String(length=255), nullable=False), + sa.Column('encrypted_token', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='account_integrate_pkey'), + sa.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), + sa.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + ) + op.create_table('accounts', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('email', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=True), + sa.Column('password_salt', sa.String(length=255), nullable=True), + sa.Column('avatar', sa.String(length=255), nullable=True), + sa.Column('interface_language', sa.String(length=255), nullable=True), + sa.Column('interface_theme', sa.String(length=255), nullable=True), + sa.Column('timezone', sa.String(length=255), nullable=True), + sa.Column('last_login_at', sa.DateTime(), nullable=True), + sa.Column('last_login_ip', sa.String(length=255), nullable=True), + sa.Column('status', sa.String(length=16), server_default=sa.text("'active'::character varying"), nullable=False), + sa.Column('initialized_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='account_pkey') + ) + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.create_index('account_email_idx', ['email'], unique=False) + + op.create_table('api_requests', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('api_token_id', postgresql.UUID(), nullable=False), + sa.Column('path', sa.String(length=255), nullable=False), + sa.Column('request', sa.Text(), nullable=True), + sa.Column('response', sa.Text(), nullable=True), + sa.Column('ip', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_request_pkey') + ) + with op.batch_alter_table('api_requests', schema=None) as batch_op: + batch_op.create_index('api_request_token_idx', ['tenant_id', 'api_token_id'], unique=False) + + op.create_table('api_tokens', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=True), + sa.Column('dataset_id', postgresql.UUID(), nullable=True), + sa.Column('type', sa.String(length=16), nullable=False), + sa.Column('token', sa.String(length=255), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_token_pkey') + ) + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.create_index('api_token_app_id_type_idx', ['app_id', 'type'], unique=False) + batch_op.create_index('api_token_token_idx', ['token', 'type'], unique=False) + + op.create_table('app_dataset_joins', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_dataset_join_pkey') + ) + with op.batch_alter_table('app_dataset_joins', schema=None) as batch_op: + batch_op.create_index('app_dataset_join_app_dataset_idx', ['dataset_id', 'app_id'], unique=False) + + op.create_table('app_model_configs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('configs', sa.JSON(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('opening_statement', sa.Text(), nullable=True), + sa.Column('suggested_questions', sa.Text(), nullable=True), + sa.Column('suggested_questions_after_answer', sa.Text(), nullable=True), + sa.Column('more_like_this', sa.Text(), nullable=True), + sa.Column('model', sa.Text(), nullable=True), + sa.Column('user_input_form', sa.Text(), nullable=True), + sa.Column('pre_prompt', sa.Text(), nullable=True), + sa.Column('agent_mode', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='app_model_config_pkey') + ) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.create_index('app_app_id_idx', ['app_id'], unique=False) + + op.create_table('apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=True), + sa.Column('icon_background', sa.String(length=255), nullable=True), + sa.Column('app_model_config_id', postgresql.UUID(), nullable=True), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('enable_site', sa.Boolean(), nullable=False), + sa.Column('enable_api', sa.Boolean(), nullable=False), + sa.Column('api_rpm', sa.Integer(), nullable=False), + sa.Column('api_rph', sa.Integer(), nullable=False), + sa.Column('is_demo', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_pkey') + ) + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.create_index('app_tenant_id_idx', ['tenant_id'], unique=False) + + op.execute('CREATE SEQUENCE task_id_sequence;') + op.execute('CREATE SEQUENCE taskset_id_sequence;') + + op.create_table('celery_taskmeta', + sa.Column('id', sa.Integer(), nullable=False, + server_default=sa.text('nextval(\'task_id_sequence\')')), + sa.Column('task_id', sa.String(length=155), nullable=True), + sa.Column('status', sa.String(length=50), nullable=True), + sa.Column('result', sa.PickleType(), nullable=True), + sa.Column('date_done', sa.DateTime(), nullable=True), + sa.Column('traceback', sa.Text(), nullable=True), + sa.Column('name', sa.String(length=155), nullable=True), + sa.Column('args', sa.LargeBinary(), nullable=True), + sa.Column('kwargs', sa.LargeBinary(), nullable=True), + sa.Column('worker', sa.String(length=155), nullable=True), + sa.Column('retries', sa.Integer(), nullable=True), + sa.Column('queue', sa.String(length=155), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('task_id') + ) + op.create_table('celery_tasksetmeta', + sa.Column('id', sa.Integer(), nullable=False, + server_default=sa.text('nextval(\'taskset_id_sequence\')')), + sa.Column('taskset_id', sa.String(length=155), nullable=True), + sa.Column('result', sa.PickleType(), nullable=True), + sa.Column('date_done', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('taskset_id') + ) + op.create_table('conversations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('app_model_config_id', postgresql.UUID(), nullable=False), + sa.Column('model_provider', sa.String(length=255), nullable=False), + sa.Column('override_model_configs', sa.Text(), nullable=True), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('summary', sa.Text(), nullable=True), + sa.Column('inputs', sa.JSON(), nullable=True), + sa.Column('introduction', sa.Text(), nullable=True), + sa.Column('system_instruction', sa.Text(), nullable=True), + sa.Column('system_instruction_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('read_at', sa.DateTime(), nullable=True), + sa.Column('read_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='conversation_pkey') + ) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.create_index('conversation_app_from_user_idx', ['app_id', 'from_source', 'from_end_user_id'], unique=False) + + op.create_table('dataset_keyword_tables', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('keyword_table', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), + sa.UniqueConstraint('dataset_id') + ) + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.create_index('dataset_keyword_table_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('dataset_process_rules', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('mode', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('rules', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey') + ) + with op.batch_alter_table('dataset_process_rules', schema=None) as batch_op: + batch_op.create_index('dataset_process_rule_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('dataset_queries', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('source', sa.String(length=255), nullable=False), + sa.Column('source_app_id', postgresql.UUID(), nullable=True), + sa.Column('created_by_role', sa.String(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_query_pkey') + ) + with op.batch_alter_table('dataset_queries', schema=None) as batch_op: + batch_op.create_index('dataset_query_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('datasets', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('provider', sa.String(length=255), server_default=sa.text("'vendor'::character varying"), nullable=False), + sa.Column('permission', sa.String(length=255), server_default=sa.text("'only_me'::character varying"), nullable=False), + sa.Column('data_source_type', sa.String(length=255), nullable=True), + sa.Column('indexing_technique', sa.String(length=255), nullable=True), + sa.Column('index_struct', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_pkey') + ) + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.create_index('dataset_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('dify_setups', + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('setup_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('version', name='dify_setup_pkey') + ) + op.create_table('document_segments', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('document_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('tokens', sa.Integer(), nullable=False), + sa.Column('keywords', sa.JSON(), nullable=True), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('hit_count', sa.Integer(), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', postgresql.UUID(), nullable=True), + sa.Column('status', sa.String(length=255), server_default=sa.text("'waiting'::character varying"), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('stopped_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='document_segment_pkey') + ) + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.create_index('document_segment_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_dataset_node_idx', ['dataset_id', 'index_node_id'], unique=False) + batch_op.create_index('document_segment_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_tenant_dataset_idx', ['dataset_id', 'tenant_id'], unique=False) + batch_op.create_index('document_segment_tenant_document_idx', ['document_id', 'tenant_id'], unique=False) + + op.create_table('documents', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('data_source_type', sa.String(length=255), nullable=False), + sa.Column('data_source_info', sa.Text(), nullable=True), + sa.Column('dataset_process_rule_id', postgresql.UUID(), nullable=True), + sa.Column('batch', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_api_request_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('processing_started_at', sa.DateTime(), nullable=True), + sa.Column('file_id', sa.Text(), nullable=True), + sa.Column('word_count', sa.Integer(), nullable=True), + sa.Column('parsing_completed_at', sa.DateTime(), nullable=True), + sa.Column('cleaning_completed_at', sa.DateTime(), nullable=True), + sa.Column('splitting_completed_at', sa.DateTime(), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('indexing_latency', sa.Float(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('is_paused', sa.Boolean(), server_default=sa.text('false'), nullable=True), + sa.Column('paused_by', postgresql.UUID(), nullable=True), + sa.Column('paused_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('stopped_at', sa.DateTime(), nullable=True), + sa.Column('indexing_status', sa.String(length=255), server_default=sa.text("'waiting'::character varying"), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', postgresql.UUID(), nullable=True), + sa.Column('archived', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('archived_reason', sa.String(length=255), nullable=True), + sa.Column('archived_by', postgresql.UUID(), nullable=True), + sa.Column('archived_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('doc_type', sa.String(length=40), nullable=True), + sa.Column('doc_metadata', sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint('id', name='document_pkey') + ) + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.create_index('document_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_is_paused_idx', ['is_paused'], unique=False) + + op.create_table('embeddings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('hash', sa.String(length=64), nullable=False), + sa.Column('embedding', sa.LargeBinary(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='embedding_pkey'), + sa.UniqueConstraint('hash', name='embedding_hash_idx') + ) + op.create_table('end_users', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=True), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('external_user_id', sa.String(length=255), nullable=True), + sa.Column('name', sa.String(length=255), nullable=True), + sa.Column('is_anonymous', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('session_id', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='end_user_pkey') + ) + with op.batch_alter_table('end_users', schema=None) as batch_op: + batch_op.create_index('end_user_session_id_idx', ['session_id', 'type'], unique=False) + batch_op.create_index('end_user_tenant_session_id_idx', ['tenant_id', 'session_id', 'type'], unique=False) + + op.create_table('installed_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('app_owner_tenant_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('is_pinned', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='installed_app_pkey'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + ) + with op.batch_alter_table('installed_apps', schema=None) as batch_op: + batch_op.create_index('installed_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('installed_app_tenant_id_idx', ['tenant_id'], unique=False) + + op.create_table('invitation_codes', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('batch', sa.String(length=255), nullable=False), + sa.Column('code', sa.String(length=32), nullable=False), + sa.Column('status', sa.String(length=16), server_default=sa.text("'unused'::character varying"), nullable=False), + sa.Column('used_at', sa.DateTime(), nullable=True), + sa.Column('used_by_tenant_id', postgresql.UUID(), nullable=True), + sa.Column('used_by_account_id', postgresql.UUID(), nullable=True), + sa.Column('deprecated_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='invitation_code_pkey') + ) + with op.batch_alter_table('invitation_codes', schema=None) as batch_op: + batch_op.create_index('invitation_codes_batch_idx', ['batch'], unique=False) + batch_op.create_index('invitation_codes_code_idx', ['code', 'status'], unique=False) + + op.create_table('message_agent_thoughts', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('message_chain_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('thought', sa.Text(), nullable=True), + sa.Column('tool', sa.Text(), nullable=True), + sa.Column('tool_input', sa.Text(), nullable=True), + sa.Column('observation', sa.Text(), nullable=True), + sa.Column('tool_process_data', sa.Text(), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('message_token', sa.Integer(), nullable=True), + sa.Column('message_unit_price', sa.Numeric(), nullable=True), + sa.Column('answer', sa.Text(), nullable=True), + sa.Column('answer_token', sa.Integer(), nullable=True), + sa.Column('answer_unit_price', sa.Numeric(), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('total_price', sa.Numeric(), nullable=True), + sa.Column('currency', sa.String(), nullable=True), + sa.Column('latency', sa.Float(), nullable=True), + sa.Column('created_by_role', sa.String(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_agent_thought_pkey') + ) + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.create_index('message_agent_thought_message_chain_id_idx', ['message_chain_id'], unique=False) + batch_op.create_index('message_agent_thought_message_id_idx', ['message_id'], unique=False) + + op.create_table('message_chains', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('input', sa.Text(), nullable=True), + sa.Column('output', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_chain_pkey') + ) + with op.batch_alter_table('message_chains', schema=None) as batch_op: + batch_op.create_index('message_chain_message_id_idx', ['message_id'], unique=False) + + op.create_table('message_feedbacks', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('rating', sa.String(length=255), nullable=False), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_feedback_pkey') + ) + with op.batch_alter_table('message_feedbacks', schema=None) as batch_op: + batch_op.create_index('message_feedback_app_idx', ['app_id'], unique=False) + batch_op.create_index('message_feedback_conversation_idx', ['conversation_id', 'from_source', 'rating'], unique=False) + batch_op.create_index('message_feedback_message_idx', ['message_id', 'from_source'], unique=False) + + op.create_table('operation_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('action', sa.String(length=255), nullable=False), + sa.Column('content', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_ip', sa.String(length=255), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='operation_log_pkey') + ) + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.create_index('operation_log_account_action_idx', ['tenant_id', 'account_id', 'action'], unique=False) + + op.create_table('pinned_conversations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pinned_conversation_pkey') + ) + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by'], unique=False) + + op.create_table('providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('provider_type', sa.String(length=40), nullable=False, server_default=sa.text("'custom'::character varying")), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('last_used', sa.DateTime(), nullable=True), + sa.Column('quota_type', sa.String(length=40), nullable=True, server_default=sa.text("''::character varying")), + sa.Column('quota_limit', sa.Integer(), nullable=True), + sa.Column('quota_used', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + ) + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.create_index('provider_tenant_id_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + op.create_table('recommended_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('description', sa.JSON(), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=False), + sa.Column('privacy_policy', sa.String(length=255), nullable=False), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('is_listed', sa.Boolean(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='recommended_app_pkey') + ) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.create_index('recommended_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('recommended_app_is_listed_idx', ['is_listed'], unique=False) + + op.create_table('saved_messages', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='saved_message_pkey') + ) + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by'], unique=False) + + op.create_table('sessions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('session_id', sa.String(length=255), nullable=True), + sa.Column('data', sa.LargeBinary(), nullable=True), + sa.Column('expiry', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('session_id') + ) + op.create_table('sites', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=True), + sa.Column('icon_background', sa.String(length=255), nullable=True), + sa.Column('description', sa.String(length=255), nullable=True), + sa.Column('default_language', sa.String(length=255), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=True), + sa.Column('privacy_policy', sa.String(length=255), nullable=True), + sa.Column('customize_domain', sa.String(length=255), nullable=True), + sa.Column('customize_token_strategy', sa.String(length=255), nullable=False), + sa.Column('prompt_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('code', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='site_pkey') + ) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.create_index('site_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('site_code_idx', ['code', 'status'], unique=False) + + op.create_table('tenant_account_joins', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('role', sa.String(length=16), server_default='normal', nullable=False), + sa.Column('invited_by', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), + sa.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + ) + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.create_index('tenant_account_join_account_id_idx', ['account_id'], unique=False) + batch_op.create_index('tenant_account_join_tenant_id_idx', ['tenant_id'], unique=False) + + op.create_table('tenants', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('encrypt_public_key', sa.Text(), nullable=True), + sa.Column('plan', sa.String(length=255), server_default=sa.text("'basic'::character varying"), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_pkey') + ) + op.create_table('upload_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('storage_type', sa.String(length=255), nullable=False), + sa.Column('key', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('size', sa.Integer(), nullable=False), + sa.Column('extension', sa.String(length=255), nullable=False), + sa.Column('mime_type', sa.String(length=255), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('used', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('used_by', postgresql.UUID(), nullable=True), + sa.Column('used_at', sa.DateTime(), nullable=True), + sa.Column('hash', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='upload_file_pkey') + ) + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.create_index('upload_file_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('message_annotations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_annotation_pkey') + ) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.create_index('message_annotation_app_idx', ['app_id'], unique=False) + batch_op.create_index('message_annotation_conversation_idx', ['conversation_id'], unique=False) + batch_op.create_index('message_annotation_message_idx', ['message_id'], unique=False) + + op.create_table('messages', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('model_provider', sa.String(length=255), nullable=False), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('override_model_configs', sa.Text(), nullable=True), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('inputs', sa.JSON(), nullable=True), + sa.Column('query', sa.Text(), nullable=False), + sa.Column('message', sa.JSON(), nullable=False), + sa.Column('message_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('message_unit_price', sa.Numeric(precision=10, scale=4), nullable=False), + sa.Column('answer', sa.Text(), nullable=False), + sa.Column('answer_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('answer_unit_price', sa.Numeric(precision=10, scale=4), nullable=False), + sa.Column('provider_response_latency', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), + sa.Column('currency', sa.String(length=255), nullable=False), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('agent_based', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_pkey') + ) + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_account_idx', ['app_id', 'from_source', 'from_account_id'], unique=False) + batch_op.create_index('message_app_id_idx', ['app_id', 'created_at'], unique=False) + batch_op.create_index('message_conversation_id_idx', ['conversation_id'], unique=False) + batch_op.create_index('message_end_user_idx', ['app_id', 'from_source', 'from_end_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_end_user_idx') + batch_op.drop_index('message_conversation_id_idx') + batch_op.drop_index('message_app_id_idx') + batch_op.drop_index('message_account_idx') + + op.drop_table('messages') + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.drop_index('message_annotation_message_idx') + batch_op.drop_index('message_annotation_conversation_idx') + batch_op.drop_index('message_annotation_app_idx') + + op.drop_table('message_annotations') + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_index('upload_file_tenant_idx') + + op.drop_table('upload_files') + op.drop_table('tenants') + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.drop_index('tenant_account_join_tenant_id_idx') + batch_op.drop_index('tenant_account_join_account_id_idx') + + op.drop_table('tenant_account_joins') + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_index('site_code_idx') + batch_op.drop_index('site_app_id_idx') + + op.drop_table('sites') + op.drop_table('sessions') + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.drop_index('saved_message_message_idx') + + op.drop_table('saved_messages') + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.drop_index('recommended_app_is_listed_idx') + batch_op.drop_index('recommended_app_app_id_idx') + + op.drop_table('recommended_apps') + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_index('provider_tenant_id_provider_idx') + + op.drop_table('providers') + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.drop_index('pinned_conversation_conversation_idx') + + op.drop_table('pinned_conversations') + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.drop_index('operation_log_account_action_idx') + + op.drop_table('operation_logs') + with op.batch_alter_table('message_feedbacks', schema=None) as batch_op: + batch_op.drop_index('message_feedback_message_idx') + batch_op.drop_index('message_feedback_conversation_idx') + batch_op.drop_index('message_feedback_app_idx') + + op.drop_table('message_feedbacks') + with op.batch_alter_table('message_chains', schema=None) as batch_op: + batch_op.drop_index('message_chain_message_id_idx') + + op.drop_table('message_chains') + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_index('message_agent_thought_message_id_idx') + batch_op.drop_index('message_agent_thought_message_chain_id_idx') + + op.drop_table('message_agent_thoughts') + with op.batch_alter_table('invitation_codes', schema=None) as batch_op: + batch_op.drop_index('invitation_codes_code_idx') + batch_op.drop_index('invitation_codes_batch_idx') + + op.drop_table('invitation_codes') + with op.batch_alter_table('installed_apps', schema=None) as batch_op: + batch_op.drop_index('installed_app_tenant_id_idx') + batch_op.drop_index('installed_app_app_id_idx') + + op.drop_table('installed_apps') + with op.batch_alter_table('end_users', schema=None) as batch_op: + batch_op.drop_index('end_user_tenant_session_id_idx') + batch_op.drop_index('end_user_session_id_idx') + + op.drop_table('end_users') + op.drop_table('embeddings') + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_index('document_is_paused_idx') + batch_op.drop_index('document_dataset_id_idx') + + op.drop_table('documents') + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_tenant_document_idx') + batch_op.drop_index('document_segment_tenant_dataset_idx') + batch_op.drop_index('document_segment_document_id_idx') + batch_op.drop_index('document_segment_dataset_node_idx') + batch_op.drop_index('document_segment_dataset_id_idx') + + op.drop_table('document_segments') + op.drop_table('dify_setups') + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_index('dataset_tenant_idx') + + op.drop_table('datasets') + with op.batch_alter_table('dataset_queries', schema=None) as batch_op: + batch_op.drop_index('dataset_query_dataset_id_idx') + + op.drop_table('dataset_queries') + with op.batch_alter_table('dataset_process_rules', schema=None) as batch_op: + batch_op.drop_index('dataset_process_rule_dataset_id_idx') + + op.drop_table('dataset_process_rules') + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.drop_index('dataset_keyword_table_dataset_id_idx') + + op.drop_table('dataset_keyword_tables') + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_index('conversation_app_from_user_idx') + + op.drop_table('conversations') + op.drop_table('celery_tasksetmeta') + op.drop_table('celery_taskmeta') + + op.execute('DROP SEQUENCE taskset_id_sequence;') + op.execute('DROP SEQUENCE task_id_sequence;') + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_index('app_tenant_id_idx') + + op.drop_table('apps') + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_index('app_app_id_idx') + + op.drop_table('app_model_configs') + with op.batch_alter_table('app_dataset_joins', schema=None) as batch_op: + batch_op.drop_index('app_dataset_join_app_dataset_idx') + + op.drop_table('app_dataset_joins') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.drop_index('api_token_token_idx') + batch_op.drop_index('api_token_app_id_type_idx') + + op.drop_table('api_tokens') + with op.batch_alter_table('api_requests', schema=None) as batch_op: + batch_op.drop_index('api_request_token_idx') + + op.drop_table('api_requests') + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.drop_index('account_email_idx') + + op.drop_table('accounts') + op.drop_table('account_integrates') + + op.execute('DROP EXTENSION IF EXISTS "uuid-ossp";') + # ### end Alembic commands ### diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 0000000000000000000000000000000000000000..55824945da49b0848b908a2e4fa1ae78341eb5ca --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: 030f4915f36a +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = '030f4915f36a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py new file mode 100644 index 0000000000000000000000000000000000000000..da27dd4426bb42480cff15902e094e0c930b11c2 --- /dev/null +++ b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py @@ -0,0 +1,54 @@ +"""add_dataset_retriever_resource + +Revision ID: 6dcb43972bdc +Revises: 4bcffcd64aa4 +Create Date: 2023-09-06 16:51:27.385844 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6dcb43972bdc' +down_revision = '4bcffcd64aa4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_retriever_resources', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_name', sa.Text(), nullable=False), + sa.Column('document_id', postgresql.UUID(), nullable=False), + sa.Column('document_name', sa.Text(), nullable=False), + sa.Column('data_source_type', sa.Text(), nullable=False), + sa.Column('segment_id', postgresql.UUID(), nullable=False), + sa.Column('score', sa.Float(), nullable=True), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('hit_count', sa.Integer(), nullable=True), + sa.Column('word_count', sa.Integer(), nullable=True), + sa.Column('segment_position', sa.Integer(), nullable=True), + sa.Column('index_node_hash', sa.Text(), nullable=True), + sa.Column('retriever_from', sa.Text(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey') + ) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.drop_index('dataset_retriever_resource_message_id_idx') + + op.drop_table('dataset_retriever_resources') + # ### end Alembic commands ### diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa322f69394ec6202de397200ee69ded598d2d1 --- /dev/null +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -0,0 +1,47 @@ +"""add_dataset_collection_binding + +Revision ID: 6e2cfb077b04 +Revises: 77e83833755c +Create Date: 2023-09-13 22:16:48.027810 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6e2cfb077b04' +down_revision = '77e83833755c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_collection_bindings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('collection_name', sa.String(length=64), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey') + ) + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('collection_binding_id') + + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.drop_index('provider_model_name_idx') + + op.drop_table('dataset_collection_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py new file mode 100644 index 0000000000000000000000000000000000000000..7445f664cd75a1a63a92d6b97c208ef57f79c9b5 --- /dev/null +++ b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py @@ -0,0 +1,32 @@ +"""add-embedding-cache-created_at_index + +Revision ID: 6e957a32015b +Revises: fecff1c3da27 +Create Date: 2024-07-19 17:21:34.414705 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '6e957a32015b' +down_revision = 'fecff1c3da27' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.create_index('created_at_idx', ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_index('created_at_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py new file mode 100644 index 0000000000000000000000000000000000000000..498b46e3c47364c1bc61ab286bff49cded517c18 --- /dev/null +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -0,0 +1,33 @@ +"""add_anntation_history_match_response + +Revision ID: 714aafe25d39 +Revises: f2a6fc85e260 +Create Date: 2023-12-14 06:38:02.972527 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '714aafe25d39' +down_revision = 'f2a6fc85e260' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) + batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.drop_column('annotation_content') + batch_op.drop_column('annotation_question') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d8c3d88d4b49723156f31b826c870c52dfcc8e --- /dev/null +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -0,0 +1,31 @@ +"""add_app_config_retriever_resource + +Revision ID: 77e83833755c +Revises: 6dcb43972bdc +Create Date: 2023-09-06 17:26:40.311927 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '77e83833755c' +down_revision = '6dcb43972bdc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('retriever_resource') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba0e13caa936ab71b30ccbe3669cc5ab6e159f3 --- /dev/null +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -0,0 +1,67 @@ +"""add-api-key-auth-binding + +Revision ID: 7b45942e39bb +Revises: 47cc7df8c4f3 +Create Date: 2024-05-14 07:31:29.702766 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '7b45942e39bb' +down_revision = '4e99a8df00ff' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('data_source_api_key_auth_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True), + sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey') + ) + with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: + batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False) + batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.drop_index('source_binding_tenant_id_idx') + batch_op.drop_index('source_info_idx') + + op.rename_table('data_source_bindings', 'data_source_oauth_bindings') + + with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: + batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) + batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: + batch_op.drop_index('source_info_idx', postgresql_using='gin') + batch_op.drop_index('source_binding_tenant_id_idx') + + op.rename_table('data_source_oauth_bindings', 'data_source_bindings') + + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.create_index('source_info_idx', ['source_info'], unique=False) + batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx') + batch_op.drop_index('data_source_api_key_auth_binding_provider_idx') + + op.drop_table('data_source_api_key_auth_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..f09a682f285bd2cb4f44d2a056665d58cf16e304 --- /dev/null +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -0,0 +1,42 @@ +"""add workflow tool + +Revision ID: 7bdef072e63a +Revises: 5fda94355fce +Create Date: 2024-05-04 09:47:19.366961 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '7bdef072e63a' +down_revision = '5fda94355fce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_workflow_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), + sa.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + op.drop_table('tool_workflow_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py new file mode 100644 index 0000000000000000000000000000000000000000..881ffec61d76cffcaf7ab99afab5d5a14e75ad8e --- /dev/null +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -0,0 +1,44 @@ +"""add tool providers + +Revision ID: 7ce5a52e4eee +Revises: 2beac44e5f5f +Create Date: 2023-07-10 10:26:50.074515 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7ce5a52e4eee' +down_revision = '2beac44e5f5f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('tool_name', sa.String(length=40), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('sensitive_word_avoidance') + + op.drop_table('tool_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py new file mode 100644 index 0000000000000000000000000000000000000000..865572f3a75c71761be8070811b72da490f4c50d --- /dev/null +++ b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py @@ -0,0 +1,42 @@ +"""add table dataset_permissions + +Revision ID: 7e6a8693e07a +Revises: 4ff534e1eb11 +Create Date: 2024-06-25 03:20:46.012193 + +""" +import sqlalchemy as sa +from alembic import op + +import models.types + +# revision identifiers, used by Alembic. +revision = '7e6a8693e07a' +down_revision = 'b2602e131636' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_permissions', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), + sa.Column('has_permission', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_permission_pkey') + ) + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + batch_op.create_index('idx_dataset_permissions_account_id', ['account_id'], unique=False) + batch_op.create_index('idx_dataset_permissions_dataset_id', ['dataset_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + batch_op.drop_index('idx_dataset_permissions_dataset_id') + batch_op.drop_index('idx_dataset_permissions_account_id') + op.drop_table('dataset_permissions') + # ### end Alembic commands ### diff --git a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8476501b5e91ad49578384a196214798b4e44a --- /dev/null +++ b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py @@ -0,0 +1,42 @@ +"""add message price unit + +Revision ID: 853f9b9cd3b6 +Revises: e8883b0148c9 +Create Date: 2023-08-19 17:01:57.471562 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '853f9b9cd3b6' +down_revision = 'e8883b0148c9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('answer_price_unit') + batch_op.drop_column('message_price_unit') + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('answer_price_unit') + batch_op.drop_column('message_price_unit') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py new file mode 100644 index 0000000000000000000000000000000000000000..f7625bff8cb305b392298025a0c6a679dfc81b7b --- /dev/null +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -0,0 +1,31 @@ +"""add custom config in tenant + +Revision ID: 88072f0caa04 +Revises: fca025d3b60f +Create Date: 2023-12-14 07:36:50.705362 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '88072f0caa04' +down_revision = '246ba09cbbdb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenants', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenants', schema=None) as batch_op: + batch_op.drop_column('custom_config') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py new file mode 100644 index 0000000000000000000000000000000000000000..0fad39fa57f1ad548056f3bdc26e531d34b3f3e4 --- /dev/null +++ b/api/migrations/versions/89c7899ca936_.py @@ -0,0 +1,37 @@ +"""empty message + +Revision ID: 89c7899ca936 +Revises: 187385f442fc +Create Date: 2024-01-21 04:10:23.192853 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '89c7899ca936' +down_revision = '187385f442fc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=sa.VARCHAR(length=255), + type_=sa.Text(), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=sa.Text(), + type_=sa.VARCHAR(length=255), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c4ebb51bf0df13b51a6ef2c7e37e6023b44d13 --- /dev/null +++ b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py @@ -0,0 +1,32 @@ +"""add tool conversation variables idx + +Revision ID: 8ae9bc661daa +Revises: 9fafbd60eca1 +Create Date: 2024-01-15 14:22:03.597692 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '8ae9bc661daa' +down_revision = '9fafbd60eca1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_conversation_variables', schema=None) as batch_op: + batch_op.create_index('conversation_id_idx', ['conversation_id'], unique=False) + batch_op.create_index('user_id_idx', ['user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('user_id_idx') + batch_op.drop_index('conversation_id_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py new file mode 100644 index 0000000000000000000000000000000000000000..849103b0711d1b9e254b90ffd13fc0c01de6e04a --- /dev/null +++ b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py @@ -0,0 +1,42 @@ +"""add_qa_model_support + +Revision ID: 8d2d099ceb74 +Revises: a5b56fb053ef +Create Date: 2023-07-18 15:25:15.293438 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8d2d099ceb74' +down_revision = '7ce5a52e4eee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.add_column(sa.Column('answer', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('updated_by', postgresql.UUID(), nullable=True)) + batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('doc_form', sa.String(length=255), server_default=sa.text("'text_model'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('doc_form') + + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_column('updated_at') + batch_op.drop_column('updated_by') + batch_op.drop_column('answer') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2336da4dec712145cfae4880a07445006837b1 --- /dev/null +++ b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py @@ -0,0 +1,33 @@ +"""add environment variable to workflow model + +Revision ID: 8e5588e6412e +Revises: 6e957a32015b +Create Date: 2024-07-22 03:27:16.042533 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8e5588e6412e' +down_revision = '6e957a32015b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('environment_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('environment_variables') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py new file mode 100644 index 0000000000000000000000000000000000000000..6cafc198aafa646095d50842b26a82238dc436e5 --- /dev/null +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -0,0 +1,31 @@ +"""rename api provider credentials + +Revision ID: 8ec536f3c800 +Revises: ad472b61a054 +Create Date: 2024-01-07 03:57:35.257545 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '8ec536f3c800' +down_revision = 'ad472b61a054' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('credentials_str') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py new file mode 100644 index 0000000000000000000000000000000000000000..01d56315106c2a12bcccb44f17e04da6c3636fa9 --- /dev/null +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -0,0 +1,59 @@ +"""add gpt4v supports + +Revision ID: 8fe468ba0ca5 +Revises: a9836e3baeee +Create Date: 2023-11-09 11:39:00.006432 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8fe468ba0ca5' +down_revision = 'a9836e3baeee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('message_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('transfer_method', sa.String(length=255), nullable=False), + sa.Column('url', sa.Text(), nullable=True), + sa.Column('upload_file_id', postgresql.UUID(), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_file_pkey') + ) + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False) + batch_op.create_index('message_file_message_idx', ['message_id'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True)) + + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('created_by_role') + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('file_upload') + + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.drop_index('message_file_message_idx') + batch_op.drop_index('message_file_created_by_idx') + + op.drop_table('message_files') + # ### end Alembic commands ### diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..207a9c841f03d1123a44aa18eb3b31e76b7fdc12 --- /dev/null +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -0,0 +1,45 @@ +"""add_api_based_extension + +Revision ID: 968fff4c0ab9 +Revises: b3a09c049e8e +Create Date: 2023-10-27 13:05:58.901858 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '968fff4c0ab9' +down_revision = 'b3a09c049e8e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.create_table('api_based_extensions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('api_endpoint', sa.String(length=255), nullable=False), + sa.Column('api_key', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey') + ) + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.create_index('api_based_extension_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.drop_index('api_based_extension_tenant_idx') + + op.drop_table('api_based_extensions') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py new file mode 100644 index 0000000000000000000000000000000000000000..92f41f0abd0d91ce7de06cd76e93e62ccfeef198 --- /dev/null +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -0,0 +1,27 @@ +"""add workflow tool version + +Revision ID: 9e98fbaffb88 +Revises: 3b18fea55204 +Create Date: 2024-05-21 10:25:40.434162 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '9e98fbaffb88' +down_revision = '3b18fea55204' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + + +def downgrade(): + with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: + batch_op.drop_column('version') diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a98b4ac6880492763b3c13294675a5afe62072 --- /dev/null +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -0,0 +1,45 @@ +"""add created by role + +Revision ID: 9f4e3427ea84 +Revises: 64b051264f32 +Create Date: 2023-05-17 17:29:01.060435 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '9f4e3427ea84' +down_revision = '64b051264f32' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) + batch_op.drop_index('pinned_conversation_conversation_idx') + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by_role', 'created_by'], unique=False) + + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) + batch_op.drop_index('saved_message_message_idx') + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.drop_index('saved_message_message_idx') + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by'], unique=False) + batch_op.drop_column('created_by_role') + + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.drop_index('pinned_conversation_conversation_idx') + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by'], unique=False) + batch_op.drop_column('created_by_role') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py new file mode 100644 index 0000000000000000000000000000000000000000..968906bdd7b959e66ffed1b78a423b822fd39024 --- /dev/null +++ b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py @@ -0,0 +1,31 @@ +"""add message file belongs to + +Revision ID: 9fafbd60eca1 +Revises: 4823da1d26cf +Create Date: 2024-01-15 13:07:20.340896 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '9fafbd60eca1' +down_revision = '4823da1d26cf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('belongs_to', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.drop_column('belongs_to') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py new file mode 100644 index 0000000000000000000000000000000000000000..3014978110840e71bf2d04f7937986b5c9ce82a7 --- /dev/null +++ b/api/migrations/versions/a45f4dfde53b_add_language_to_recommend_apps.py @@ -0,0 +1,35 @@ +"""add language to recommend apps + +Revision ID: a45f4dfde53b +Revises: 9f4e3427ea84 +Create Date: 2023-05-25 17:50:32.052335 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a45f4dfde53b' +down_revision = '9f4e3427ea84' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False)) + batch_op.drop_index('recommended_app_is_listed_idx') + batch_op.create_index('recommended_app_is_listed_idx', ['is_listed', 'language'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.drop_index('recommended_app_is_listed_idx') + batch_op.create_index('recommended_app_is_listed_idx', ['is_listed'], unique=False) + batch_op.drop_column('language') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..acb681243445fd88fe123b1072313f1b2e32fc62 --- /dev/null +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -0,0 +1,31 @@ +"""app config add speech_to_text + +Revision ID: a5b56fb053ef +Revises: d3d503a3471c +Create Date: 2023-07-06 17:55:20.894149 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a5b56fb053ef' +down_revision = 'd3d503a3471c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('speech_to_text') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee01381d8456d51ba64a0cffca9d3bff097ff03 --- /dev/null +++ b/api/migrations/versions/a8d7385a7b66_add_embeddings_provider_name.py @@ -0,0 +1,34 @@ +"""add-embeddings-provider-name + +Revision ID: a8d7385a7b66 +Revises: 17b5ab037c40 +Create Date: 2024-04-02 12:17:22.641525 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a8d7385a7b66' +down_revision = '17b5ab037c40' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.add_column(sa.Column('provider_name', sa.String(length=40), server_default=sa.text("''::character varying"), nullable=False)) + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash', 'provider_name']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) + batch_op.drop_column('provider_name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py new file mode 100644 index 0000000000000000000000000000000000000000..62d6faeb1d58b7aec0b4c9c0640cf115b4df6d4d --- /dev/null +++ b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py @@ -0,0 +1,36 @@ +"""add_tenant_id_db_index + +Revision ID: a8f9b3c45e4a +Revises: 16830a790f0f +Create Date: 2024-03-18 05:07:35.588473 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a8f9b3c45e4a' +down_revision = '16830a790f0f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.create_index('document_segment_tenant_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.create_index('document_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_index('document_tenant_idx') + + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_tenant_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py new file mode 100644 index 0000000000000000000000000000000000000000..5dcb630aed1c4df19a5f3da0a01d395b411f9941 --- /dev/null +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -0,0 +1,31 @@ +"""add external_data_tools in app model config + +Revision ID: a9836e3baeee +Revises: 968fff4c0ab9 +Create Date: 2023-11-02 04:04:57.609485 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a9836e3baeee' +down_revision = '968fff4c0ab9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('external_data_tools') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py new file mode 100644 index 0000000000000000000000000000000000000000..eee41bf4e09f9199d098cc023ea9068f56565f08 --- /dev/null +++ b/api/migrations/versions/ab23c11305d4_add_dataset_query_variable_at_app_model_.py @@ -0,0 +1,31 @@ +"""add dataset query variable at app model configs. + +Revision ID: ab23c11305d4 +Revises: 6e2cfb077b04 +Create Date: 2023-09-26 12:22:59.044088 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'ab23c11305d4' +down_revision = '6e2cfb077b04' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_query_variable', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('dataset_query_variable') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/ad472b61a054_add_api_provider_icon.py b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddaf1eb0adf327a4ed51ed7bd432bc1382a0871 --- /dev/null +++ b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py @@ -0,0 +1,31 @@ +"""add api provider icon + +Revision ID: ad472b61a054 +Revises: 3ef9b2b6bee6 +Create Date: 2024-01-07 02:21:23.114790 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'ad472b61a054' +down_revision = '3ef9b2b6bee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon', sa.String(length=256), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('icon') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py new file mode 100644 index 0000000000000000000000000000000000000000..29ba859f2b7459bc99a511c28271a8752a7250f4 --- /dev/null +++ b/api/migrations/versions/b24be59fbb04_.py @@ -0,0 +1,31 @@ +"""empty message + +Revision ID: b24be59fbb04 +Revises: 187385f442fc +Create Date: 2024-01-17 01:31:12.670556 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b24be59fbb04' +down_revision = 'de95f5c77138' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('text_to_speech') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py b/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a6a5a5a7d90fa44324cec400bdd2916d399a32 --- /dev/null +++ b/api/migrations/versions/b2602e131636_add_workflow_run_id_index_for_message.py @@ -0,0 +1,32 @@ +"""add workflow_run_id index for message + +Revision ID: b2602e131636 +Revises: 63f9175e515b +Create Date: 2024-06-29 12:16:51.646346 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'b2602e131636' +down_revision = '63f9175e515b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_workflow_run_id_idx', ['conversation_id', 'workflow_run_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_workflow_run_id_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..966f86c05fa9f316d116276596fca176bbc16c2e --- /dev/null +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -0,0 +1,142 @@ +"""add workflow + +Revision ID: b289e2408ee2 +Revises: 16830a790f0f +Create Date: 2024-02-19 12:47:24.646954 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b289e2408ee2' +down_revision = 'a8d7385a7b66' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_app_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_app_log_pkey') + ) + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.create_index('workflow_app_log_app_idx', ['tenant_id', 'app_id'], unique=False) + + op.create_table('workflow_node_executions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=True), + sa.Column('index', sa.Integer(), nullable=False), + sa.Column('predecessor_node_id', sa.String(length=255), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=False), + sa.Column('node_type', sa.String(length=255), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('process_data', sa.Text(), nullable=True), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('execution_metadata', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') + ) + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.create_index('workflow_node_execution_node_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'], unique=False) + batch_op.create_index('workflow_node_execution_workflow_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'workflow_run_id'], unique=False) + + op.create_table('workflow_runs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('sequence_number', sa.Integer(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') + ) + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'triggered_from'], unique=False) + + op.create_table('workflows', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('features', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_pkey') + ) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_run_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('workflow_run_id') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('workflow_id') + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_index('workflow_version_idx') + + op.drop_table('workflows') + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_triggerd_from_idx') + + op.drop_table('workflow_runs') + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_workflow_run_idx') + batch_op.drop_index('workflow_node_execution_node_run_idx') + + op.drop_table('workflow_node_executions') + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_app_log_app_idx') + + op.drop_table('workflow_app_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..5682eff0307dca5ea5cdd2ee2bac40852f1bed3e --- /dev/null +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -0,0 +1,37 @@ +"""add advanced prompt templates + +Revision ID: b3a09c049e8e +Revises: 2e9819ca5b28 +Create Date: 2023-10-10 15:23:23.395420 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b3a09c049e8e' +down_revision = '2e9819ca5b28' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) + batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('dataset_configs') + batch_op.drop_column('completion_prompt_config') + batch_op.drop_column('chat_prompt_config') + batch_op.drop_column('prompt_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py new file mode 100644 index 0000000000000000000000000000000000000000..ee81fdab2872a29a7209c7e16ccb837b9ebe6820 --- /dev/null +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -0,0 +1,41 @@ +"""messages columns set nullable + +Revision ID: b5429b71023c +Revises: 42e85ed5564d +Create Date: 2024-03-07 09:52:00.846136 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b5429b71023c' +down_revision = '42e85ed5564d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py b/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5a7495e475f3e3e8fc741de01517e76da196ab --- /dev/null +++ b/api/migrations/versions/b69ca54b9208_add_chatbot_color_theme.py @@ -0,0 +1,35 @@ +"""add chatbot color theme + +Revision ID: b69ca54b9208 +Revises: 4ff534e1eb11 +Create Date: 2024-06-25 01:14:21.523873 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'b69ca54b9208' +down_revision = '4ff534e1eb11' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('chat_color_theme', sa.String(length=255), nullable=True)) + batch_op.add_column(sa.Column('chat_color_theme_inverted', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('chat_color_theme_inverted') + batch_op.drop_column('chat_color_theme') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa1517462bbb90b30d2db7a5096f6e169f53abe --- /dev/null +++ b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py @@ -0,0 +1,52 @@ +"""add provider order + +Revision ID: bf0aec5ba2cf +Revises: e35ed59becda +Create Date: 2023-08-10 00:03:44.273430 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bf0aec5ba2cf' +down_revision = 'e35ed59becda' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('provider_orders', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('payment_product_id', sa.String(length=191), nullable=False), + sa.Column('payment_id', sa.String(length=191), nullable=True), + sa.Column('transaction_id', sa.String(length=191), nullable=True), + sa.Column('quantity', sa.Integer(), server_default=sa.text('1'), nullable=False), + sa.Column('currency', sa.String(length=40), nullable=True), + sa.Column('total_amount', sa.Integer(), nullable=True), + sa.Column('payment_status', sa.String(length=40), server_default=sa.text("'wait_pay'::character varying"), nullable=False), + sa.Column('paid_at', sa.DateTime(), nullable=True), + sa.Column('pay_failed_at', sa.DateTime(), nullable=True), + sa.Column('refunded_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_order_pkey') + ) + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.create_index('provider_order_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.drop_index('provider_order_tenant_provider_idx') + + op.drop_table('provider_orders') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b9880ca42a9d02be3395abe7c01cf0876a9158 --- /dev/null +++ b/api/models/__init__.py @@ -0,0 +1,187 @@ +from .account import ( + Account, + AccountIntegrate, + AccountStatus, + InvitationCode, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetKeywordTable, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + Embedding, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, + TidbAuthBinding, + Whitelist, +) +from .engine import db +from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom +from .model import ( + ApiRequest, + ApiToken, + App, + AppAnnotationHitHistory, + AppAnnotationSetting, + AppMode, + AppModelConfig, + Conversation, + DatasetRetrieverResource, + DifySetup, + EndUser, + IconType, + InstalledApp, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, + OperationLog, + RecommendedApp, + Site, + Tag, + TagBinding, + TraceAppConfig, + UploadFile, +) +from .provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderOrder, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) +from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .task import CeleryTask, CeleryTaskSet +from .tools import ( + ApiToolProvider, + BuiltinToolProvider, + PublishedAppTool, + ToolConversationVariables, + ToolFile, + ToolLabelBinding, + ToolModelInvoke, + WorkflowToolProvider, +) +from .web import PinnedConversation, SavedMessage +from .workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowType, +) + +__all__ = [ + "APIBasedExtension", + "APIBasedExtensionPoint", + "Account", + "AccountIntegrate", + "AccountStatus", + "ApiRequest", + "ApiToken", + "ApiToolProvider", # Added + "App", + "AppAnnotationHitHistory", + "AppAnnotationSetting", + "AppDatasetJoin", + "AppMode", + "AppModelConfig", + "BuiltinToolProvider", # Added + "CeleryTask", + "CeleryTaskSet", + "Conversation", + "ConversationVariable", + "CreatedByRole", + "DataSourceApiKeyAuthBinding", + "DataSourceOauthBinding", + "Dataset", + "DatasetCollectionBinding", + "DatasetKeywordTable", + "DatasetPermission", + "DatasetPermissionEnum", + "DatasetProcessRule", + "DatasetQuery", + "DatasetRetrieverResource", + "DifySetup", + "Document", + "DocumentSegment", + "Embedding", + "EndUser", + "ExternalKnowledgeApis", + "ExternalKnowledgeBindings", + "IconType", + "InstalledApp", + "InvitationCode", + "LoadBalancingModelConfig", + "Message", + "MessageAgentThought", + "MessageAnnotation", + "MessageChain", + "MessageFeedback", + "MessageFile", + "OperationLog", + "PinnedConversation", + "Provider", + "ProviderModel", + "ProviderModelSetting", + "ProviderOrder", + "ProviderQuotaType", + "ProviderType", + "PublishedAppTool", + "RecommendedApp", + "SavedMessage", + "Site", + "Tag", + "TagBinding", + "Tenant", + "TenantAccountJoin", + "TenantAccountJoinRole", + "TenantAccountRole", + "TenantDefaultModel", + "TenantPreferredModelProvider", + "TenantStatus", + "TidbAuthBinding", + "ToolConversationVariables", + "ToolFile", + "ToolLabelBinding", + "ToolModelInvoke", + "TraceAppConfig", + "UploadFile", + "UserFrom", + "Whitelist", + "Workflow", + "WorkflowAppLog", + "WorkflowAppLogCreatedFrom", + "WorkflowNodeExecution", + "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionTriggeredFrom", + "WorkflowRun", + "WorkflowRunStatus", + "WorkflowRunTriggeredFrom", + "WorkflowToolProvider", + "WorkflowType", + "db", +] diff --git a/api/models/account.py b/api/models/account.py new file mode 100644 index 0000000000000000000000000000000000000000..35a28df7505943cc29c5388f020486e9f81a7ba2 --- /dev/null +++ b/api/models/account.py @@ -0,0 +1,267 @@ +import enum +import json + +from flask_login import UserMixin # type: ignore +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + +from .engine import db +from .types import StringUUID + + +class AccountStatus(enum.StrEnum): + PENDING = "pending" + UNINITIALIZED = "uninitialized" + ACTIVE = "active" + BANNED = "banned" + CLOSED = "closed" + + +class Account(UserMixin, db.Model): # type: ignore[name-defined] + __tablename__ = "accounts" + __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name = db.Column(db.String(255), nullable=False) + email = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=True) + password_salt = db.Column(db.String(255), nullable=True) + avatar = db.Column(db.String(255)) + interface_language = db.Column(db.String(255)) + interface_theme = db.Column(db.String(255)) + timezone = db.Column(db.String(255)) + last_login_at = db.Column(db.DateTime) + last_login_ip = db.Column(db.String(255)) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) + initialized_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def is_password_set(self): + return self.password is not None + + @property + def current_tenant(self): + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore + + @current_tenant.setter + def current_tenant(self, value: "Tenant"): + tenant = value + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() + if ta: + tenant.current_role = ta.role + else: + # FIXME: fix the type error later, because the type is important maybe cause some bugs + tenant = None # type: ignore + self._current_tenant = tenant + + @property + def current_tenant_id(self) -> str | None: + return self._current_tenant.id if self._current_tenant else None + + @current_tenant_id.setter + def current_tenant_id(self, value: str): + try: + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == value) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == self.id) + .one_or_none() + ) + + if tenant_account_join: + tenant, ta = tenant_account_join + tenant.current_role = ta.role + else: + tenant = None + except: + tenant = None + + self._current_tenant = tenant + + @property + def current_role(self): + return self._current_tenant.current_role + + def get_status(self) -> AccountStatus: + status_str = self.status + return AccountStatus(status_str) + + @classmethod + def get_by_openid(cls, provider: str, open_id: str): + account_integrate = ( + db.session.query(AccountIntegrate) + .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .one_or_none() + ) + if account_integrate: + return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() + return None + + @property + def is_admin_or_owner(self): + return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) + + @property + def is_admin(self): + return TenantAccountRole.is_admin_role(self._current_tenant.current_role) + + @property + def is_editor(self): + return TenantAccountRole.is_editing_role(self._current_tenant.current_role) + + @property + def is_dataset_editor(self): + return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) + + @property + def is_dataset_operator(self): + return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + + +class TenantStatus(enum.StrEnum): + NORMAL = "normal" + ARCHIVE = "archive" + + +class TenantAccountRole(enum.StrEnum): + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" + + @staticmethod + def is_valid_role(role: str) -> bool: + return role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + + @staticmethod + def is_privileged_role(role: str) -> bool: + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + + @staticmethod + def is_admin_role(role: str) -> bool: + return role == TenantAccountRole.ADMIN + + @staticmethod + def is_non_owner_role(role: str) -> bool: + return role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + + @staticmethod + def is_editing_role(role: str) -> bool: + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + + @staticmethod + def is_dataset_edit_role(role: str) -> bool: + return role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + + +class Tenant(db.Model): # type: ignore[name-defined] + __tablename__ = "tenants" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name = db.Column(db.String(255), nullable=False) + encrypt_public_key = db.Column(db.Text) + plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + custom_config = db.Column(db.Text) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + def get_accounts(self) -> list[Account]: + return ( + db.session.query(Account) + .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .all() + ) + + @property + def custom_config_dict(self) -> dict: + return json.loads(self.custom_config) if self.custom_config else {} + + @custom_config_dict.setter + def custom_config_dict(self, value: dict): + self.custom_config = json.dumps(value) + + +class TenantAccountJoinRole(enum.Enum): + OWNER = "owner" + ADMIN = "admin" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" + + +class TenantAccountJoin(db.Model): # type: ignore[name-defined] + __tablename__ = "tenant_account_joins" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + db.Index("tenant_account_join_account_id_idx", "account_id"), + db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) + current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + role = db.Column(db.String(16), nullable=False, server_default="normal") + invited_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class AccountIntegrate(db.Model): # type: ignore[name-defined] + __tablename__ = "account_integrates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + account_id = db.Column(StringUUID, nullable=False) + provider = db.Column(db.String(16), nullable=False) + open_id = db.Column(db.String(255), nullable=False) + encrypted_token = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class InvitationCode(db.Model): # type: ignore[name-defined] + __tablename__ = "invitation_codes" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + db.Index("invitation_codes_batch_idx", "batch"), + db.Index("invitation_codes_code_idx", "code", "status"), + ) + + id = db.Column(db.Integer, nullable=False) + batch = db.Column(db.String(255), nullable=False) + code = db.Column(db.String(32), nullable=False) + status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) + used_at = db.Column(db.DateTime) + used_by_tenant_id = db.Column(StringUUID) + used_by_account_id = db.Column(StringUUID) + deprecated_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6d808710afc00933f9aad948407fd37b984a1f --- /dev/null +++ b/api/models/api_based_extension.py @@ -0,0 +1,28 @@ +import enum + +from sqlalchemy import func + +from .engine import db +from .types import StringUUID + + +class APIBasedExtensionPoint(enum.Enum): + APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" + PING = "ping" + APP_MODERATION_INPUT = "app.moderation.input" + APP_MODERATION_OUTPUT = "app.moderation.output" + + +class APIBasedExtension(db.Model): # type: ignore[name-defined] + __tablename__ = "api_based_extensions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + db.Index("api_based_extension_tenant_idx", "tenant_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + api_endpoint = db.Column(db.String(255), nullable=False) + api_key = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf3dc42fe8235ace8c0dea14c369b2dd9090395 --- /dev/null +++ b/api/models/dataset.py @@ -0,0 +1,928 @@ +import base64 +import enum +import hashlib +import hmac +import json +import logging +import os +import pickle +import re +import time +from json import JSONDecodeError +from typing import Any, cast + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped + +from configs import dify_config +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + +from .account import Account +from .engine import db +from .model import App, Tag, TagBinding, UploadFile +from .types import StringUUID + + +class DatasetPermissionEnum(enum.StrEnum): + ONLY_ME = "only_me" + ALL_TEAM = "all_team_members" + PARTIAL_TEAM = "partial_members" + + +class Dataset(db.Model): # type: ignore[name-defined] + __tablename__ = "datasets" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_pkey"), + db.Index("dataset_tenant_idx", "tenant_id"), + db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + ) + + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] + PROVIDER_LIST = ["vendor", "external", None] + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=True) + provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) + data_source_type = db.Column(db.String(255)) + indexing_technique = db.Column(db.String(255), nullable=True) + index_struct = db.Column(db.Text, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(db.String(255), nullable=True) + embedding_model_provider = db.Column(db.String(255), nullable=True) + collection_binding_id = db.Column(StringUUID, nullable=True) + retrieval_model = db.Column(JSONB, nullable=True) + + @property + def dataset_keyword_table(self): + dataset_keyword_table = ( + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + ) + if dataset_keyword_table: + return dataset_keyword_table + + return None + + @property + def index_struct_dict(self): + return json.loads(self.index_struct) if self.index_struct else None + + @property + def external_retrieval_model(self): + default_retrieval_model = { + "top_k": 2, + "score_threshold": 0.0, + } + return self.retrieval_model or default_retrieval_model + + @property + def created_by_account(self): + return db.session.get(Account, self.created_by) + + @property + def latest_process_rule(self): + return ( + DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + .order_by(DatasetProcessRule.created_at.desc()) + .first() + ) + + @property + def app_count(self): + return ( + db.session.query(func.count(AppDatasetJoin.id)) + .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .scalar() + ) + + @property + def document_count(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + + @property + def available_document_count(self): + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) + + @property + def available_segment_count(self): + return ( + db.session.query(func.count(DocumentSegment.id)) + .filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .scalar() + ) + + @property + def word_count(self): + return ( + Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + .filter(Document.dataset_id == self.id) + .scalar() + ) + + @property + def doc_form(self): + document = db.session.query(Document).filter(Document.dataset_id == self.id).first() + if document: + return document.doc_form + return None + + @property + def retrieval_model_dict(self): + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + return self.retrieval_model or default_retrieval_model + + @property + def tags(self): + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "knowledge", + ) + .all() + ) + + return tags or [] + + @property + def external_knowledge_info(self): + if self.provider != "external": + return None + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() + ) + if not external_knowledge_binding: + return None + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) + .first() + ) + if not external_knowledge_api: + return None + return { + "external_knowledge_id": external_knowledge_binding.external_knowledge_id, + "external_knowledge_api_id": external_knowledge_api.id, + "external_knowledge_api_name": external_knowledge_api.name, + "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), + } + + @staticmethod + def gen_collection_name_by_id(dataset_id: str) -> str: + normalized_dataset_id = dataset_id.replace("-", "_") + return f"Vector_index_{normalized_dataset_id}_Node" + + +class DatasetProcessRule(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_process_rules" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = db.Column(StringUUID, nullable=False) + mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + rules = db.Column(db.Text, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + MODES = ["automatic", "custom", "hierarchical"] + PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] + AUTOMATIC_RULES: dict[str, Any] = { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, + ], + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + } + + def to_dict(self): + return { + "id": self.id, + "dataset_id": self.dataset_id, + "mode": self.mode, + "rules": self.rules_dict, + } + + @property + def rules_dict(self): + try: + return json.loads(self.rules) if self.rules else None + except JSONDecodeError: + return None + + +class Document(db.Model): # type: ignore[name-defined] + __tablename__ = "documents" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="document_pkey"), + db.Index("document_dataset_id_idx", "dataset_id"), + db.Index("document_is_paused_idx", "is_paused"), + db.Index("document_tenant_idx", "tenant_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + data_source_type = db.Column(db.String(255), nullable=False) + data_source_info = db.Column(db.Text, nullable=True) + dataset_process_rule_id = db.Column(StringUUID, nullable=True) + batch = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False) + created_from = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_api_request_id = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + # start processing + processing_started_at = db.Column(db.DateTime, nullable=True) + + # parsing + file_id = db.Column(db.Text, nullable=True) + word_count = db.Column(db.Integer, nullable=True) + parsing_completed_at = db.Column(db.DateTime, nullable=True) + + # cleaning + cleaning_completed_at = db.Column(db.DateTime, nullable=True) + + # split + splitting_completed_at = db.Column(db.DateTime, nullable=True) + + # indexing + tokens = db.Column(db.Integer, nullable=True) + indexing_latency = db.Column(db.Float, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + + # pause + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + paused_by = db.Column(StringUUID, nullable=True) + paused_at = db.Column(db.DateTime, nullable=True) + + # error + error = db.Column(db.Text, nullable=True) + stopped_at = db.Column(db.DateTime, nullable=True) + + # basic fields + indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = db.Column(db.DateTime, nullable=True) + disabled_by = db.Column(StringUUID, nullable=True) + archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = db.Column(db.String(255), nullable=True) + archived_by = db.Column(StringUUID, nullable=True) + archived_at = db.Column(db.DateTime, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = db.Column(db.String(40), nullable=True) + doc_metadata = db.Column(db.JSON, nullable=True) + doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = db.Column(db.String(255), nullable=True) + + DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] + + @property + def display_status(self): + status = None + if self.indexing_status == "waiting": + status = "queuing" + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: + status = "paused" + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: + status = "indexing" + elif self.indexing_status == "error": + status = "error" + elif self.indexing_status == "completed" and not self.archived and self.enabled: + status = "available" + elif self.indexing_status == "completed" and not self.archived and not self.enabled: + status = "disabled" + elif self.indexing_status == "completed" and self.archived: + status = "archived" + return status + + @property + def data_source_info_dict(self): + if self.data_source_info: + try: + data_source_info_dict = json.loads(self.data_source_info) + except JSONDecodeError: + data_source_info_dict = {} + + return data_source_info_dict + return None + + @property + def data_source_detail_dict(self): + if self.data_source_info: + if self.data_source_type == "upload_file": + data_source_info_dict = json.loads(self.data_source_info) + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .one_or_none() + ) + if file_detail: + return { + "upload_file": { + "id": file_detail.id, + "name": file_detail.name, + "size": file_detail.size, + "extension": file_detail.extension, + "mime_type": file_detail.mime_type, + "created_by": file_detail.created_by, + "created_at": file_detail.created_at.timestamp(), + } + } + elif self.data_source_type in {"notion_import", "website_crawl"}: + return json.loads(self.data_source_info) + return {} + + @property + def average_segment_length(self): + if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0: + return self.word_count // self.segment_count + return 0 + + @property + def dataset_process_rule(self): + if self.dataset_process_rule_id: + return db.session.get(DatasetProcessRule, self.dataset_process_rule_id) + return None + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() + + @property + def segment_count(self): + return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() + + @property + def hit_count(self): + return ( + DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + .filter(DocumentSegment.document_id == self.id) + .scalar() + ) + + @property + def process_rule_dict(self): + if self.dataset_process_rule_id: + return self.dataset_process_rule.to_dict() + return None + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "dataset_id": self.dataset_id, + "position": self.position, + "data_source_type": self.data_source_type, + "data_source_info": self.data_source_info, + "dataset_process_rule_id": self.dataset_process_rule_id, + "batch": self.batch, + "name": self.name, + "created_from": self.created_from, + "created_by": self.created_by, + "created_api_request_id": self.created_api_request_id, + "created_at": self.created_at, + "processing_started_at": self.processing_started_at, + "file_id": self.file_id, + "word_count": self.word_count, + "parsing_completed_at": self.parsing_completed_at, + "cleaning_completed_at": self.cleaning_completed_at, + "splitting_completed_at": self.splitting_completed_at, + "tokens": self.tokens, + "indexing_latency": self.indexing_latency, + "completed_at": self.completed_at, + "is_paused": self.is_paused, + "paused_by": self.paused_by, + "paused_at": self.paused_at, + "error": self.error, + "stopped_at": self.stopped_at, + "indexing_status": self.indexing_status, + "enabled": self.enabled, + "disabled_at": self.disabled_at, + "disabled_by": self.disabled_by, + "archived": self.archived, + "archived_reason": self.archived_reason, + "archived_by": self.archived_by, + "archived_at": self.archived_at, + "updated_at": self.updated_at, + "doc_type": self.doc_type, + "doc_metadata": self.doc_metadata, + "doc_form": self.doc_form, + "doc_language": self.doc_language, + "display_status": self.display_status, + "data_source_info_dict": self.data_source_info_dict, + "average_segment_length": self.average_segment_length, + "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + "dataset": self.dataset.to_dict() if self.dataset else None, + "segment_count": self.segment_count, + "hit_count": self.hit_count, + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data.get("id"), + tenant_id=data.get("tenant_id"), + dataset_id=data.get("dataset_id"), + position=data.get("position"), + data_source_type=data.get("data_source_type"), + data_source_info=data.get("data_source_info"), + dataset_process_rule_id=data.get("dataset_process_rule_id"), + batch=data.get("batch"), + name=data.get("name"), + created_from=data.get("created_from"), + created_by=data.get("created_by"), + created_api_request_id=data.get("created_api_request_id"), + created_at=data.get("created_at"), + processing_started_at=data.get("processing_started_at"), + file_id=data.get("file_id"), + word_count=data.get("word_count"), + parsing_completed_at=data.get("parsing_completed_at"), + cleaning_completed_at=data.get("cleaning_completed_at"), + splitting_completed_at=data.get("splitting_completed_at"), + tokens=data.get("tokens"), + indexing_latency=data.get("indexing_latency"), + completed_at=data.get("completed_at"), + is_paused=data.get("is_paused"), + paused_by=data.get("paused_by"), + paused_at=data.get("paused_at"), + error=data.get("error"), + stopped_at=data.get("stopped_at"), + indexing_status=data.get("indexing_status"), + enabled=data.get("enabled"), + disabled_at=data.get("disabled_at"), + disabled_by=data.get("disabled_by"), + archived=data.get("archived"), + archived_reason=data.get("archived_reason"), + archived_by=data.get("archived_by"), + archived_at=data.get("archived_at"), + updated_at=data.get("updated_at"), + doc_type=data.get("doc_type"), + doc_metadata=data.get("doc_metadata"), + doc_form=data.get("doc_form"), + doc_language=data.get("doc_language"), + ) + + +class DocumentSegment(db.Model): # type: ignore[name-defined] + __tablename__ = "document_segments" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="document_segment_pkey"), + db.Index("document_segment_dataset_id_idx", "dataset_id"), + db.Index("document_segment_document_id_idx", "document_id"), + db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_tenant_idx", "tenant_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + position: Mapped[int] + content = db.Column(db.Text, nullable=False) + answer = db.Column(db.Text, nullable=True) + word_count = db.Column(db.Integer, nullable=False) + tokens = db.Column(db.Integer, nullable=False) + + # indexing fields + keywords = db.Column(db.JSON, nullable=True) + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + + # basic fields + hit_count = db.Column(db.Integer, nullable=False, default=0) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = db.Column(db.DateTime, nullable=True) + disabled_by = db.Column(StringUUID, nullable=True) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + stopped_at = db.Column(db.DateTime, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def previous_segment(self): + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) + .first() + ) + + @property + def next_segment(self): + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) + .first() + ) + + @property + def child_chunks(self): + process_rule = self.document.dataset_process_rule + if process_rule.mode == "hierarchical": + rules = Rule(**process_rule.rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .filter(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + else: + return [] + else: + return [] + + def get_sign_content(self): + signed_urls = [] + text = self.content + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview" + matches = re.finditer(pattern, text) + for match in matches: + upload_file_id = match.group(1) + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + signed_url = f"{match.group(0)}?{params}" + signed_urls.append((match.start(), match.end(), signed_url)) + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview" + matches = re.finditer(pattern, text) + for match in matches: + upload_file_id = match.group(1) + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + signed_url = f"{match.group(0)}?{params}" + signed_urls.append((match.start(), match.end(), signed_url)) + + # Reconstruct the text with signed URLs + offset = 0 + for start, end, signed_url in signed_urls: + text = text[: start + offset] + signed_url + text[end + offset :] + offset += len(signed_url) - (end - start) + + return text + + +class ChildChunk(db.Model): # type: ignore[name-defined] + __tablename__ = "child_chunks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + # indexing fields + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def segment(self): + return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + + +class AppDatasetJoin(db.Model): # type: ignore[name-defined] + __tablename__ = "app_dataset_joins" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), + ) + + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + @property + def app(self): + return db.session.get(App, self.app_id) + + +class DatasetQuery(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_queries" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + db.Index("dataset_query_dataset_id_idx", "dataset_id"), + ) + + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = db.Column(StringUUID, nullable=False) + content = db.Column(db.Text, nullable=False) + source = db.Column(db.String(255), nullable=False) + source_app_id = db.Column(StringUUID, nullable=True) + created_by_role = db.Column(db.String, nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_keyword_tables" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), + ) + + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + dataset_id = db.Column(StringUUID, nullable=False, unique=True) + keyword_table = db.Column(db.Text, nullable=False) + data_source_type = db.Column( + db.String(255), nullable=False, server_default=db.text("'database'::character varying") + ) + + @property + def keyword_table_dict(self): + class SetDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + super().__init__(object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, dct): + if isinstance(dct, dict): + for keyword, node_idxs in dct.items(): + if isinstance(node_idxs, list): + dct[keyword] = set(node_idxs) + return dct + + # get dataset + dataset = Dataset.query.filter_by(id=self.dataset_id).first() + if not dataset: + return None + if self.data_source_type == "database": + return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None + else: + file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" + try: + keyword_table_text = storage.load_once(file_key) + if keyword_table_text: + return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) + return None + except Exception as e: + logging.exception(f"Failed to load keyword table from file: {file_key}") + return None + + +class Embedding(db.Model): # type: ignore[name-defined] + __tablename__ = "embeddings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="embedding_pkey"), + db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + db.Index("created_at_idx", "created_at"), + ) + + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = db.Column( + db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + ) + hash = db.Column(db.String(64), nullable=False) + embedding = db.Column(db.LargeBinary, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + + def set_embedding(self, embedding_data: list[float]): + self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) + + def get_embedding(self) -> list[float]: + return cast(list[float], pickle.loads(self.embedding)) + + +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_collection_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + db.Index("provider_model_name_idx", "provider_name", "model_name"), + ) + + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + provider_name = db.Column(db.String(40), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = db.Column(db.String(64), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class TidbAuthBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "tidb_auth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + db.Index("tidb_auth_bindings_active_idx", "active"), + db.Index("tidb_auth_bindings_created_at_idx", "created_at"), + db.Index("tidb_auth_bindings_status_idx", "status"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + cluster_id = db.Column(db.String(255), nullable=False) + cluster_name = db.Column(db.String(255), nullable=False) + active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class Whitelist(db.Model): # type: ignore[name-defined] + __tablename__ = "whitelists" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="whitelists_pkey"), + db.Index("whitelists_tenant_idx", "tenant_id"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + category = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetPermission(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_permissions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + db.Index("idx_dataset_permissions_account_id", "account_id"), + db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + dataset_id = db.Column(StringUUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) + has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] + __tablename__ = "external_knowledge_apis" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), + db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), + db.Index("external_knowledge_apis_name_idx", "name"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.String(255), nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) + settings = db.Column(db.Text, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "name": self.name, + "description": self.description, + "settings": self.settings_dict, + "dataset_bindings": self.dataset_bindings, + "created_by": self.created_by, + "created_at": self.created_at.isoformat(), + } + + @property + def settings_dict(self): + try: + return json.loads(self.settings) if self.settings else None + except JSONDecodeError: + return None + + @property + def dataset_bindings(self): + external_knowledge_bindings = ( + db.session.query(ExternalKnowledgeBindings) + .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + .all() + ) + dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] + datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() + dataset_bindings = [] + for dataset in datasets: + dataset_bindings.append({"id": dataset.id, "name": dataset.name}) + + return dataset_bindings + + +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] + __tablename__ = "external_knowledge_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + external_knowledge_api_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + external_knowledge_id = db.Column(db.Text, nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_auto_disable_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + db.Index("dataset_auto_disable_log_created_atx", "created_at"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/engine.py b/api/models/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..dda93bc9415cfcc890fdede9f6c7b476e8e4fbb9 --- /dev/null +++ b/api/models/engine.py @@ -0,0 +1,13 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData + +POSTGRES_INDEXES_NAMING_CONVENTION = { + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + +metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) +db = SQLAlchemy(metadata=metadata) diff --git a/api/models/enums.py b/api/models/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9500ebe4d36a5d4540ab28db13c5d942048311 --- /dev/null +++ b/api/models/enums.py @@ -0,0 +1,16 @@ +from enum import StrEnum + + +class CreatedByRole(StrEnum): + ACCOUNT = "account" + END_USER = "end_user" + + +class UserFrom(StrEnum): + ACCOUNT = "account" + END_USER = "end-user" + + +class WorkflowRunTriggeredFrom(StrEnum): + DEBUGGING = "debugging" + APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2780b79c98e2b6791881403741607a7f06b387c8 --- /dev/null +++ b/api/models/model.py @@ -0,0 +1,1720 @@ +import json +import re +import uuid +from collections.abc import Mapping +from datetime import datetime +from enum import Enum, StrEnum +from typing import TYPE_CHECKING, Any, Literal, Optional, cast + +import sqlalchemy as sa +from flask import request +from flask_login import UserMixin # type: ignore +from sqlalchemy import Float, func, text +from sqlalchemy.orm import Mapped, mapped_column + +from configs import dify_config +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from core.file import helpers as file_helpers +from core.file.tool_file_parser import ToolFileParser +from libs.helper import generate_string +from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus + +from .account import Account, Tenant +from .engine import db +from .types import StringUUID + +if TYPE_CHECKING: + from .workflow import Workflow + + +class DifySetup(db.Model): # type: ignore[name-defined] + __tablename__ = "dify_setups" + __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) + + version = db.Column(db.String(255), nullable=False) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class AppMode(StrEnum): + COMPLETION = "completion" + WORKFLOW = "workflow" + CHAT = "chat" + ADVANCED_CHAT = "advanced-chat" + AGENT_CHAT = "agent-chat" + CHANNEL = "channel" + + @classmethod + def value_of(cls, value: str) -> "AppMode": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid mode value {value}") + + +class IconType(Enum): + IMAGE = "image" + EMOJI = "emoji" + + +class App(db.Model): # type: ignore[name-defined] + __tablename__ = "apps" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + mode = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) # image, emoji + icon = db.Column(db.String(255)) + icon_background = db.Column(db.String(255)) + app_model_config_id = db.Column(StringUUID, nullable=True) + workflow_id = db.Column(StringUUID, nullable=True) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + enable_site = db.Column(db.Boolean, nullable=False) + enable_api = db.Column(db.Boolean, nullable=False) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + tracing = db.Column(db.Text, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + + @property + def desc_or_prompt(self): + if self.description: + return self.description + else: + app_model_config = self.app_model_config + if app_model_config: + return app_model_config.pre_prompt + else: + return "" + + @property + def site(self): + site = db.session.query(Site).filter(Site.app_id == self.id).first() + return site + + @property + def app_model_config(self): + if self.app_model_config_id: + return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + + return None + + @property + def workflow(self) -> Optional["Workflow"]: + if self.workflow_id: + from .workflow import Workflow + + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + + @property + def api_base_url(self): + return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" + + @property + def tenant(self): + tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return tenant + + @property + def is_agent(self) -> bool: + app_model_config = self.app_model_config + if not app_model_config: + return False + if not app_model_config.agent_mode: + return False + if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + "strategy", "" + ) in {"function_call", "react"}: + self.mode = AppMode.AGENT_CHAT.value + db.session.commit() + return True + return False + + @property + def mode_compatible_with_agent(self) -> str: + if self.mode == AppMode.CHAT.value and self.is_agent: + return AppMode.AGENT_CHAT.value + + return str(self.mode) + + @property + def deleted_tools(self) -> list: + # get agent mode tools + app_model_config = self.app_model_config + if not app_model_config: + return [] + if not app_model_config.agent_mode: + return [] + agent_mode = app_model_config.agent_mode_dict + tools = agent_mode.get("tools", []) + + provider_ids = [] + + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api": + # check if provider id is a uuid string, if not, skip + try: + uuid.UUID(provider_id) + except Exception: + continue + provider_ids.append(provider_id) + + if not provider_ids: + return [] + + api_providers = db.session.execute( + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} + ).fetchall() + + deleted_tools = [] + current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers] + + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api" and provider_id not in current_api_provider_ids: + deleted_tools.append(tool["tool_name"]) + + return deleted_tools + + @property + def tags(self): + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "app", + ) + .all() + ) + + return tags or [] + + +class AppModelConfig(db.Model): # type: ignore[name-defined] + __tablename__ = "app_model_configs" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) + configs = db.Column(db.JSON, nullable=True) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = db.Column(db.Text) + suggested_questions = db.Column(db.Text) + suggested_questions_after_answer = db.Column(db.Text) + speech_to_text = db.Column(db.Text) + text_to_speech = db.Column(db.Text) + more_like_this = db.Column(db.Text) + model = db.Column(db.Text) + user_input_form = db.Column(db.Text) + dataset_query_variable = db.Column(db.String(255)) + pre_prompt = db.Column(db.Text) + agent_mode = db.Column(db.Text) + sensitive_word_avoidance = db.Column(db.Text) + retriever_resource = db.Column(db.Text) + prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + chat_prompt_config = db.Column(db.Text) + completion_prompt_config = db.Column(db.Text) + dataset_configs = db.Column(db.Text) + external_data_tools = db.Column(db.Text) + file_upload = db.Column(db.Text) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + @property + def model_dict(self) -> dict: + return json.loads(self.model) if self.model else {} + + @property + def suggested_questions_list(self) -> list: + return json.loads(self.suggested_questions) if self.suggested_questions else [] + + @property + def suggested_questions_after_answer_dict(self) -> dict: + return ( + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer + else {"enabled": False} + ) + + @property + def speech_to_text_dict(self) -> dict: + return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} + + @property + def text_to_speech_dict(self) -> dict: + return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} + + @property + def retriever_resource_dict(self) -> dict: + return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + + @property + def annotation_reply_dict(self) -> dict: + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + ) + if annotation_setting: + collection_binding_detail = annotation_setting.collection_binding_detail + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + + else: + return {"enabled": False} + + @property + def more_like_this_dict(self) -> dict: + return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + + @property + def sensitive_word_avoidance_dict(self) -> dict: + return ( + json.loads(self.sensitive_word_avoidance) + if self.sensitive_word_avoidance + else {"enabled": False, "type": "", "configs": []} + ) + + @property + def external_data_tools_list(self) -> list[dict]: + return json.loads(self.external_data_tools) if self.external_data_tools else [] + + @property + def user_input_form_list(self) -> list[dict]: + return json.loads(self.user_input_form) if self.user_input_form else [] + + @property + def agent_mode_dict(self) -> dict: + return ( + json.loads(self.agent_mode) + if self.agent_mode + else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + ) + + @property + def chat_prompt_config_dict(self) -> dict: + return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + + @property + def completion_prompt_config_dict(self) -> dict: + return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + + @property + def dataset_configs_dict(self) -> dict: + if self.dataset_configs: + dataset_configs: dict = json.loads(self.dataset_configs) + if "retrieval_model" not in dataset_configs: + return {"retrieval_model": "single"} + else: + return dataset_configs + return { + "retrieval_model": "multiple", + } + + @property + def file_upload_dict(self) -> dict: + return ( + json.loads(self.file_upload) + if self.file_upload + else { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + } + ) + + def to_dict(self) -> dict: + return { + "opening_statement": self.opening_statement, + "suggested_questions": self.suggested_questions_list, + "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, + "speech_to_text": self.speech_to_text_dict, + "text_to_speech": self.text_to_speech_dict, + "retriever_resource": self.retriever_resource_dict, + "annotation_reply": self.annotation_reply_dict, + "more_like_this": self.more_like_this_dict, + "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, + "external_data_tools": self.external_data_tools_list, + "model": self.model_dict, + "user_input_form": self.user_input_form_list, + "dataset_query_variable": self.dataset_query_variable, + "pre_prompt": self.pre_prompt, + "agent_mode": self.agent_mode_dict, + "prompt_type": self.prompt_type, + "chat_prompt_config": self.chat_prompt_config_dict, + "completion_prompt_config": self.completion_prompt_config_dict, + "dataset_configs": self.dataset_configs_dict, + "file_upload": self.file_upload_dict, + } + + def from_model_config_dict(self, model_config: Mapping[str, Any]): + self.opening_statement = model_config.get("opening_statement") + self.suggested_questions = ( + json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + ) + self.suggested_questions_after_answer = ( + json.dumps(model_config["suggested_questions_after_answer"]) + if model_config.get("suggested_questions_after_answer") + else None + ) + self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None + self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None + self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.sensitive_word_avoidance = ( + json.dumps(model_config["sensitive_word_avoidance"]) + if model_config.get("sensitive_word_avoidance") + else None + ) + self.external_data_tools = ( + json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + ) + self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.user_input_form = ( + json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + ) + self.dataset_query_variable = model_config.get("dataset_query_variable") + self.pre_prompt = model_config["pre_prompt"] + self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.retriever_resource = ( + json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + ) + self.prompt_type = model_config.get("prompt_type", "simple") + self.chat_prompt_config = ( + json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None + ) + self.completion_prompt_config = ( + json.dumps(model_config.get("completion_prompt_config")) + if model_config.get("completion_prompt_config") + else None + ) + self.dataset_configs = ( + json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None + ) + self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None + return self + + def copy(self): + new_app_model_config = AppModelConfig( + id=self.id, + app_id=self.app_id, + opening_statement=self.opening_statement, + suggested_questions=self.suggested_questions, + suggested_questions_after_answer=self.suggested_questions_after_answer, + speech_to_text=self.speech_to_text, + text_to_speech=self.text_to_speech, + more_like_this=self.more_like_this, + sensitive_word_avoidance=self.sensitive_word_avoidance, + external_data_tools=self.external_data_tools, + model=self.model, + user_input_form=self.user_input_form, + dataset_query_variable=self.dataset_query_variable, + pre_prompt=self.pre_prompt, + agent_mode=self.agent_mode, + retriever_resource=self.retriever_resource, + prompt_type=self.prompt_type, + chat_prompt_config=self.chat_prompt_config, + completion_prompt_config=self.completion_prompt_config, + dataset_configs=self.dataset_configs, + file_upload=self.file_upload, + ) + + return new_app_model_config + + +class RecommendedApp(db.Model): # type: ignore[name-defined] + __tablename__ = "recommended_apps" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + db.Index("recommended_app_app_id_idx", "app_id"), + db.Index("recommended_app_is_listed_idx", "is_listed", "language"), + ) + + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + description = db.Column(db.JSON, nullable=False) + copyright = db.Column(db.String(255), nullable=False) + privacy_policy = db.Column(db.String(255), nullable=False) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") + category = db.Column(db.String(255), nullable=False) + position = db.Column(db.Integer, nullable=False, default=0) + is_listed = db.Column(db.Boolean, nullable=False, default=True) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + +class InstalledApp(db.Model): # type: ignore[name-defined] + __tablename__ = "installed_apps" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="installed_app_pkey"), + db.Index("installed_app_tenant_id_idx", "tenant_id"), + db.Index("installed_app_app_id_idx", "app_id"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + app_owner_tenant_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False, default=0) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + last_used_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + @property + def tenant(self): + tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return tenant + + +class Conversation(db.Model): # type: ignore[name-defined] + __tablename__ = "conversations" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="conversation_pkey"), + db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + app_model_config_id = db.Column(StringUUID, nullable=True) + model_provider = db.Column(db.String(255), nullable=True) + override_model_configs = db.Column(db.Text) + model_id = db.Column(db.String(255), nullable=True) + mode: Mapped[str] = mapped_column(db.String(255)) + name = db.Column(db.String(255), nullable=False) + summary = db.Column(db.Text) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) + introduction = db.Column(db.Text) + system_instruction = db.Column(db.Text) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + status = db.Column(db.String(255), nullable=False) + invoke_from = db.Column(db.String(255), nullable=True) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id = db.Column(StringUUID) + from_account_id = db.Column(StringUUID) + read_at = db.Column(db.DateTime) + read_account_id = db.Column(StringUUID) + dialogue_count: Mapped[int] = mapped_column(default=0) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") + message_annotations = db.relationship( + "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + ) + + is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + + @property + def inputs(self): + inputs = self._inputs.copy() + + # Convert file mapping to File object + for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + + @property + def model_config(self): + model_config = {} + app_model_config: Optional[AppModelConfig] = None + + if self.mode == AppMode.ADVANCED_CHAT.value: + if self.override_model_configs: + override_model_configs = json.loads(self.override_model_configs) + model_config = override_model_configs + else: + if self.override_model_configs: + override_model_configs = json.loads(self.override_model_configs) + + if "model" in override_model_configs: + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" + model_config = app_model_config.to_dict() + else: + model_config["configs"] = override_model_configs + else: + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + ) + if app_model_config: + model_config = app_model_config.to_dict() + + model_config["model_id"] = self.model_id + model_config["provider"] = self.model_provider + + return model_config + + @property + def summary_or_query(self): + if self.summary: + return self.summary + else: + first_message = self.first_message + if first_message: + return first_message.query + else: + return "" + + @property + def annotated(self): + return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0 + + @property + def annotation(self): + return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first() + + @property + def message_count(self): + return db.session.query(Message).filter(Message.conversation_id == self.id).count() + + @property + def user_feedback_stats(self): + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) + .count() + ) + + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) + .count() + ) + + return {"like": like, "dislike": dislike} + + @property + def admin_feedback_stats(self): + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) + .count() + ) + + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) + .count() + ) + + return {"like": like, "dislike": dislike} + + @property + def status_count(self): + messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + status_counts = { + WorkflowRunStatus.RUNNING: 0, + WorkflowRunStatus.SUCCEEDED: 0, + WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, + WorkflowRunStatus.PARTIAL_SUCCESSED: 0, + } + + for message in messages: + if message.workflow_run: + status_counts[message.workflow_run.status] += 1 + + return ( + { + "success": status_counts[WorkflowRunStatus.SUCCEEDED], + "failed": status_counts[WorkflowRunStatus.FAILED], + "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCESSED], + } + if messages + else None + ) + + @property + def first_message(self): + return db.session.query(Message).filter(Message.conversation_id == self.id).first() + + @property + def app(self): + return db.session.query(App).filter(App.id == self.app_id).first() + + @property + def from_end_user_session_id(self): + if self.from_end_user_id: + end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first() + if end_user: + return end_user.session_id + + return None + + @property + def from_account_name(self): + if self.from_account_id: + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + if account: + return account.name + + return None + + @property + def in_debug_mode(self): + return self.override_model_configs is not None + + +class Message(db.Model): # type: ignore[name-defined] + __tablename__ = "messages" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_pkey"), + db.Index("message_app_id_idx", "app_id", "created_at"), + db.Index("message_conversation_id_idx", "conversation_id"), + db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), + db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), + db.Index("message_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + model_provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) + override_model_configs = db.Column(db.Text) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) + query: Mapped[str] = db.Column(db.Text, nullable=False) + message = db.Column(db.JSON, nullable=False) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + answer: Mapped[str] = db.Column(db.Text, nullable=False) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + parent_message_id = db.Column(StringUUID, nullable=True) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255), nullable=False) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + error = db.Column(db.Text) + message_metadata = db.Column(db.Text) + invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) + from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + workflow_run_id = db.Column(StringUUID) + + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + + @property + def re_sign_file_url_answer(self) -> str: + if not self.answer: + return self.answer + + pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" + matches = re.findall(pattern, self.answer) + + if not matches: + return self.answer + + urls = [match[0] for match in matches] + + # remove duplicate urls + urls = list(set(urls)) + + if not urls: + return self.answer + + re_sign_file_url_answer = self.answer + for url in urls: + if "files/tools" in url: + # get tool file id + tool_file_id_pattern = r"\/files\/tools\/([\.\w-]+)?\?timestamp=" + result = re.search(tool_file_id_pattern, url) + if not result: + continue + + tool_file_id = result.group(1) + + # get extension + if "." in tool_file_id: + split_result = tool_file_id.split(".") + extension = f".{split_result[-1]}" + if len(extension) > 10: + extension = ".bin" + tool_file_id = split_result[0] + else: + extension = ".bin" + + if not tool_file_id: + continue + + sign_url = ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=tool_file_id, extension=extension + ) + elif "file-preview" in url: + # get upload file id + upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" + result = re.search(upload_file_id_pattern, url) + if not result: + continue + + upload_file_id = result.group(1) + if not upload_file_id: + continue + sign_url = file_helpers.get_signed_file_url(upload_file_id) + elif "image-preview" in url: + # image-preview is deprecated, use file-preview instead + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + result = re.search(upload_file_id_pattern, url) + if not result: + continue + upload_file_id = result.group(1) + if not upload_file_id: + continue + sign_url = file_helpers.get_signed_file_url(upload_file_id) + else: + continue + + re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) + + return re_sign_file_url_answer + + @property + def user_feedback(self): + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .first() + ) + return feedback + + @property + def admin_feedback(self): + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .first() + ) + return feedback + + @property + def feedbacks(self): + feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all() + return feedbacks + + @property + def annotation(self): + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() + return annotation + + @property + def annotation_hit_history(self): + annotation_history = ( + db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + ) + if annotation_history: + annotation = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id == annotation_history.annotation_id) + .first() + ) + return annotation + return None + + @property + def app_model_config(self): + conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() + if conversation: + return ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() + ) + + return None + + @property + def in_debug_mode(self): + return self.override_model_configs is not None + + @property + def message_metadata_dict(self) -> dict: + return json.loads(self.message_metadata) if self.message_metadata else {} + + @property + def agent_thoughts(self): + return ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == self.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) + + @property + def retriever_resources(self): + return ( + db.session.query(DatasetRetrieverResource) + .filter(DatasetRetrieverResource.message_id == self.id) + .order_by(DatasetRetrieverResource.position.asc()) + .all() + ) + + @property + def message_files(self): + from factories import file_factory + + message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).filter(App.id == self.app_id).first() + if not current_app: + raise ValueError(f"App {self.app_id} not found") + + files = [] + for message_file in message_files: + if message_file.transfer_method == "local_file": + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "upload_file_id": message_file.upload_file_id, + "transfer_method": message_file.transfer_method, + "type": message_file.type, + }, + tenant_id=current_app.tenant_id, + ) + elif message_file.transfer_method == "remote_url": + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + ) + elif message_file.transfer_method == "tool_file": + if message_file.upload_file_id is None: + assert message_file.url is not None + message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] + mapping = { + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "tool_file_id": message_file.upload_file_id, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + ) + else: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) + files.append(file) + + result = [ + {"belongs_to": message_file.belongs_to, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ] + + db.session.commit() + return result + + @property + def workflow_run(self): + if self.workflow_run_id: + from .workflow import WorkflowRun + + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + return None + + def to_dict(self) -> dict: + return { + "id": self.id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "model_id": self.model_id, + "inputs": self.inputs, + "query": self.query, + "total_price": self.total_price, + "message": self.message, + "answer": self.answer, + "status": self.status, + "error": self.error, + "message_metadata": self.message_metadata_dict, + "from_source": self.from_source, + "from_end_user_id": self.from_end_user_id, + "from_account_id": self.from_account_id, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "agent_based": self.agent_based, + "workflow_run_id": self.workflow_run_id, + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data["id"], + app_id=data["app_id"], + conversation_id=data["conversation_id"], + model_id=data["model_id"], + inputs=data["inputs"], + total_price=data["total_price"], + query=data["query"], + message=data["message"], + answer=data["answer"], + status=data["status"], + error=data["error"], + message_metadata=json.dumps(data["message_metadata"]), + from_source=data["from_source"], + from_end_user_id=data["from_end_user_id"], + from_account_id=data["from_account_id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + agent_based=data["agent_based"], + workflow_run_id=data["workflow_run_id"], + ) + + +class MessageFeedback(db.Model): # type: ignore[name-defined] + __tablename__ = "message_feedbacks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + db.Index("message_feedback_app_idx", "app_id"), + db.Index("message_feedback_message_idx", "message_id", "from_source"), + db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + conversation_id = db.Column(StringUUID, nullable=False) + message_id = db.Column(StringUUID, nullable=False) + rating = db.Column(db.String(255), nullable=False) + content = db.Column(db.Text) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id = db.Column(StringUUID) + from_account_id = db.Column(StringUUID) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def from_account(self): + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + return account + + +class MessageFile(db.Model): # type: ignore[name-defined] + __tablename__ = "message_files" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_file_pkey"), + db.Index("message_file_message_idx", "message_id"), + db.Index("message_file_created_by_idx", "created_by"), + ) + + def __init__( + self, + *, + message_id: str, + type: FileType, + transfer_method: FileTransferMethod, + url: str | None = None, + belongs_to: Literal["user", "assistant"] | None = None, + upload_file_id: str | None = None, + created_by_role: CreatedByRole, + created_by: str, + ): + self.message_id = message_id + self.type = type + self.transfer_method = transfer_method + self.url = url + self.belongs_to = belongs_to + self.upload_file_id = upload_file_id + self.created_by_role = created_by_role.value + self.created_by = created_by + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class MessageAnnotation(db.Model): # type: ignore[name-defined] + __tablename__ = "message_annotations" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + db.Index("message_annotation_app_idx", "app_id"), + db.Index("message_annotation_conversation_idx", "conversation_id"), + db.Index("message_annotation_message_idx", "message_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) + message_id = db.Column(StringUUID, nullable=True) + question = db.Column(db.Text, nullable=True) + content = db.Column(db.Text, nullable=False) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + account_id = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def account(self): + account = db.session.query(Account).filter(Account.id == self.account_id).first() + return account + + @property + def annotation_create_account(self): + account = db.session.query(Account).filter(Account.id == self.account_id).first() + return account + + +class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] + __tablename__ = "app_annotation_hit_histories" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + db.Index("app_annotation_hit_histories_app_idx", "app_id"), + db.Index("app_annotation_hit_histories_account_idx", "account_id"), + db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + db.Index("app_annotation_hit_histories_message_idx", "message_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + annotation_id = db.Column(StringUUID, nullable=False) + source = db.Column(db.Text, nullable=False) + question = db.Column(db.Text, nullable=False) + account_id = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + score = db.Column(Float, nullable=False, server_default=db.text("0")) + message_id = db.Column(StringUUID, nullable=False) + annotation_question = db.Column(db.Text, nullable=False) + annotation_content = db.Column(db.Text, nullable=False) + + @property + def account(self): + account = ( + db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id) + .first() + ) + return account + + @property + def annotation_create_account(self): + account = db.session.query(Account).filter(Account.id == self.account_id).first() + return account + + +class AppAnnotationSetting(db.Model): # type: ignore[name-defined] + __tablename__ = "app_annotation_settings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + db.Index("app_annotation_settings_app_idx", "app_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) + collection_binding_id = db.Column(StringUUID, nullable=False) + created_user_id = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_user_id = db.Column(StringUUID, nullable=False) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def created_account(self): + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) + return account + + @property + def updated_account(self): + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) + return account + + @property + def collection_binding_detail(self): + from .dataset import DatasetCollectionBinding + + collection_binding_detail = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .first() + ) + return collection_binding_detail + + +class OperationLog(db.Model): # type: ignore[name-defined] + __tablename__ = "operation_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="operation_log_pkey"), + db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) + action = db.Column(db.String(255), nullable=False) + content = db.Column(db.JSON) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_ip = db.Column(db.String(255), nullable=False) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class EndUser(UserMixin, db.Model): # type: ignore[name-defined] + __tablename__ = "end_users" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="end_user_pkey"), + db.Index("end_user_session_id_idx", "session_id", "type"), + db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=True) + type = db.Column(db.String(255), nullable=False) + external_user_id = db.Column(db.String(255), nullable=True) + name = db.Column(db.String(255)) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + session_id: Mapped[str] = mapped_column() + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class Site(db.Model): # type: ignore[name-defined] + __tablename__ = "sites" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="site_pkey"), + db.Index("site_app_id_idx", "app_id"), + db.Index("site_code_idx", "code", "status"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + title = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) + icon = db.Column(db.String(255)) + icon_background = db.Column(db.String(255)) + description = db.Column(db.Text) + default_language = db.Column(db.String(255), nullable=False) + chat_color_theme = db.Column(db.String(255)) + chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = db.Column(db.String(255)) + privacy_policy = db.Column(db.String(255)) + show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") + customize_domain = db.Column(db.String(255)) + customize_token_strategy = db.Column(db.String(255), nullable=False) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + code = db.Column(db.String(255)) + + @property + def custom_disclaimer(self): + return self._custom_disclaimer + + @custom_disclaimer.setter + def custom_disclaimer(self, value: str): + if len(value) > 512: + raise ValueError("Custom disclaimer cannot exceed 512 characters.") + self._custom_disclaimer = value + + @staticmethod + def generate_code(n): + while True: + result = generate_string(n) + while db.session.query(Site).filter(Site.code == result).count() > 0: + result = generate_string(n) + + return result + + @property + def app_base_url(self): + return dify_config.APP_WEB_URL or request.url_root.rstrip("/") + + +class ApiToken(db.Model): # type: ignore[name-defined] + __tablename__ = "api_tokens" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="api_token_pkey"), + db.Index("api_token_app_id_type_idx", "app_id", "type"), + db.Index("api_token_token_idx", "token", "type"), + db.Index("api_token_tenant_idx", "tenant_id", "type"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=True) + tenant_id = db.Column(StringUUID, nullable=True) + type = db.Column(db.String(16), nullable=False) + token = db.Column(db.String(255), nullable=False) + last_used_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @staticmethod + def generate_api_key(prefix, n): + while True: + result = prefix + generate_string(n) + if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: + continue + return result + + +class UploadFile(db.Model): # type: ignore[name-defined] + __tablename__ = "upload_files" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="upload_file_pkey"), + db.Index("upload_file_tenant_idx", "tenant_id"), + ) + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) + key: Mapped[str] = db.Column(db.String(255), nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) + size: Mapped[int] = db.Column(db.Integer, nullable=False) + extension: Mapped[str] = db.Column(db.String(255), nullable=False) + mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) + created_by_role: Mapped[str] = db.Column( + db.String(255), nullable=False, server_default=db.text("'account'::character varying") + ) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) + hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + source_url: Mapped[str] = mapped_column(sa.TEXT, default="") + + def __init__( + self, + *, + tenant_id: str, + storage_type: str, + key: str, + name: str, + size: int, + extension: str, + mime_type: str, + created_by_role: CreatedByRole, + created_by: str, + created_at: datetime, + used: bool, + used_by: str | None = None, + used_at: datetime | None = None, + hash: str | None = None, + source_url: str = "", + ): + self.tenant_id = tenant_id + self.storage_type = storage_type + self.key = key + self.name = name + self.size = size + self.extension = extension + self.mime_type = mime_type + self.created_by_role = created_by_role.value + self.created_by = created_by + self.created_at = created_at + self.used = used + self.used_by = used_by + self.used_at = used_at + self.hash = hash + self.source_url = source_url + + +class ApiRequest(db.Model): # type: ignore[name-defined] + __tablename__ = "api_requests" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="api_request_pkey"), + db.Index("api_request_token_idx", "tenant_id", "api_token_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + api_token_id = db.Column(StringUUID, nullable=False) + path = db.Column(db.String(255), nullable=False) + request = db.Column(db.Text, nullable=True) + response = db.Column(db.Text, nullable=True) + ip = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class MessageChain(db.Model): # type: ignore[name-defined] + __tablename__ = "message_chains" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_chain_pkey"), + db.Index("message_chain_message_id_idx", "message_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = db.Column(StringUUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + input = db.Column(db.Text, nullable=True) + output = db.Column(db.Text, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class MessageAgentThought(db.Model): # type: ignore[name-defined] + __tablename__ = "message_agent_thoughts" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + db.Index("message_agent_thought_message_id_idx", "message_id"), + db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = db.Column(StringUUID, nullable=False) + message_chain_id = db.Column(StringUUID, nullable=True) + position = db.Column(db.Integer, nullable=False) + thought = db.Column(db.Text, nullable=True) + tool = db.Column(db.Text, nullable=True) + tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_input = db.Column(db.Text, nullable=True) + observation = db.Column(db.Text, nullable=True) + # plugin_id = db.Column(StringUUID, nullable=True) ## for future design + tool_process_data = db.Column(db.Text, nullable=True) + message = db.Column(db.Text, nullable=True) + message_token = db.Column(db.Integer, nullable=True) + message_unit_price = db.Column(db.Numeric, nullable=True) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + message_files = db.Column(db.Text, nullable=True) + answer = db.Column(db.Text, nullable=True) + answer_token = db.Column(db.Integer, nullable=True) + answer_unit_price = db.Column(db.Numeric, nullable=True) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + tokens = db.Column(db.Integer, nullable=True) + total_price = db.Column(db.Numeric, nullable=True) + currency = db.Column(db.String, nullable=True) + latency = db.Column(db.Float, nullable=True) + created_by_role = db.Column(db.String, nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + @property + def files(self) -> list: + if self.message_files: + return cast(list[Any], json.loads(self.message_files)) + else: + return [] + + @property + def tools(self) -> list[str]: + return self.tool.split(";") if self.tool else [] + + @property + def tool_labels(self) -> dict: + try: + if self.tool_labels_str: + return cast(dict, json.loads(self.tool_labels_str)) + else: + return {} + except Exception as e: + return {} + + @property + def tool_meta(self) -> dict: + try: + if self.tool_meta_str: + return cast(dict, json.loads(self.tool_meta_str)) + else: + return {} + except Exception as e: + return {} + + @property + def tool_inputs_dict(self) -> dict: + tools = self.tools + try: + if self.tool_input: + data = json.loads(self.tool_input) + result = {} + for tool in tools: + if tool in data: + result[tool] = data[tool] + else: + if len(tools) == 1: + result[tool] = data + else: + result[tool] = {} + return result + else: + return {tool: {} for tool in tools} + except Exception as e: + return {} + + @property + def tool_outputs_dict(self) -> dict: + tools = self.tools + try: + if self.observation: + data = json.loads(self.observation) + result = {} + for tool in tools: + if tool in data: + result[tool] = data[tool] + else: + if len(tools) == 1: + result[tool] = data + else: + result[tool] = {} + return result + else: + return {tool: {} for tool in tools} + except Exception as e: + if self.observation: + return dict.fromkeys(tools, self.observation) + else: + return {} + + +class DatasetRetrieverResource(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_retriever_resources" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + db.Index("dataset_retriever_resource_message_id_idx", "message_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + dataset_name = db.Column(db.Text, nullable=False) + document_id = db.Column(StringUUID, nullable=True) + document_name = db.Column(db.Text, nullable=False) + data_source_type = db.Column(db.Text, nullable=True) + segment_id = db.Column(StringUUID, nullable=True) + score = db.Column(db.Float, nullable=True) + content = db.Column(db.Text, nullable=False) + hit_count = db.Column(db.Integer, nullable=True) + word_count = db.Column(db.Integer, nullable=True) + segment_position = db.Column(db.Integer, nullable=True) + index_node_hash = db.Column(db.Text, nullable=True) + retriever_from = db.Column(db.Text, nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class Tag(db.Model): # type: ignore[name-defined] + __tablename__ = "tags" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tag_pkey"), + db.Index("tag_type_idx", "type"), + db.Index("tag_name_idx", "name"), + ) + + TAG_TYPE_LIST = ["knowledge", "app"] + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + type = db.Column(db.String(16), nullable=False) + name = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class TagBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "tag_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + db.Index("tag_bind_target_id_idx", "target_id"), + db.Index("tag_bind_tag_id_idx", "tag_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + tag_id = db.Column(StringUUID, nullable=True) + target_id = db.Column(StringUUID, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class TraceAppConfig(db.Model): # type: ignore[name-defined] + __tablename__ = "trace_app_config" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + db.Index("trace_app_config_app_id_idx", "app_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + tracing_provider = db.Column(db.String(255), nullable=True) + tracing_config = db.Column(db.JSON, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + + @property + def tracing_config_dict(self): + return self.tracing_config or {} + + @property + def tracing_config_str(self): + return json.dumps(self.tracing_config_dict) + + def to_dict(self): + return { + "id": self.id, + "app_id": self.app_id, + "tracing_provider": self.tracing_provider, + "tracing_config": self.tracing_config_dict, + "is_active": self.is_active, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, + } diff --git a/api/models/provider.py b/api/models/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..abe673975c1ccc0c39174498cb30e047df68c292 --- /dev/null +++ b/api/models/provider.py @@ -0,0 +1,215 @@ +from enum import Enum + +from sqlalchemy import func + +from .engine import db +from .types import StringUUID + + +class ProviderType(Enum): + CUSTOM = "custom" + SYSTEM = "system" + + @staticmethod + def value_of(value): + for member in ProviderType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class ProviderQuotaType(Enum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value): + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class Provider(db.Model): # type: ignore[name-defined] + """ + Provider model representing the API providers and their configurations. + """ + + __tablename__ = "providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="provider_pkey"), + db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" + ), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) + encrypted_config = db.Column(db.Text, nullable=True) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + last_used = db.Column(db.DateTime, nullable=True) + + quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) + quota_limit = db.Column(db.BigInteger, nullable=True) + quota_used = db.Column(db.BigInteger, default=0) + + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + def __repr__(self): + return ( + f"" + ) + + @property + def token_is_set(self): + """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ + return self.encrypted_config is not None + + @property + def is_enabled(self): + """ + Returns True if the provider is enabled. + """ + if self.provider_type == ProviderType.SYSTEM.value: + return self.is_valid + else: + return self.is_valid and self.token_is_set + + +class ProviderModel(db.Model): # type: ignore[name-defined] + """ + Provider model representing the API provider_models and their configurations. + """ + + __tablename__ = "provider_models" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="provider_model_pkey"), + db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" + ), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + encrypted_config = db.Column(db.Text, nullable=True) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class TenantDefaultModel(db.Model): # type: ignore[name-defined] + __tablename__ = "tenant_default_models" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined] + __tablename__ = "tenant_preferred_model_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + preferred_provider_type = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderOrder(db.Model): # type: ignore[name-defined] + __tablename__ = "provider_orders" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="provider_order_pkey"), + db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + account_id = db.Column(StringUUID, nullable=False) + payment_product_id = db.Column(db.String(191), nullable=False) + payment_id = db.Column(db.String(191)) + transaction_id = db.Column(db.String(191)) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) + currency = db.Column(db.String(40)) + total_amount = db.Column(db.Integer) + payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) + paid_at = db.Column(db.DateTime) + pay_failed_at = db.Column(db.DateTime) + refunded_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderModelSetting(db.Model): # type: ignore[name-defined] + """ + Provider model settings for record the model enabled status and load balancing status. + """ + + __tablename__ = "provider_model_settings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined] + """ + Configurations for load balancing models. + """ + + __tablename__ = "load_balancing_model_configs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + name = db.Column(db.String(255), nullable=False) + encrypted_config = db.Column(db.Text, nullable=True) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py new file mode 100644 index 0000000000000000000000000000000000000000..881cfaac7d399846f509a711447a208dbaf961ab --- /dev/null +++ b/api/models/source.py @@ -0,0 +1,55 @@ +import json + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB + +from .engine import db +from .types import StringUUID + + +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "data_source_oauth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="source_binding_pkey"), + db.Index("source_binding_tenant_id_idx", "tenant_id"), + db.Index("source_info_idx", "source_info", postgresql_using="gin"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + access_token = db.Column(db.String(255), nullable=False) + provider = db.Column(db.String(255), nullable=False) + source_info = db.Column(JSONB, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + + +class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "data_source_api_key_auth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + category = db.Column(db.String(255), nullable=False) + provider = db.Column(db.String(255), nullable=False) + credentials = db.Column(db.Text, nullable=True) # JSON + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "category": self.category, + "provider": self.provider, + "credentials": json.loads(self.credentials), + "created_at": self.created_at.timestamp(), + "updated_at": self.updated_at.timestamp(), + "disabled": self.disabled, + } diff --git a/api/models/task.py b/api/models/task.py new file mode 100644 index 0000000000000000000000000000000000000000..0db1c632299fcbb0097fdc3fb61cbf528190048a --- /dev/null +++ b/api/models/task.py @@ -0,0 +1,40 @@ +from datetime import UTC, datetime + +from celery import states # type: ignore + +from .engine import db + + +class CeleryTask(db.Model): # type: ignore[name-defined] + """Task result/status.""" + + __tablename__ = "celery_taskmeta" + + id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + task_id = db.Column(db.String(155), unique=True) + status = db.Column(db.String(50), default=states.PENDING) + result = db.Column(db.PickleType, nullable=True) + date_done = db.Column( + db.DateTime, + default=lambda: datetime.now(UTC).replace(tzinfo=None), + onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), + nullable=True, + ) + traceback = db.Column(db.Text, nullable=True) + name = db.Column(db.String(155), nullable=True) + args = db.Column(db.LargeBinary, nullable=True) + kwargs = db.Column(db.LargeBinary, nullable=True) + worker = db.Column(db.String(155), nullable=True) + retries = db.Column(db.Integer, nullable=True) + queue = db.Column(db.String(155), nullable=True) + + +class CeleryTaskSet(db.Model): # type: ignore[name-defined] + """TaskSet result.""" + + __tablename__ = "celery_tasksetmeta" + + id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) + taskset_id = db.Column(db.String(155), unique=True) + result = db.Column(db.PickleType, nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..13a112ee83b5136fcf721b3dc85366806e8e1e11 --- /dev/null +++ b/api/models/tools.py @@ -0,0 +1,325 @@ +import json +from typing import Any, Optional + +import sqlalchemy as sa +from sqlalchemy import ForeignKey, func +from sqlalchemy.orm import Mapped, mapped_column + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration + +from .engine import db +from .model import Account, App, Tenant +from .types import StringUUID + + +class BuiltinToolProvider(db.Model): # type: ignore[name-defined] + """ + This table stores the tool provider information for built-in tools for each tenant. + """ + + __tablename__ = "tool_builtin_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + # one tenant can only have one tool provider with the same name + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), + ) + + # id of the tool provider + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the tenant + tenant_id = db.Column(StringUUID, nullable=True) + # who created this tool provider + user_id = db.Column(StringUUID, nullable=False) + # name of the tool provider + provider = db.Column(db.String(40), nullable=False) + # credential of the tool provider + encrypted_credentials = db.Column(db.Text, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def credentials(self) -> dict: + return dict(json.loads(self.encrypted_credentials)) + + +class PublishedAppTool(db.Model): # type: ignore[name-defined] + """ + The table stores the apps published as a tool for each person. + """ + + __tablename__ = "tool_published_apps" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), + ) + + # id of the tool provider + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the app + app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) + # who published this tool + user_id = db.Column(StringUUID, nullable=False) + # description of the tool, stored in i18n format, for human + description = db.Column(db.Text, nullable=False) + # llm_description of the tool, for LLM + llm_description = db.Column(db.Text, nullable=False) + # query description, query will be seem as a parameter of the tool, + # to describe this parameter to llm, we need this field + query_description = db.Column(db.Text, nullable=False) + # query name, the name of the query parameter + query_name = db.Column(db.String(40), nullable=False) + # name of the tool provider + tool_name = db.Column(db.String(40), nullable=False) + # author + author = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def description_i18n(self) -> I18nObject: + return I18nObject(**json.loads(self.description)) + + @property + def app(self): + return db.session.query(App).filter(App.id == self.app_id).first() + + +class ApiToolProvider(db.Model): # type: ignore[name-defined] + """ + The table stores the api providers. + """ + + __tablename__ = "tool_api_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # name of the api provider + name = db.Column(db.String(40), nullable=False) + # icon + icon = db.Column(db.String(255), nullable=False) + # original schema + schema = db.Column(db.Text, nullable=False) + schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) + # who created this tool + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # description of the provider + description = db.Column(db.Text, nullable=False) + # json format tools + tools_str = db.Column(db.Text, nullable=False) + # json format credentials + credentials_str = db.Column(db.Text, nullable=False) + # privacy policy + privacy_policy = db.Column(db.String(255), nullable=True) + # custom_disclaimer + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") + + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def schema_type(self) -> ApiProviderSchemaType: + return ApiProviderSchemaType.value_of(self.schema_type_str) + + @property + def tools(self) -> list[ApiToolBundle]: + return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] + + @property + def credentials(self) -> dict: + return dict(json.loads(self.credentials_str)) + + @property + def user(self) -> Account | None: + return db.session.query(Account).filter(Account.id == self.user_id).first() + + @property + def tenant(self) -> Tenant | None: + return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + + +class ToolLabelBinding(db.Model): # type: ignore[name-defined] + """ + The table stores the labels for tools. + """ + + __tablename__ = "tool_label_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tool id + tool_id = db.Column(db.String(64), nullable=False) + # tool type + tool_type = db.Column(db.String(40), nullable=False) + # label name + label_name = db.Column(db.String(40), nullable=False) + + +class WorkflowToolProvider(db.Model): # type: ignore[name-defined] + """ + The table stores the workflow providers. + """ + + __tablename__ = "tool_workflow_providers" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # name of the workflow provider + name = db.Column(db.String(40), nullable=False) + # label of the workflow provider + label = db.Column(db.String(255), nullable=False, server_default="") + # icon + icon = db.Column(db.String(255), nullable=False) + # app id of the workflow provider + app_id = db.Column(StringUUID, nullable=False) + # version of the workflow provider + version = db.Column(db.String(255), nullable=False, server_default="") + # who created this tool + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # description of the provider + description = db.Column(db.Text, nullable=False) + # parameter configuration + parameter_configuration = db.Column(db.Text, nullable=False, server_default="[]") + # privacy policy + privacy_policy = db.Column(db.String(255), nullable=True, server_default="") + + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def user(self) -> Account | None: + return db.session.query(Account).filter(Account.id == self.user_id).first() + + @property + def tenant(self) -> Tenant | None: + return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + + @property + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: + return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + + @property + def app(self) -> App | None: + return db.session.query(App).filter(App.id == self.app_id).first() + + +class ToolModelInvoke(db.Model): # type: ignore[name-defined] + """ + store the invoke logs from tool invoke + """ + + __tablename__ = "tool_model_invokes" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # who invoke this tool + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # provider + provider = db.Column(db.String(40), nullable=False) + # type + tool_type = db.Column(db.String(40), nullable=False) + # tool name + tool_name = db.Column(db.String(40), nullable=False) + # invoke parameters + model_parameters = db.Column(db.Text, nullable=False) + # prompt messages + prompt_messages = db.Column(db.Text, nullable=False) + # invoke response + model_response = db.Column(db.Text, nullable=False) + + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ToolConversationVariables(db.Model): # type: ignore[name-defined] + """ + store the conversation variables from tool invoke + """ + + __tablename__ = "tool_conversation_variables" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), + # add index for user_id and conversation_id + db.Index("user_id_idx", "user_id"), + db.Index("conversation_id_idx", "conversation_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # conversation user id + user_id = db.Column(StringUUID, nullable=False) + # tenant id + tenant_id = db.Column(StringUUID, nullable=False) + # conversation id + conversation_id = db.Column(StringUUID, nullable=False) + # variables pool + variables_str = db.Column(db.Text, nullable=False) + + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def variables(self) -> Any: + return json.loads(self.variables_str) + + +class ToolFile(db.Model): # type: ignore[name-defined] + __tablename__ = "tool_files" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_file_pkey"), + db.Index("tool_file_conversation_id_idx", "conversation_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + file_key: Mapped[str] = db.Column(db.String(255), nullable=False) + mimetype: Mapped[str] = db.Column(db.String(255), nullable=False) + original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True) + name: Mapped[str] = mapped_column(default="") + size: Mapped[int] = mapped_column(default=-1) + + def __init__( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + file_key: str, + mimetype: str, + original_url: Optional[str] = None, + name: str, + size: int, + ): + self.user_id = user_id + self.tenant_id = tenant_id + self.conversation_id = conversation_id + self.file_key = file_key + self.mimetype = mimetype + self.original_url = original_url + self.name = name + self.size = size diff --git a/api/models/types.py b/api/models/types.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6773e70cdd5fa6400f2e912cc0fb35d7c6f535 --- /dev/null +++ b/api/models/types.py @@ -0,0 +1,26 @@ +from sqlalchemy import CHAR, TypeDecorator +from sqlalchemy.dialects.postgresql import UUID + + +class StringUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == "postgresql": + return str(value) + else: + return value.hex + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_result_value(self, value, dialect): + if value is None: + return value + return str(value) diff --git a/api/models/web.py b/api/models/web.py new file mode 100644 index 0000000000000000000000000000000000000000..864428fe0931b6a0f5dcc937dbe07e7eec15c4bb --- /dev/null +++ b/api/models/web.py @@ -0,0 +1,40 @@ +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + +from .engine import db +from .model import Message +from .types import StringUUID + + +class SavedMessage(db.Model): # type: ignore[name-defined] + __tablename__ = "saved_messages" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="saved_message_pkey"), + db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + message_id = db.Column(StringUUID, nullable=False) + created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def message(self): + return db.session.query(Message).filter(Message.id == self.message_id).first() + + +class PinnedConversation(db.Model): # type: ignore[name-defined] + __tablename__ = "pinned_conversations" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) + created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e1776ce887340b7af29e4da03a206a46fd7e6b07 --- /dev/null +++ b/api/models/workflow.py @@ -0,0 +1,810 @@ +import json +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime +from enum import Enum, StrEnum +from typing import TYPE_CHECKING, Any, Optional, Union + +import sqlalchemy as sa +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + +import contexts +from constants import HIDDEN_VALUE +from core.helper import encrypter +from core.variables import SecretVariable, Variable +from factories import variable_factory +from libs import helper +from models.enums import CreatedByRole + +from .account import Account +from .engine import db +from .types import StringUUID + +if TYPE_CHECKING: + from models.model import AppMode, Message + + +class WorkflowType(Enum): + """ + Workflow Type Enum + """ + + WORKFLOW = "workflow" + CHAT = "chat" + + @classmethod + def value_of(cls, value: str) -> "WorkflowType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid workflow type value {value}") + + @classmethod + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + """ + Get workflow type from app mode. + + :param app_mode: app mode + :return: workflow type + """ + from models.model import AppMode + + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) + return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT + + +class Workflow(db.Model): # type: ignore[name-defined] + """ + Workflow, for `Workflow App` and `Chat App workflow mode`. + + Attributes: + + - id (uuid) Workflow ID, pk + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - type (string) Workflow type + + `workflow` for `Workflow App` + + `chat` for `Chat App workflow mode` + + - version (string) Version + + `draft` for draft version (only one for each app), other for version number (redundant) + + - graph (text) Workflow canvas configuration (JSON) + + The entire canvas configuration JSON, including Node, Edge, and other configurations + + - nodes (array[object]) Node list, see Node Schema + + - edges (array[object]) Edge list, see Edge Schema + + - created_by (uuid) Creator ID + - created_at (timestamp) Creation time + - updated_by (uuid) `optional` Last updater ID + - updated_at (timestamp) `optional` Last update time + """ + + __tablename__ = "workflows" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_pkey"), + db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(db.String(255), nullable=False) + graph: Mapped[str] = mapped_column(sa.Text) + _features: Mapped[str] = mapped_column("features", sa.TEXT) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), + ) + _environment_variables: Mapped[str] = mapped_column( + "environment_variables", db.Text, nullable=False, server_default="{}" + ) + _conversation_variables: Mapped[str] = mapped_column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) + + def __init__( + self, + *, + tenant_id: str, + app_id: str, + type: str, + version: str, + graph: str, + features: str, + created_by: str, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ): + self.tenant_id = tenant_id + self.app_id = app_id + self.type = type + self.version = version + self.graph = graph + self.features = features + self.created_by = created_by + self.environment_variables = environment_variables or [] + self.conversation_variables = conversation_variables or [] + + @property + def created_by_account(self): + return db.session.get(Account, self.created_by) + + @property + def updated_by_account(self): + return db.session.get(Account, self.updated_by) if self.updated_by else None + + @property + def graph_dict(self) -> Mapping[str, Any]: + return json.loads(self.graph) if self.graph else {} + + @property + def features(self) -> str: + """ + Convert old features structure to new features structure. + """ + if not self._features: + return self._features + + features = json.loads(self._features) + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_enabled = True + image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = image_enabled + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = ["image"] + features["file_upload"]["allowed_file_extensions"] = [] + del features["file_upload"]["image"] + self._features = json.dumps(features) + return self._features + + @features.setter + def features(self, value: str) -> None: + self._features = value + + @property + def features_dict(self) -> dict[str, Any]: + return json.loads(self.features) if self.features else {} + + def user_input_form(self, to_old_structure: bool = False) -> list: + # get start node from graph + if not self.graph: + return [] + + graph_dict = self.graph_dict + if "nodes" not in graph_dict: + return [] + + start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) + if not start_node: + return [] + + # get user_input_form from start node + variables: list[Any] = start_node.get("data", {}).get("variables", []) + + if to_old_structure: + old_structure_variables = [] + for variable in variables: + old_structure_variables.append({variable["type"]: variable}) + + return old_structure_variables + + return variables + + @property + def unique_hash(self) -> str: + """ + Get hash of workflow. + + :return: hash + """ + entity = {"graph": self.graph_dict, "features": self.features_dict} + + return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) + + @property + def tool_published(self) -> bool: + from models.tools import WorkflowToolProvider + + return ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) + .count() + > 0 + ) + + @property + def environment_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._environment_variables` when instance created. + if self._environment_variables is None: + self._environment_variables = "{}" + + tenant_id = contexts.tenant_id.get() + + environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) + results = [ + variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() + ] + + # decrypt secret variables value + decrypt_func = ( + lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) + if isinstance(var, SecretVariable) + else var + ) + results = list(map(decrypt_func, results)) + return results + + @environment_variables.setter + def environment_variables(self, value: Sequence[Variable]): + if not value: + self._environment_variables = "{}" + return + + tenant_id = contexts.tenant_id.get() + + value = list(value) + if any(var for var in value if not var.id): + raise ValueError("environment variable require a unique id") + + # Compare inputs and origin variables, + # if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). + origin_variables_dictionary = {var.id: var for var in self.environment_variables} + for i, variable in enumerate(value): + if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: + value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) + + # encrypt secret variables value + encrypt_func = ( + lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) + if isinstance(var, SecretVariable) + else var + ) + encrypted_vars = list(map(encrypt_func, value)) + environment_variables_json = json.dumps( + {var.name: var.model_dump() for var in encrypted_vars}, + ensure_ascii=False, + ) + self._environment_variables = environment_variables_json + + def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: + environment_variables = list(self.environment_variables) + environment_variables = [ + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) + for v in environment_variables + ] + + result = { + "graph": self.graph_dict, + "features": self.features_dict, + "environment_variables": [var.model_dump(mode="json") for var in environment_variables], + "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + } + return result + + @property + def conversation_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = "{}" + + variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] + return results + + @conversation_variables.setter + def conversation_variables(self, value: Sequence[Variable]) -> None: + self._conversation_variables = json.dumps( + {var.name: var.model_dump() for var in value}, + ensure_ascii=False, + ) + + +class WorkflowRunStatus(StrEnum): + """ + Workflow Run Status Enum + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" + PARTIAL_SUCCESSED = "partial-succeeded" + + @classmethod + def value_of(cls, value: str) -> "WorkflowRunStatus": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid workflow run status value {value}") + + +class WorkflowRun(db.Model): # type: ignore[name-defined] + """ + Workflow Run + + Attributes: + + - id (uuid) Run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID + - type (string) Workflow type + - triggered_from (string) Trigger source + + `debugging` for canvas debugging + + `app-run` for (published) app execution + + - version (string) Version + - graph (text) Workflow canvas configuration (JSON) + - inputs (text) Input parameters + - status (string) Execution status, `running` / `succeeded` / `failed` / `stopped` + - outputs (text) `optional` Output content + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - total_tokens (int) `optional` Total tokens used + - total_steps (int) Total steps (redundant), default 0 + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Runner ID + - created_at (timestamp) Run time + - finished_at (timestamp) End time + """ + + __tablename__ = "workflow_runs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) + total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) + total_steps = db.Column(db.Integer, server_default=db.text("0")) + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at = db.Column(db.DateTime) + exceptions_count = db.Column(db.Integer, server_default=db.text("0")) + + @property + def created_by_account(self): + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + + @property + def graph_dict(self): + return json.loads(self.graph) if self.graph else {} + + @property + def inputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.inputs) if self.inputs else {} + + @property + def outputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.outputs) if self.outputs else {} + + @property + def message(self) -> Optional["Message"]: + from models.model import Message + + return ( + db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + ) + + @property + def workflow(self): + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "sequence_number": self.sequence_number, + "workflow_id": self.workflow_id, + "type": self.type, + "triggered_from": self.triggered_from, + "version": self.version, + "graph": self.graph_dict, + "inputs": self.inputs_dict, + "status": self.status, + "outputs": self.outputs_dict, + "error": self.error, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "total_steps": self.total_steps, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + "finished_at": self.finished_at, + "exceptions_count": self.exceptions_count, + } + + @classmethod + def from_dict(cls, data: dict) -> "WorkflowRun": + return cls( + id=data.get("id"), + tenant_id=data.get("tenant_id"), + app_id=data.get("app_id"), + sequence_number=data.get("sequence_number"), + workflow_id=data.get("workflow_id"), + type=data.get("type"), + triggered_from=data.get("triggered_from"), + version=data.get("version"), + graph=json.dumps(data.get("graph")), + inputs=json.dumps(data.get("inputs")), + status=data.get("status"), + outputs=json.dumps(data.get("outputs")), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + total_tokens=data.get("total_tokens"), + total_steps=data.get("total_steps"), + created_by_role=data.get("created_by_role"), + created_by=data.get("created_by"), + created_at=data.get("created_at"), + finished_at=data.get("finished_at"), + exceptions_count=data.get("exceptions_count"), + ) + + +class WorkflowNodeExecutionTriggeredFrom(Enum): + """ + Workflow Node Execution Triggered From Enum + """ + + SINGLE_STEP = "single-step" + WORKFLOW_RUN = "workflow-run" + + @classmethod + def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid workflow node execution triggered from value {value}") + + +class WorkflowNodeExecutionStatus(Enum): + """ + Workflow Node Execution Status Enum + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + EXCEPTION = "exception" + RETRY = "retry" + + @classmethod + def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid workflow node execution status value {value}") + + +class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] + """ + Workflow Node Execution + + - id (uuid) Execution ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Workflow ID + - triggered_from (string) Trigger source + + `single-step` for single-step debugging + + `workflow-run` for workflow execution (debugging / user execution) + + - workflow_run_id (uuid) `optional` Workflow run ID + + Null for single-step debugging. + + - index (int) Execution sequence number, used for displaying Tracing Node order + - predecessor_node_id (string) `optional` Predecessor node ID, used for displaying execution path + - node_id (string) Node ID + - node_type (string) Node type, such as `start` + - title (string) Node title + - inputs (json) All predecessor node variable content used in the node + - process_data (json) Node process data + - outputs (json) `optional` Node output variables + - status (string) Execution status, `running` / `succeeded` / `failed` + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - execution_metadata (text) Metadata + + - total_tokens (int) `optional` Total tokens used + + - total_price (decimal) `optional` Total cost + + - currency (string) `optional` Currency, such as USD / RMB + + - created_at (timestamp) Run time + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Runner ID + - finished_at (timestamp) End time + """ + + __tablename__ = "workflow_node_executions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + db.Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + db.Index( + "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" + ), + db.Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id: Mapped[str] = mapped_column(StringUUID) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + index: Mapped[int] = mapped_column(db.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_id: Mapped[str] = mapped_column(db.String(255)) + node_type: Mapped[str] = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(db.String(255)) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(db.Text) + outputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_by: Mapped[str] = mapped_column(StringUUID) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + + @property + def created_by_account(self): + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + + @property + def inputs_dict(self): + return json.loads(self.inputs) if self.inputs else None + + @property + def outputs_dict(self): + return json.loads(self.outputs) if self.outputs else None + + @property + def process_data_dict(self): + return json.loads(self.process_data) if self.process_data else None + + @property + def execution_metadata_dict(self): + return json.loads(self.execution_metadata) if self.execution_metadata else None + + @property + def extras(self): + from core.tools.tool_manager import ToolManager + + extras = {} + if self.execution_metadata_dict: + from core.workflow.nodes import NodeType + + if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + tool_info = self.execution_metadata_dict["tool_info"] + extras["icon"] = ToolManager.get_tool_icon( + tenant_id=self.tenant_id, + provider_type=tool_info["provider_type"], + provider_id=tool_info["provider_id"], + ) + + return extras + + +class WorkflowAppLogCreatedFrom(Enum): + """ + Workflow App Log Created From Enum + """ + + SERVICE_API = "service-api" + WEB_APP = "web-app" + INSTALLED_APP = "installed-app" + + @classmethod + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f"invalid workflow app log created from value {value}") + + +class WorkflowAppLog(db.Model): # type: ignore[name-defined] + """ + Workflow App execution log, excluding workflow debugging records. + + Attributes: + + - id (uuid) run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Associated Workflow ID + - workflow_run_id (uuid) Associated Workflow Run ID + - created_from (string) Creation source + + `service-api` App Execution OpenAPI + + `web-app` WebApp + + `installed-app` Installed App + + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Creator ID, depends on the user table according to created_by_role + - created_at (timestamp) Creation time + """ + + __tablename__ = "workflow_app_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) + created_from = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def workflow_run(self): + return db.session.get(WorkflowRun, self.workflow_run_id) + + @property + def created_by_account(self): + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + + +class ConversationVariable(db.Model): # type: ignore[name-defined] + __tablename__ = "workflow_conversation_variables" + + id: Mapped[str] = db.Column(StringUUID, primary_key=True) + conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) + data = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: + self.id = id + self.app_id = app_id + self.conversation_id = conversation_id + self.data = data + + @classmethod + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": + obj = cls( + id=variable.id, + app_id=app_id, + conversation_id=conversation_id, + data=variable.model_dump_json(), + ) + return obj + + def to_variable(self) -> Variable: + mapping = json.loads(self.data) + return variable_factory.build_conversation_variable_from_mapping(mapping) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py new file mode 100644 index 0000000000000000000000000000000000000000..9efe120b7a57fe755fe44be98c1dbbdb10347d2a --- /dev/null +++ b/api/schedule/clean_embedding_cache_task.py @@ -0,0 +1,42 @@ +import datetime +import time + +import click +from sqlalchemy import text +from werkzeug.exceptions import NotFound + +import app +from configs import dify_config +from extensions.ext_database import db +from models.dataset import Embedding + + +@app.celery.task(queue="dataset") +def clean_embedding_cache_task(): + click.echo(click.style("Start clean embedding cache.", fg="green")) + clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING) + start_at = time.perf_counter() + thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) + while True: + try: + embedding_ids = ( + db.session.query(Embedding.id) + .filter(Embedding.created_at < thirty_days_ago) + .order_by(Embedding.created_at.desc()) + .limit(100) + .all() + ) + embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] + except NotFound: + break + if embedding_ids: + for embedding_id in embedding_ids: + db.session.execute( + text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} + ) + + db.session.commit() + else: + break + end_at = time.perf_counter() + click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4d3ec323e41d8340f299ce83dac19a0dbd0804 --- /dev/null +++ b/api/schedule/clean_messages.py @@ -0,0 +1,81 @@ +import datetime +import time + +import click +from werkzeug.exceptions import NotFound + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import ( + App, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.feature_service import FeatureService + + +@app.celery.task(queue="dataset") +def clean_messages(): + click.echo(click.style("Start clean messages.", fg="green")) + start_at = time.perf_counter() + plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( + days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING + ) + while True: + try: + # Main query with join and filter + # FIXME:for mypy no paginate method error + messages = ( + db.session.query(Message) # type: ignore + .filter(Message.created_at < plan_sandbox_clean_message_day) + .order_by(Message.created_at.desc()) + .limit(100) + .all() + ) + + except NotFound: + break + if not messages: + break + for message in messages: + plan_sandbox_clean_message_day = message.created_at + app = App.query.filter_by(id=message.app_id).first() + features_cache_key = f"features:{app.tenant_id}" + plan_cache = redis_client.get(features_cache_key) + if plan_cache is None: + features = FeatureService.get_features(app.tenant_id) + redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) + plan = features.billing.subscription.plan + else: + plan = plan_cache.decode() + if plan == "sandbox": + # clean related message + db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(MessageFile).filter(MessageFile.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message.id).delete( + synchronize_session=False + ) + db.session.query(Message).filter(Message.id == message.id).delete() + db.session.commit() + end_at = time.perf_counter() + click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7e443c2c1ab0a54e4ed69598f8ed6f664fad1a --- /dev/null +++ b/api/schedule/clean_unused_datasets_task.py @@ -0,0 +1,188 @@ +import datetime +import time + +import click +from sqlalchemy import func +from werkzeug.exceptions import NotFound + +import app +from configs import dify_config +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document +from services.feature_service import FeatureService + + +@app.celery.task(queue="dataset") +def clean_unused_datasets_task(): + click.echo(click.style("Start clean unused datasets indexes.", fg="green")) + plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING + plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING + start_at = time.perf_counter() + plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting) + plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting) + while True: + try: + # Subquery for counting new documents + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > plan_sandbox_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Subquery for counting old documents + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < plan_sandbox_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Main query with join and filter + datasets = ( + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < plan_sandbox_clean_day, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=1, per_page=50) + ) + + except NotFound: + break + if datasets.items is None or len(datasets.items) == 0: + break + for dataset in datasets: + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) + .all() + ) + if not dataset_query or len(dataset_query) == 0: + try: + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) + # remove index + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + + # update document + update_params = {Document.enabled: False} + + Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.commit() + click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) + except Exception as e: + click.echo( + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) + while True: + try: + # Subquery for counting new documents + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > plan_pro_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Subquery for counting old documents + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < plan_pro_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Main query with join and filter + datasets = ( + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < plan_pro_clean_day, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=1, per_page=50) + ) + + except NotFound: + break + if datasets.items is None or len(datasets.items) == 0: + break + for dataset in datasets: + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) + .all() + ) + if not dataset_query or len(dataset_query) == 0: + try: + features_cache_key = f"features:{dataset.tenant_id}" + plan_cache = redis_client.get(features_cache_key) + if plan_cache is None: + features = FeatureService.get_features(dataset.tenant_id) + redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) + plan = features.billing.subscription.plan + else: + plan = plan_cache.decode() + if plan == "sandbox": + # remove index + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + + # update document + update_params = {Document.enabled: False} + + Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.commit() + click.echo( + click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") + ) + except Exception as e: + click.echo( + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) + end_at = time.perf_counter() + click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py new file mode 100644 index 0000000000000000000000000000000000000000..1c985461c6aa2ed18eacb40f8fa6bd3019d2f2f7 --- /dev/null +++ b/api/schedule/create_tidb_serverless_task.py @@ -0,0 +1,59 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def create_tidb_serverless_task(): + click.echo(click.style("Start create tidb serverless task.", fg="green")) + if not dify_config.CREATE_TIDB_SERVICE_JOB_ENABLED: + return + tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + if idle_tidb_serverless_number >= tidb_serverless_number: + break + # create tidb serverless + iterations_per_thread = 20 + create_clusters(iterations_per_thread) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) + + +def create_clusters(batch_size): + try: + # TODO: maybe we can set the default value for the following parameters in the config file + new_clusters = TidbService.batch_create_tidb_serverless_cluster( + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", + ) + for new_cluster in new_clusters: + tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + ) + db.session.add(tidb_auth_binding) + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6839288d85034ab0260d6bbe2eefa8f95ebbde --- /dev/null +++ b/api/schedule/mail_clean_document_notify_task.py @@ -0,0 +1,90 @@ +import logging +import time +from collections import defaultdict + +import click +from flask import render_template # type: ignore + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail +from models.account import Account, Tenant, TenantAccountJoin +from models.dataset import Dataset, DatasetAutoDisableLog +from services.feature_service import FeatureService + + +@app.celery.task(queue="dataset") +def send_document_clean_notify_task(): + """ + Async Send document clean notify mail + + Usage: send_document_clean_notify_task.delay() + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start send document clean notify mail", fg="green")) + start_at = time.perf_counter() + + # send document clean notify mail + try: + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != "sandbox": + knowledge_details = [] + # check tenant + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + # check current owner + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue + + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + html_content = render_template( + "clean_document_job_mail_template-US.html", + userName=account.email, + knowledge_details=knowledge_details, + url=url, + ) + mail.send( + to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + ) + + # update notified to True + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_disable_log.notified = True + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send document clean notify mail failed") diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py new file mode 100644 index 0000000000000000000000000000000000000000..11a39e60ee4ce55b625a2c66eb3c27142be094ea --- /dev/null +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -0,0 +1,49 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def update_tidb_serverless_status_task(): + click.echo(click.style("Update tidb serverless status task.", fg="green")) + start_at = time.perf_counter() + try: + # check the number of idle tidb serverless + tidb_serverless_list = TidbAuthBinding.query.filter( + TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" + ).all() + if len(tidb_serverless_list) == 0: + return + # update tidb serverless status + update_clusters(tidb_serverless_list) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + + end_at = time.perf_counter() + click.echo( + click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") + ) + + +def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): + try: + # batch 20 + for i in range(0, len(tidb_serverless_list), 20): + items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file + TidbService.batch_update_tidb_serverless_cluster_status( + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + ) + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/__init__.py b/api/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5163862cc12781ee028023d0ac4800c16af44c81 --- /dev/null +++ b/api/services/__init__.py @@ -0,0 +1,3 @@ +from . import errors + +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py new file mode 100644 index 0000000000000000000000000000000000000000..5388e1878ed8422962c64315674b5c2b399e4d2b --- /dev/null +++ b/api/services/account_service.py @@ -0,0 +1,1061 @@ +import base64 +import json +import logging +import random +import secrets +import uuid +from datetime import UTC, datetime, timedelta +from hashlib import sha256 +from typing import Any, Optional, cast + +from pydantic import BaseModel +from sqlalchemy import func +from werkzeug.exceptions import Unauthorized + +from configs import dify_config +from constants.languages import language_timezone_mapping, languages +from events.tenant_event import tenant_was_created +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.helper import RateLimiter, TokenManager +from libs.passport import PassportService +from libs.password import compare_password, hash_password, valid_password +from libs.rsa import generate_key_pair +from models.account import ( + Account, + AccountIntegrate, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from models.model import DifySetup +from services.billing_service import BillingService +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotFoundError, + AccountNotLinkTenantError, + AccountPasswordError, + AccountRegisterError, + CannotOperateSelfError, + CurrentPasswordIncorrectError, + InvalidActionError, + LinkAccountIntegrateError, + MemberNotInTenantError, + NoPermissionError, + RoleAlreadyAssignedError, + TenantNotFoundError, +) +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService +from tasks.delete_account_task import delete_account_task +from tasks.mail_account_deletion_task import send_account_deletion_verification_code +from tasks.mail_email_code_login import send_email_code_login_mail_task +from tasks.mail_invite_member_task import send_invite_member_mail_task +from tasks.mail_reset_password_task import send_reset_password_mail_task + + +class TokenPair(BaseModel): + access_token: str + refresh_token: str + + +REFRESH_TOKEN_PREFIX = "refresh_token:" +ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" +REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) + + +class AccountService: + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_code_login_rate_limiter = RateLimiter( + prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 + ) + email_code_account_deletion_rate_limiter = RateLimiter( + prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 + ) + LOGIN_MAX_ERROR_LIMITS = 5 + FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 + + @staticmethod + def _get_refresh_token_key(refresh_token: str) -> str: + return f"{REFRESH_TOKEN_PREFIX}{refresh_token}" + + @staticmethod + def _get_account_refresh_token_key(account_id: str) -> str: + return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" + + @staticmethod + def _store_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) + redis_client.setex( + AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token + ) + + @staticmethod + def _delete_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) + redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) + + @staticmethod + def load_user(user_id: str) -> None | Account: + account = Account.query.filter_by(id=user_id).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + if current_tenant: + account.current_tenant_id = current_tenant.tenant_id + else: + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) + if not available_ta: + return None + + account.current_tenant_id = available_ta.tenant_id + available_ta.current = True + db.session.commit() + + if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return cast(Account, account) + + @staticmethod + def get_account_jwt_token(account: Account) -> str: + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp = int(exp_dt.timestamp()) + payload = { + "user_id": account.id, + "exp": exp, + "iss": dify_config.EDITION, + "sub": "Console API Passport", + } + + token: str = PassportService().issue(payload) + return token + + @staticmethod + def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: + """authenticate account with email and password""" + + account = Account.query.filter_by(email=email).first() + if not account: + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if password and invite_token and account.password is None: + # if invite_token is valid, set password and password_salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + + db.session.commit() + + return cast(Account, account) + + @staticmethod + def update_account_password(account, password, new_password): + """update account password""" + if account.password and not compare_password(password, account.password, account.password_salt): + raise CurrentPasswordIncorrectError("Current password is incorrect.") + + # may be raised + valid_password(new_password) + + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + return account + + @staticmethod + def create_account( + email: str, + name: str, + interface_language: str, + password: Optional[str] = None, + interface_theme: str = "light", + is_setup: Optional[bool] = False, + ) -> Account: + """create account""" + if not FeatureService.get_system_features().is_allow_register and not is_setup: + from controllers.console.error import AccountNotFound + + raise AccountNotFound() + + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + + account = Account() + account.email = email + account.name = name + + if password: + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + + account.password = base64_password_hashed + account.password_salt = base64_salt + + account.interface_language = interface_language + account.interface_theme = interface_theme + + # Set timezone based on language + account.timezone = language_timezone_mapping.get(interface_language, "UTC") + + db.session.add(account) + db.session.commit() + return account + + @staticmethod + def create_account_and_tenant( + email: str, name: str, interface_language: str, password: Optional[str] = None + ) -> Account: + """create account""" + account = AccountService.create_account( + email=email, name=name, interface_language=interface_language, password=password + ) + + TenantService.create_owner_tenant_if_not_exist(account=account) + + return account + + @staticmethod + def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, token_type="account_deletion", additional_data={"code": code} + ) + return token, code + + @classmethod + def send_account_deletion_verification_email(cls, account: Account, code: str): + email = account.email + if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): + from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError + + raise EmailCodeAccountDeletionRateLimitExceededError() + + send_account_deletion_verification_code.delay(to=email, code=code) + + cls.email_code_account_deletion_rate_limiter.increment_rate_limit(email) + + @staticmethod + def verify_account_deletion_code(token: str, code: str) -> bool: + token_data = TokenManager.get_token_data(token, "account_deletion") + if token_data is None: + return False + + if token_data["code"] != code: + return False + + return True + + @staticmethod + def delete_account(account: Account) -> None: + """Delete account. This method only adds a task to the queue for deletion.""" + delete_account_task.delay(account.id) + + @staticmethod + def link_account_integrate(provider: str, open_id: str, account: Account) -> None: + """Link account integrate""" + try: + # Query whether there is an existing binding record for the same provider + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() + + if account_integrate: + # If it exists, update the record + account_integrate.open_id = open_id + account_integrate.encrypted_token = "" # todo + account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) + else: + # If it does not exist, create a new record + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) + db.session.add(account_integrate) + + db.session.commit() + logging.info(f"Account {account.id} linked {provider} account {open_id}.") + except Exception as e: + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e + + @staticmethod + def close_account(account: Account) -> None: + """Close account""" + account.status = AccountStatus.CLOSED.value + db.session.commit() + + @staticmethod + def update_account(account, **kwargs): + """Update account fields""" + for field, value in kwargs.items(): + if hasattr(account, field): + setattr(account, field, value) + else: + raise AttributeError(f"Invalid field: {field}") + + db.session.commit() + return account + + @staticmethod + def update_login_info(account: Account, *, ip_address: str) -> None: + """Update last login time and ip""" + account.last_login_at = datetime.now(UTC).replace(tzinfo=None) + account.last_login_ip = ip_address + db.session.add(account) + db.session.commit() + + @staticmethod + def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: + if ip_address: + AccountService.update_login_info(account=account, ip_address=ip_address) + + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + db.session.commit() + + access_token = AccountService.get_account_jwt_token(account=account) + refresh_token = _generate_refresh_token() + + AccountService._store_refresh_token(refresh_token, account.id) + + return TokenPair(access_token=access_token, refresh_token=refresh_token) + + @staticmethod + def logout(*, account: Account) -> None: + refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) + if refresh_token: + AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) + + @staticmethod + def refresh_token(refresh_token: str) -> TokenPair: + # Verify the refresh token + account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) + if not account_id: + raise ValueError("Invalid refresh token") + + account = AccountService.load_user(account_id.decode("utf-8")) + if not account: + raise ValueError("Invalid account") + + # Generate new access token and refresh token + new_access_token = AccountService.get_account_jwt_token(account) + new_refresh_token = _generate_refresh_token() + + AccountService._delete_refresh_token(refresh_token, account.id) + AccountService._store_refresh_token(new_refresh_token, account.id) + + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + + @staticmethod + def load_logged_in_account(*, account_id: str): + return AccountService.load_user(account_id) + + @classmethod + def send_reset_password_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.reset_password_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import PasswordResetRateLimitExceededError + + raise PasswordResetRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data={"code": code} + ) + send_reset_password_mail_task.delay( + language=language, + to=account_email, + code=code, + ) + cls.reset_password_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def revoke_reset_password_token(cls, token: str): + TokenManager.revoke_token(token, "reset_password") + + @classmethod + def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "reset_password") + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") + if cls.email_code_login_rate_limiter.is_rate_limited(email): + from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError + + raise EmailCodeLoginRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + cls.email_code_login_rate_limiter.increment_rate_limit(email) + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "email_code_login") + + @classmethod + def get_user_through_email(cls, email: str): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @staticmethod + def add_login_error_rate_limit(email: str) -> None: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count) + + @staticmethod + def is_login_error_rate_limit(email: str) -> bool: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + + count = int(count) + if count > AccountService.LOGIN_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + def reset_login_error_rate_limit(email: str): + key = f"login_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + def add_forgot_password_error_rate_limit(email: str) -> None: + key = f"forgot_password_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) + + @staticmethod + def is_forgot_password_error_rate_limit(email: str) -> bool: + key = f"forgot_password_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + + count = int(count) + if count > AccountService.FORGOT_PASSWORD_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + def reset_forgot_password_error_rate_limit(email: str): + key = f"forgot_password_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + def is_email_send_ip_limit(ip_address: str): + minute_key = f"email_send_ip_limit_minute:{ip_address}" + freeze_key = f"email_send_ip_limit_freeze:{ip_address}" + hour_limit_key = f"email_send_ip_limit_hour:{ip_address}" + + # check ip is frozen + if redis_client.get(freeze_key): + return True + + # check current minute count + current_minute_count = redis_client.get(minute_key) + if current_minute_count is None: + current_minute_count = 0 + current_minute_count = int(current_minute_count) + + # check current hour count + if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: + hour_limit_count = redis_client.get(hour_limit_key) + if hour_limit_count is None: + hour_limit_count = 0 + hour_limit_count = int(hour_limit_count) + + if hour_limit_count >= 1: + redis_client.setex(freeze_key, 60 * 60, 1) + return True + else: + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes + + # add hour limit count + redis_client.incr(hour_limit_key) + redis_client.expire(hour_limit_key, 60 * 60) + + return True + + redis_client.setex(minute_key, 60, current_minute_count + 1) + redis_client.expire(minute_key, 60) + + return False + + +def _get_login_cache_key(*, account_id: str, token: str): + return f"account_login:{account_id}:{token}" + + +class TenantService: + @staticmethod + def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: + """Create tenant""" + if ( + not FeatureService.get_system_features().is_allow_create_workspace + and not is_setup + and not is_from_dashboard + ): + from controllers.console.error import NotAllowedCreateWorkspace + + raise NotAllowedCreateWorkspace() + tenant = Tenant(name=name) + + db.session.add(tenant) + db.session.commit() + + tenant.encrypt_public_key = generate_key_pair(tenant.id) + db.session.commit() + return tenant + + @staticmethod + def create_owner_tenant_if_not_exist( + account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False + ): + """Check if user have a workspace or not""" + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) + + if available_ta: + return + + """Create owner tenant if not exist""" + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + raise WorkSpaceNotAllowedCreateError() + + if name: + tenant = TenantService.create_tenant(name=name, is_setup=is_setup) + else: + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + db.session.commit() + tenant_was_created.send(tenant) + + @staticmethod + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: + """Create tenant member""" + if role == TenantAccountJoinRole.OWNER.value: + if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") + + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + if ta: + ta.role = role + else: + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + db.session.add(ta) + + db.session.commit() + return ta + + @staticmethod + def get_join_tenants(account: Account) -> list[Tenant]: + """Get account join tenants""" + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) + + @staticmethod + def get_current_tenant_by_account(account: Account): + """Get tenant by account and add the role""" + tenant = account.current_tenant + if not tenant: + raise TenantNotFoundError("Tenant not found.") + + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + if ta: + tenant.role = ta.role + else: + raise TenantNotFoundError("Tenant not found for the account.") + return tenant + + @staticmethod + def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None: + """Switch the current workspace for the account""" + + # Ensure tenant_id is provided + if tenant_id is None: + raise ValueError("Tenant ID must be provided.") + + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) + + if not tenant_account_join: + raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") + else: + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) + tenant_account_join.current = True + # Set the current tenant for the account + account.current_tenant_id = tenant_account_join.tenant_id + db.session.commit() + + @staticmethod + def get_tenant_members(tenant: Tenant) -> list[Account]: + """Get tenant members""" + query = ( + db.session.query(Account, TenantAccountJoin.role) + .select_from(Account) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(TenantAccountJoin.tenant_id == tenant.id) + ) + + # Initialize an empty list to store the updated accounts + updated_accounts = [] + + for account, role in query: + account.role = role + updated_accounts.append(account) + + return updated_accounts + + @staticmethod + def get_dataset_operator_members(tenant: Tenant) -> list[Account]: + """Get dataset admin members""" + query = ( + db.session.query(Account, TenantAccountJoin.role) + .select_from(Account) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(TenantAccountJoin.tenant_id == tenant.id) + .filter(TenantAccountJoin.role == "dataset_operator") + ) + + # Initialize an empty list to store the updated accounts + updated_accounts = [] + + for account, role in query: + account.role = role + updated_accounts.append(account) + + return updated_accounts + + @staticmethod + def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: + """Check if user has any of the given roles for a tenant""" + if not all(isinstance(role, TenantAccountJoinRole) for role in roles): + raise ValueError("all roles must be TenantAccountJoinRole") + + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) + + @staticmethod + def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: + """Get the role of the current account for a given tenant""" + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) + return join.role if join else None + + @staticmethod + def get_tenant_count() -> int: + """Get tenant count""" + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) + + @staticmethod + def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: + """Check member permission""" + perms = { + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], + } + if action not in {"add", "remove", "update"}: + raise InvalidActionError("Invalid action.") + + if member: + if operator.id == member.id: + raise CannotOperateSelfError("Cannot operate self.") + + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() + + if not ta_operator or ta_operator.role not in perms[action]: + raise NoPermissionError(f"No permission to {action} member.") + + @staticmethod + def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: + """Remove member from tenant""" + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): + raise CannotOperateSelfError("Cannot operate self.") + + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + if not ta: + raise MemberNotInTenantError("Member not in tenant.") + + db.session.delete(ta) + db.session.commit() + + @staticmethod + def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: + """Update member role""" + TenantService.check_member_permission(tenant, operator, member, "update") + + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() + + if target_member_join.role == new_role: + raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") + + if new_role == "owner": + # Find the current owner and change their role to 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" + + # Update the role of the target member + target_member_join.role = new_role + db.session.commit() + + @staticmethod + def dissolve_tenant(tenant: Tenant, operator: Account) -> None: + """Dissolve tenant""" + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() + db.session.delete(tenant) + db.session.commit() + + @staticmethod + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() + + return cast(dict, tenant.custom_config_dict) + + +class RegisterService: + @classmethod + def _get_invitation_token_key(cls, token: str) -> str: + return f"member_invite:token:{token}" + + @classmethod + def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: + """ + Setup dify + + :param email: email + :param name: username + :param password: password + :param ip_address: ip address + """ + try: + # Register + account = AccountService.create_account( + email=email, + name=name, + interface_language=languages[0], + password=password, + is_setup=True, + ) + + account.last_login_ip = ip_address + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + + TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) + + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) + db.session.add(dify_setup) + db.session.commit() + except Exception as e: + db.session.query(DifySetup).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Account).delete() + db.session.query(Tenant).delete() + db.session.commit() + + logging.exception(f"Setup account failed, email: {email}, name: {name}") + raise ValueError(f"Setup failed: {e}") + + @classmethod + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, + ) -> Account: + db.session.begin_nested() + """Register account""" + try: + account = AccountService.create_account( + email=email, + name=name, + interface_language=language or languages[0], + password=password, + is_setup=is_setup, + ) + account.status = AccountStatus.ACTIVE.value if not status else status.value + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + + if open_id is not None and provider is not None: + AccountService.link_account_integrate(provider, open_id, account) + + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + + db.session.commit() + except WorkSpaceNotAllowedCreateError: + db.session.rollback() + except AccountRegisterError as are: + db.session.rollback() + logging.exception("Register failed") + raise are + except Exception as e: + db.session.rollback() + logging.exception("Register failed") + raise AccountRegisterError(f"Registration failed: {e}") from e + + return account + + @classmethod + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None + ) -> str: + """Invite new member""" + account = Account.query.filter_by(email=email).first() + assert inviter is not None, "Inviter must be provided." + + if not account: + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] + + account = cls.register( + email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + ) + # Create new tenant member for invited tenant + TenantService.create_tenant_member(tenant, account, role) + TenantService.switch_tenant(account, tenant.id) + else: + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + + if not ta: + TenantService.create_tenant_member(tenant, account, role) + + # Support resend invitation email when the account is pending status + if account.status != AccountStatus.PENDING.value: + raise AccountAlreadyInTenantError("Account already in tenant.") + + token = cls.generate_invite_token(tenant, account) + + # send email + send_invite_member_mail_task.delay( + language=account.interface_language, + to=email, + token=token, + inviter_name=inviter.name if inviter else "Dify", + workspace_name=tenant.name, + ) + + return token + + @classmethod + def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: + token = str(uuid.uuid4()) + invitation_data = { + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, + } + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) + return token + + @classmethod + def is_valid_invite_token(cls, token: str) -> bool: + data = redis_client.get(cls._get_invitation_token_key(token)) + return data is not None + + @classmethod + def revoke_token(cls, workspace_id: str, email: str, token: str): + if workspace_id and email: + email_hash = sha256(email.encode()).hexdigest() + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) + redis_client.delete(cache_key) + else: + redis_client.delete(cls._get_invitation_token_key(token)) + + @classmethod + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: + invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + if not invitation_data: + return None + + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) + + if not tenant: + return None + + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) + + if not tenant_account: + return None + + account = tenant_account[0] + if not account: + return None + + if invitation_data["account_id"] != str(account.id): + return None + + return { + "account": account, + "data": invitation_data, + "tenant": tenant, + } + + @classmethod + def _get_invitation_by_token( + cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None + ) -> Optional[dict[str, str]]: + if workspace_id is not None and email is not None: + email_hash = sha256(email.encode()).hexdigest() + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" + account_id = redis_client.get(cache_key) + + if not account_id: + return None + + return { + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, + } + else: + data = redis_client.get(cls._get_invitation_token_key(token)) + if not data: + return None + + invitation: dict = json.loads(data) + return invitation + + +def _generate_refresh_token(length: int = 64): + token = secrets.token_hex(length) + return token diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc1affa11d0366916335208f17342be80972cf3 --- /dev/null +++ b/api/services/advanced_prompt_template_service.py @@ -0,0 +1,97 @@ +import copy + +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) +from models.model import AppMode + + +class AdvancedPromptTemplateService: + @classmethod + def get_prompt(cls, args: dict) -> dict: + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] + + if "baichuan" in model_name.lower(): + return cls.get_baichuan_prompt(app_mode, model_mode, has_context) + else: + return cls.get_common_prompt(app_mode, model_mode, has_context) + + @classmethod + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: + context_prompt = copy.deepcopy(CONTEXT) + + if app_mode == AppMode.CHAT.value: + if model_mode == "completion": + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + elif model_mode == "chat": + return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) + elif app_mode == AppMode.COMPLETION.value: + if model_mode == "completion": + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + elif model_mode == "chat": + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + # default return empty dict + return {} + + @classmethod + def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + + return prompt_template + + @classmethod + def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + + return prompt_template + + @classmethod + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: + baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) + + if app_mode == AppMode.CHAT.value: + if model_mode == "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) + elif model_mode == "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) + elif app_mode == AppMode.COMPLETION.value: + if model_mode == "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + elif model_mode == "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b02f762ad267b830dbf6c07f6d70913756f80299 --- /dev/null +++ b/api/services/agent_service.py @@ -0,0 +1,146 @@ +from typing import Optional + +import pytz +from flask_login import current_user # type: ignore + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager +from core.tools.tool_manager import ToolManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message, MessageAgentThought + + +class AgentService: + @classmethod + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: + """ + Service to get agent logs + """ + conversation: Optional[Conversation] = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) + + if not conversation: + raise ValueError(f"Conversation not found: {conversation_id}") + + message: Optional[Message] = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) + + if not message: + raise ValueError(f"Message not found: {message_id}") + + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + + if conversation.from_end_user_id: + # only select name field + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) + else: + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + + if executor: + executor = executor.name + else: + executor = "Unknown" + + timezone = pytz.timezone(current_user.timezone) + + result = { + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), + }, + "iterations": [], + "files": message.message_files, + } + + agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) + if not agent_config: + return result + + agent_tools = agent_config.tools or [] + + def find_agent_tool(tool_name: str): + for agent_tool in agent_tools: + if agent_tool.tool_name == tool_name: + return agent_tool + + for agent_thought in agent_thoughts: + tools = agent_thought.tools + tool_labels = agent_thought.tool_labels + tool_meta = agent_thought.tool_meta + tool_inputs = agent_thought.tool_inputs_dict + tool_outputs = agent_thought.tool_outputs_dict + tool_calls = [] + for tool in tools: + tool_name = tool + tool_label = tool_labels.get(tool_name, tool_name) + tool_input = tool_inputs.get(tool_name, {}) + tool_output = tool_outputs.get(tool_name, {}) + tool_meta_data = tool_meta.get(tool_name, {}) + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": + tool_icon = ToolManager.get_tool_icon( + tenant_id=app_model.tenant_id, + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), + ) + if not tool_icon: + tool_entity = find_agent_tool(tool_name) + if tool_entity: + tool_icon = ToolManager.get_tool_icon( + tenant_id=app_model.tenant_id, + provider_type=tool_entity.provider_type, + provider_id=tool_entity.provider_id, + ) + else: + tool_icon = "" + + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) + + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) + + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..45ec1e9b5aec614d7c503326a3e4eeab1785c0bb --- /dev/null +++ b/api/services/annotation_service.py @@ -0,0 +1,444 @@ +import datetime +import uuid +from typing import cast + +import pandas as pd +from flask_login import current_user # type: ignore +from sqlalchemy import or_ +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from services.feature_service import FeatureService +from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task +from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task +from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task +from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task +from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task +from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task + + +class AppAnnotationService: + @classmethod + def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + if args.get("message_id"): + message_id = str(args["message_id"]) + # get message info + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() + + if not message: + raise NotFound("Message Not Exists.") + + annotation = message.annotation + # save the message annotation + if annotation: + annotation.content = args["answer"] + annotation.question = args["question"] + else: + annotation = MessageAnnotation( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content=args["answer"], + question=args["question"], + account_id=current_user.id, + ) + else: + annotation = MessageAnnotation( + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id + ) + db.session.add(annotation) + db.session.commit() + # if annotation reply is enabled , add annotation to index + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) + return cast(MessageAnnotation, annotation) + + @classmethod + def enable_app_annotation(cls, args: dict, app_id: str) -> dict: + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + cache_result = redis_client.get(enable_app_annotation_key) + if cache_result is not None: + return {"job_id": cache_result, "job_status": "processing"} + + # async job + job_id = str(uuid.uuid4()) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) + # send batch add segments task + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} + + @classmethod + def disable_app_annotation(cls, app_id: str) -> dict: + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + cache_result = redis_client.get(disable_app_annotation_key) + if cache_result is not None: + return {"job_id": cache_result, "job_status": "processing"} + + # async job + job_id = str(uuid.uuid4()) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) + # send batch add segments task + redis_client.setnx(disable_app_annotation_job_key, "waiting") + disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) + return {"job_id": job_id, "job_status": "waiting"} + + @classmethod + def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + if keyword: + annotations = ( + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) + ) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) + else: + annotations = ( + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) + return annotations.items, annotations.total + + @classmethod + def export_annotation_list_by_app_id(cls, app_id: str): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) + return annotations + + @classmethod + def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation = MessageAnnotation( + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id + ) + db.session.add(annotation) + db.session.commit() + # if annotation reply is enabled , add annotation to index + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) + return annotation + + @classmethod + def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + annotation.content = args["answer"] + annotation.question = args["question"] + + db.session.commit() + # if annotation reply is enabled , add annotation to index + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + + if app_annotation_setting: + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) + + return annotation + + @classmethod + def delete_app_annotation(cls, app_id: str, annotation_id: str): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + db.session.delete(annotation) + + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) + if annotation_hit_histories: + for annotation_hit_history in annotation_hit_histories: + db.session.delete(annotation_hit_history) + + db.session.commit() + # if annotation reply is enabled , delete annotation index + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + + if app_annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) + + @classmethod + def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + try: + # Skip the first row + df = pd.read_csv(file) + result = [] + for index, row in df.iterrows(): + content = {"question": row.iloc[0], "answer": row.iloc[1]} + result.append(content) + if len(result) == 0: + raise ValueError("The CSV file is empty.") + # check annotation limit + features = FeatureService.get_features(current_user.current_tenant_id) + if features.billing.enabled: + annotation_quota_limit = features.annotation_quota_limit + if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size: + raise ValueError("The number of annotations exceeds the limit of your subscription.") + # async job + job_id = str(uuid.uuid4()) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) + # send batch add segments task + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) + except Exception as e: + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} + + @classmethod + def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + raise NotFound("Annotation not found") + + annotation_hit_histories = ( + AppAnnotationHitHistory.query.filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) + return annotation_hit_histories.items, annotation_hit_histories.total + + @classmethod + def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + + if not annotation: + return None + return annotation + + @classmethod + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): + # add hit count to annotation + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False + ) + + annotation_hit_history = AppAnnotationHitHistory( + annotation_id=annotation_id, + app_id=app_id, + account_id=user_id, + question=query, + source=from_source, + score=score, + message_id=message_id, + annotation_question=annotation_question, + annotation_content=annotation_content, + ) + db.session.add(annotation_hit_history) + db.session.commit() + + @classmethod + def get_app_annotation_setting_by_app_id(cls, app_id: str): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + collection_binding_detail = annotation_setting.collection_binding_detail + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + return {"enabled": False} + + @classmethod + def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): + # get app info + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) + if not annotation_setting: + raise NotFound("App annotation not found") + annotation_setting.score_threshold = args["score_threshold"] + annotation_setting.updated_user_id = current_user.id + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(annotation_setting) + db.session.commit() + + collection_binding_detail = annotation_setting.collection_binding_detail + + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py new file mode 100644 index 0000000000000000000000000000000000000000..601d67d2fba4e30e5714f89ef24a04dac9d11536 --- /dev/null +++ b/api/services/api_based_extension_service.py @@ -0,0 +1,105 @@ +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from core.helper.encrypter import decrypt_token, encrypt_token +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint + + +class APIBasedExtensionService: + @staticmethod + def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) + + for extension in extension_list: + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension_list + + @classmethod + def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: + cls._validation(extension_data) + + extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) + + db.session.add(extension_data) + db.session.commit() + return extension_data + + @staticmethod + def delete(extension_data: APIBasedExtension) -> None: + db.session.delete(extension_data) + db.session.commit() + + @staticmethod + def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) + .first() + ) + + if not extension: + raise ValueError("API based extension is not found") + + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension + + @classmethod + def _validation(cls, extension_data: APIBasedExtension) -> None: + # name + if not extension_data.name: + raise ValueError("name must not be empty") + + if not extension_data.id: + # case one: check new data, name must be unique + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .first() + ) + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + else: + # case two: check existing data, name must be unique + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) + .first() + ) + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + + # api_endpoint + if not extension_data.api_endpoint: + raise ValueError("api_endpoint must not be empty") + + # api_key + if not extension_data.api_key: + raise ValueError("api_key must not be empty") + + if len(extension_data.api_key) < 5: + raise ValueError("api_key must be at least 5 characters") + + # check endpoint + cls._ping_connection(extension_data) + + @staticmethod + def _ping_connection(extension_data: APIBasedExtension) -> None: + try: + client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) + resp = client.request(point=APIBasedExtensionPoint.PING, params={}) + if resp.get("result") != "pong": + raise ValueError(resp) + except Exception as e: + raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py new file mode 100644 index 0000000000000000000000000000000000000000..15119247f887f0a4163e4e60dbdd7a64caadc2d3 --- /dev/null +++ b/api/services/app_dsl_service.py @@ -0,0 +1,494 @@ +import logging +import uuid +from enum import StrEnum +from typing import Optional +from urllib.parse import urlparse +from uuid import uuid4 + +import yaml # type: ignore +from packaging import version +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from events.app_event import app_model_config_was_updated, app_was_created +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account, App, AppMode +from models.model import AppModelConfig +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes +CURRENT_DSL_VERSION = "0.1.5" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class Import(BaseModel): + id: str + status: ImportStatus + app_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # Compare major version and minor version + if current_ver.major != imported_ver.major or current_ver.minor != imported_ver.minor: + return ImportStatus.PENDING + + if current_ver.micro != imported_ver.micro: + return ImportStatus.COMPLETED_WITH_WARNINGS + + return ImportStatus.COMPLETED + + +class PendingData(BaseModel): + import_mode: str + yaml_content: str + name: str | None + description: str | None + icon_type: str | None + icon: str | None + icon_background: str | None + app_id: str | None + + +class AppDslService: + def __init__(self, session: Session): + self._session = session + + def import_app( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + app_id: Optional[str] = None, + ) -> Import: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + max_size = 10 * 1024 * 1024 # 10MB + parsed_url = urlparse(yaml_url) + if ( + parsed_url.scheme == "https" + and parsed_url.netloc == "github.com" + and parsed_url.path.endswith((".yml", ".yaml")) + ): + yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") + yaml_url = yaml_url.replace("/blob/", "/") + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content.decode() + + if len(content) > max_size: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + except Exception as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "app": + data["kind"] = "app" + + imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract app data + app_data = data.get("app") + if not app_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Missing app data in YAML content", + ) + + # If app_id is provided, check if it exists + app = None + if app_id: + stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + if not app: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="App not found", + ) + + if app.mode not in [AppMode.WORKFLOW.value, AppMode.ADVANCED_CHAT.value]: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Only workflow or advanced chat apps can be overwritten", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + panding_data = PendingData( + import_mode=import_mode, + yaml_content=content, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + app_id=app_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + panding_data.model_dump_json(), + ) + + return Import( + id=import_id, + status=status, + app_id=app_id, + imported_dsl_version=imported_version, + ) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + ) + + return Import( + id=import_id, + status=status, + app_id=app.id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> Import: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = PendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + app = None + if pending_data.app_id: + stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=pending_data.name, + description=pending_data.description, + icon_type=pending_data.icon_type, + icon=pending_data.icon, + icon_background=pending_data.icon_background, + ) + + # Delete import info from Redis + redis_client.delete(redis_key) + + return Import( + id=import_id, + status=ImportStatus.COMPLETED, + app_id=app.id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def _create_or_update_app( + self, + *, + app: Optional[App], + data: dict, + account: Account, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + ) -> App: + """Create a new app or update an existing one.""" + app_data = data.get("app", {}) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) + + # Set icon type + icon_type_value = icon_type or app_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = icon or str(app_data.get("icon", "")) + + if app: + # Update existing app + app.name = name or app_data.get("name", app.name) + app.description = description or app_data.get("description", app.description) + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) + app.updated_by = account.id + else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + + # Create new app + app = App() + app.id = str(uuid4()) + app.tenant_id = account.current_tenant_id + app.mode = app_mode.value + app.name = name or app_data.get("name", "") + app.description = description or app_data.get("description", "") + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") + app.enable_site = True + app.enable_api = True + app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) + app.created_by = account.id + app.updated_by = account.id + + self._session.add(app) + self._session.commit() + app_was_created.send(app, account=account) + + # Initialize app based on mode + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig().from_model_config_dict(model_config) + app_model_config.id = str(uuid4()) + app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id + + app.app_model_config_id = app_model_config.id + + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + else: + raise ValueError("Invalid app mode") + return app + + @classmethod + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: + """ + Export app + :param app_model: App instance + :return: + """ + app_mode = AppMode.value_of(app_model.mode) + + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, + }, + } + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) + else: + cls._append_model_config_export_data(export_data, app_model) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param app_model: App instance + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: + """ + Append model config export data + :param export_data: export data + :param app_model: App instance + """ + app_model_config = app_model.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py new file mode 100644 index 0000000000000000000000000000000000000000..51aef7ccab9a0c6b0cf805ce423d904346c6434b --- /dev/null +++ b/api/services/app_generate_service.py @@ -0,0 +1,189 @@ +from collections.abc import Generator, Mapping +from typing import Any, Union + +from openai._exceptions import RateLimitError + +from configs import dify_config +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.rate_limiting import RateLimit +from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow +from services.errors.llm import InvokeRateLimitError +from services.workflow_service import WorkflowService + + +class AppGenerateService: + @classmethod + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + App Content Generate + :param app_model: app model + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + max_active_request = AppGenerateService._get_max_active_requests(app_model) + rate_limit = RateLimit(app_model.id, max_active_request) + request_id = RateLimit.gen_request_key() + try: + request_id = rate_limit.enter(request_id) + if app_model.mode == AppMode.COMPLETION.value: + return rate_limit.generate( + generator=CompletionAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + ), + request_id=request_id, + ) + elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + generator = AgentChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + ) + return rate_limit.generate( + generator=generator, + request_id=request_id, + ) + elif app_model.mode == AppMode.CHAT.value: + return rate_limit.generate( + generator=ChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + ), + request_id=request_id, + ) + elif app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, invoke_from) + return rate_limit.generate( + generator=AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + ), + request_id=request_id, + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, invoke_from) + generator = WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + workflow_thread_pool_id=None, + ) + return rate_limit.generate( + generator=generator, + request_id=request_id, + ) + else: + raise ValueError(f"Invalid app mode {app_model.mode}") + except RateLimitError as e: + raise InvokeRateLimitError(str(e)) + except Exception: + rate_limit.exit(request_id) + raise + finally: + if not streaming: + rate_limit.exit(request_id) + + @staticmethod + def _get_max_active_requests(app_model: App) -> int: + max_active_requests = app_model.max_active_requests + if max_active_requests is None: + max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) + return max_active_requests + + @classmethod + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + else: + raise ValueError(f"Invalid app mode {app_model.mode}") + + @classmethod + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[Mapping, Generator]: + """ + Generate more like this + :param app_model: app model + :param user: user + :param message_id: message id + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + return CompletionAppGenerator().generate_more_like_this( + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming + ) + + @classmethod + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: + """ + Get workflow + :param app_model: app model + :param invoke_from: invoke from + :return: + """ + workflow_service = WorkflowService() + if invoke_from == InvokeFrom.DEBUGGER: + # fetch draft workflow by app_model + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("Workflow not initialized") + else: + # fetch published workflow by app_model + workflow = workflow_service.get_published_workflow(app_model=app_model) + + if not workflow: + raise ValueError("Workflow not published") + + return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ad2710534a847f2332bedc6d71e7e70ad06462 --- /dev/null +++ b/api/services/app_model_config_service.py @@ -0,0 +1,17 @@ +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from models.model import AppMode + + +class AppModelConfigService: + @classmethod + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.CHAT: + return ChatAppConfigManager.config_validate(tenant_id, config) + elif app_mode == AppMode.AGENT_CHAT: + return AgentChatAppConfigManager.config_validate(tenant_id, config) + elif app_mode == AppMode.COMPLETION: + return CompletionAppConfigManager.config_validate(tenant_id, config) + else: + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/app_service.py b/api/services/app_service.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd7cb5e3359b4ab8d855af59b5b01af6d8f88ca --- /dev/null +++ b/api/services/app_service.py @@ -0,0 +1,376 @@ +import json +import logging +from datetime import UTC, datetime +from typing import Optional, cast + +from flask_login import current_user # type: ignore +from flask_sqlalchemy.pagination import Pagination + +from configs import dify_config +from constants.model_template import default_app_templates +from core.agent.entities import AgentToolEntity +from core.app.features.rate_limiting import RateLimit +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager +from events.app_event import app_was_created +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, AppModelConfig +from models.tools import ApiToolProvider +from services.tag_service import TagService +from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task + + +class AppService: + def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: + """ + Get app list with pagination + :param user_id: user id + :param tenant_id: tenant id + :param args: request args + :return: + """ + filters = [App.tenant_id == tenant_id, App.is_universal == False] + + if args["mode"] == "workflow": + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) + elif args["mode"] == "chat": + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args["mode"] == "agent-chat": + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args["mode"] == "channel": + filters.append(App.mode == AppMode.CHANNEL.value) + + if args.get("is_created_by_me", False): + filters.append(App.created_by == user_id) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) + if target_ids: + filters.append(App.id.in_(target_ids)) + else: + return None + + app_models = db.paginate( + db.select(App).where(*filters).order_by(App.created_at.desc()), + page=args["page"], + per_page=args["limit"], + error_out=False, + ) + + return app_models + + def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Create app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + app_mode = AppMode.value_of(args["mode"]) + app_template = default_app_templates[app_mode] + + # get model config + default_model_config = app_template.get("model_config") + default_model_config = default_model_config.copy() if default_model_config else None + if default_model_config and "model" in default_model_config: + # get model provider + model_manager = ModelManager() + + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except (ProviderTokenNotInitError, LLMBadRequestError): + model_instance = None + except Exception as e: + logging.exception(f"Get default model instance failed, tenant_id: {tenant_id}") + model_instance = None + + if model_instance: + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") + + default_model_dict = { + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, + } + else: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] + + default_model_config["model"] = json.dumps(default_model_dict) + + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] + app.tenant_id = tenant_id + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) + app.created_by = account.id + app.updated_by = account.id + + db.session.add(app) + db.session.flush() + + if default_model_config: + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id + + db.session.commit() + + app_was_created.send(app, account=account) + + return app + + def get_app(self, app: App) -> App: + """ + Get App + """ + # get original app model config + if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + model_config = app.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get("tools") or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) + # get tool + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + app_id=app.id, + agent_tool=agent_tool_entity, + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + identity_id=f"AGENT.{app.id}", + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool["tool_parameters"] = masked_parameter + except Exception as e: + pass + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + + class ModifiedApp(App): + """ + Modified App class + """ + + def __init__(self, app): + self.__dict__.update(app.__dict__) + + @property + def app_model_config(self): + return model_config + + app = ModifiedApp(app) + + return app + + def update_app(self, app: App, args: dict) -> App: + """ + Update app + :param app: App instance + :param args: request args + :return: App instance + """ + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") + app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.updated_by = current_user.id + app.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + if app.max_active_requests is not None: + rate_limit = RateLimit(app.id, app.max_active_requests) + rate_limit.flush_cache(use_local_value=True) + return app + + def update_app_name(self, app: App, name: str) -> App: + """ + Update app name + :param app: App instance + :param name: new name + :return: App instance + """ + app.name = name + app.updated_by = current_user.id + app.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return app + + def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + """ + Update app icon + :param app: App instance + :param icon: new icon + :param icon_background: new icon_background + :return: App instance + """ + app.icon = icon + app.icon_background = icon_background + app.updated_by = current_user.id + app.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return app + + def update_app_site_status(self, app: App, enable_site: bool) -> App: + """ + Update app site status + :param app: App instance + :param enable_site: enable site status + :return: App instance + """ + if enable_site == app.enable_site: + return app + + app.enable_site = enable_site + app.updated_by = current_user.id + app.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return app + + def update_app_api_status(self, app: App, enable_api: bool) -> App: + """ + Update app api status + :param app: App instance + :param enable_api: enable api status + :return: App instance + """ + if enable_api == app.enable_api: + return app + + app.enable_api = enable_api + app.updated_by = current_user.id + app.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return app + + def delete_app(self, app: App) -> None: + """ + Delete app + :param app: App instance + """ + db.session.delete(app) + db.session.commit() + + # Trigger asynchronous deletion of app and related data + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) + + def get_app_meta(self, app_model: App) -> dict: + """ + Get app meta info + :param app_model: app model + :return: + """ + app_mode = AppMode.value_of(app_model.mode) + + meta: dict = {"tool_icons": {}} + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow = app_model.workflow + if workflow is None: + return meta + + graph = workflow.graph_dict + nodes = graph.get("nodes", []) + tools = [] + for node in nodes: + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) + else: + app_model_config: Optional[AppModelConfig] = app_model.app_model_config + + if not app_model_config: + return meta + + agent_config = app_model_config.agent_mode_dict + + # get all tools + tools = agent_config.get("tools", []) + + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" + + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + # current tool standard + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": + try: + provider: Optional[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") + meta["tool_icons"][tool_name] = json.loads(provider.icon) + except: + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} + + return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py new file mode 100644 index 0000000000000000000000000000000000000000..294dfe4c8c3ceb0efba8b18c4c23c648fc65560f --- /dev/null +++ b/api/services/audio_service.py @@ -0,0 +1,161 @@ +import io +import logging +import uuid +from typing import Optional + +from werkzeug.datastructures import FileStorage + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from models.model import App, AppMode, AppModelConfig, Message +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + ProviderNotSupportTextToSpeechServiceError, + UnsupportedAudioTypeServiceError, +) + +FILE_SIZE = 30 +FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] + +logger = logging.getLogger(__name__) + + +class AudioService: + @classmethod + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app_model.workflow + if workflow is None: + raise ValueError("Speech to text is not enabled") + + features_dict = workflow.features_dict + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): + raise ValueError("Speech to text is not enabled") + else: + app_model_config: AppModelConfig = app_model.app_model_config + + if not app_model_config.speech_to_text_dict["enabled"]: + raise ValueError("Speech to text is not enabled") + + if file is None: + raise NoAudioUploadedServiceError() + + extension = file.mimetype + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: + raise UnsupportedAudioTypeServiceError() + + file_content = file.read() + file_size = len(file_content) + + if file_size > FILE_SIZE_LIMIT: + message = f"Audio size larger than {FILE_SIZE} mb" + raise AudioTooLargeServiceError(message) + + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT + ) + if model_instance is None: + raise ProviderNotSupportSpeechToTextServiceError() + + buffer = io.BytesIO(file_content) + buffer.name = "temp.mp3" + + return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + + @classmethod + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): + from collections.abc import Generator + + from flask import Response, stream_with_context + + from app import app + from extensions.ext_database import db + + def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None): + with app.app_context(): + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app_model.workflow + if workflow is None: + raise ValueError("TTS is not enabled") + + features_dict = workflow.features_dict + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): + raise ValueError("TTS is not enabled") + + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice + else: + if app_model.app_model_config is None: + raise ValueError("AppModelConfig not found") + text_to_speech_dict = app_model.app_model_config.text_to_speech_dict + + if not text_to_speech_dict.get("enabled"): + raise ValueError("TTS is not enabled") + + voice = text_to_speech_dict.get("voice") if voice is None else voice + + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, model_type=ModelType.TTS + ) + try: + if not voice: + voices = model_instance.get_tts_voices() + if voices: + voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") + else: + raise ValueError("Sorry, no voice available.") + + return model_instance.invoke_tts( + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice + ) + except Exception as e: + raise e + + if message_id: + try: + uuid.UUID(message_id) + except ValueError: + return None + message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None + if message.answer == "" and message.status == "normal": + return None + + else: + response = invoke_tts(message.answer, app_model=app_model, voice=voice) + if isinstance(response, Generator): + return Response(stream_with_context(response), content_type="audio/mpeg") + return response + else: + if text is None: + raise ValueError("Text is required") + response = invoke_tts(text, app_model, voice) + if isinstance(response, Generator): + return Response(stream_with_context(response), content_type="audio/mpeg") + return response + + @classmethod + def transcript_tts_voices(cls, tenant_id: str, language: str): + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) + if model_instance is None: + raise ProviderNotSupportTextToSpeechServiceError() + + try: + return model_instance.get_tts_voices(language) + except Exception as e: + raise e diff --git a/api/services/auth/__init__.py b/api/services/auth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py new file mode 100644 index 0000000000000000000000000000000000000000..dd74a8f1b539a08266206a07e4f10c349473347c --- /dev/null +++ b/api/services/auth/api_key_auth_base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class ApiKeyAuthBase(ABC): + def __init__(self, credentials: dict): + self.credentials = credentials + + @abstractmethod + def validate_credentials(self): + raise NotImplementedError diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f91c448fb94a23541c0ee03251a642d9b0a1feea --- /dev/null +++ b/api/services/auth/api_key_auth_factory.py @@ -0,0 +1,25 @@ +from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.auth_type import AuthType + + +class ApiKeyAuthFactory: + def __init__(self, provider: str, credentials: dict): + auth_factory = self.get_apikey_auth_factory(provider) + self.auth = auth_factory(credentials) + + def validate_credentials(self): + return self.auth.validate_credentials() + + @staticmethod + def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]: + match provider: + case AuthType.FIRECRAWL: + from services.auth.firecrawl.firecrawl import FirecrawlAuth + + return FirecrawlAuth + case AuthType.JINA: + from services.auth.jina.jina import JinaAuth + + return JinaAuth + case _: + raise ValueError("Invalid provider") diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f4a3ef6e12d3a59a4462367ce8239cf6834089 --- /dev/null +++ b/api/services/auth/api_key_auth_service.py @@ -0,0 +1,74 @@ +import json + +from core.helper import encrypter +from extensions.ext_database import db +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory + + +class ApiKeyAuthService: + @staticmethod + def get_provider_auth_list(tenant_id: str) -> list: + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) + return data_source_api_key_bindings + + @staticmethod + def create_provider_auth(tenant_id: str, args: dict): + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() + if auth_result: + # Encrypt the api key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key + + data_source_api_key_binding = DataSourceApiKeyAuthBinding() + data_source_api_key_binding.tenant_id = tenant_id + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) + db.session.add(data_source_api_key_binding) + db.session.commit() + + @staticmethod + def get_auth_credentials(tenant_id: str, category: str, provider: str): + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) + if not data_source_api_key_bindings: + return None + credentials = json.loads(data_source_api_key_bindings.credentials) + return credentials + + @staticmethod + def delete_provider_auth(tenant_id: str, binding_id: str): + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) + if data_source_api_key_binding: + db.session.delete(data_source_api_key_binding) + db.session.commit() + + @classmethod + def validate_api_key_auth_args(cls, args): + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1946841fbd159c3b52c4ea93f167dadd3ff0c5 --- /dev/null +++ b/api/services/auth/auth_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class AuthType(StrEnum): + FIRECRAWL = "firecrawl" + JINA = "jinareader" diff --git a/api/services/auth/firecrawl/__init__.py b/api/services/auth/firecrawl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef034f2920a6f4c6cbb624a611665edb15eb4b5 --- /dev/null +++ b/api/services/auth/firecrawl/firecrawl.py @@ -0,0 +1,49 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class FirecrawlAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": True}, + } + response = self._post_request(f"{self.base_url}/v1/crawl", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py new file mode 100644 index 0000000000000000000000000000000000000000..6100e9afc8f278b20385ca79af16aebaad70c762 --- /dev/null +++ b/api/services/auth/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config", {}).get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/auth/jina/__init__.py b/api/services/auth/jina/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py new file mode 100644 index 0000000000000000000000000000000000000000..6100e9afc8f278b20385ca79af16aebaad70c762 --- /dev/null +++ b/api/services/auth/jina/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config", {}).get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0d50a2aa8c34ffee6d333b23da8827f4a072c351 --- /dev/null +++ b/api/services/billing_service.py @@ -0,0 +1,93 @@ +import os +from typing import Literal, Optional + +import httpx +from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed + +from extensions.ext_database import db +from models.account import TenantAccountJoin, TenantAccountRole + + +class BillingService: + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") + + @classmethod + def get_info(cls, tenant_id: str): + params = {"tenant_id": tenant_id} + + billing_info = cls._send_request("GET", "/subscription/info", params=params) + return billing_info + + @classmethod + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) + + @classmethod + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): + params = { + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, + } + return cls._send_request("GET", "/model-provider/payment-link", params=params) + + @classmethod + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) + + @classmethod + @retry( + wait=wait_fixed(2), + stop=stop_before_delay(10), + retry=retry_if_exception_type(httpx.RequestError), + reraise=True, + ) + def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} + + url = f"{cls.base_url}{endpoint}" + response = httpx.request(method, url, json=json, params=params, headers=headers) + if method == "GET" and response.status_code != httpx.codes.OK: + raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") + return response.json() + + @staticmethod + def is_tenant_owner_or_admin(current_user): + tenant_id = current_user.current_tenant_id + + join: Optional[TenantAccountJoin] = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + + if not join: + raise ValueError("Tenant account join not found") + + if not TenantAccountRole.is_privileged_role(join.role): + raise ValueError("Only team owner or team admin can perform this action") + + @classmethod + def delete_account(cls, account_id: str): + """Delete account.""" + params = {"account_id": account_id} + return cls._send_request("DELETE", "/account/", params=params) + + @classmethod + def is_email_in_freeze(cls, email: str) -> bool: + params = {"email": email} + try: + response = cls._send_request("GET", "/account/in-freeze", params=params) + return bool(response.get("data", False)) + except Exception: + return False + + @classmethod + def update_account_deletion_feedback(cls, email: str, feedback: str): + """Update account deletion feedback.""" + json = {"email": email, "feedback": feedback} + return cls._send_request("POST", "/account/delete-feedback", json=json) diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py new file mode 100644 index 0000000000000000000000000000000000000000..f7597b7f1fcd45c44019f4bde41b2bf42c82c3da --- /dev/null +++ b/api/services/code_based_extension_service.py @@ -0,0 +1,16 @@ +from extensions.ext_code_based_extension import code_based_extension + + +class CodeBasedExtensionService: + @staticmethod + def get_code_based_extension(module: str) -> list[dict]: + module_extensions = code_based_extension.module_extensions(module) + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..6485cbf37d5b7f19439f22177440d57662c227ad --- /dev/null +++ b/api/services/conversation_service.py @@ -0,0 +1,168 @@ +from collections.abc import Callable, Sequence +from datetime import UTC, datetime +from typing import Optional, Union + +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.llm_generator.llm_generator import LLMGenerator +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.account import Account +from models.model import App, Conversation, EndUser, Message +from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError +from services.errors.message import MessageNotExistsError + + +class ConversationService: + @classmethod + def pagination_by_last_id( + cls, + *, + session: Session, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + stmt = select(Conversation).where( + Conversation.is_deleted == False, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), + ) + if include_ids is not None: + stmt = stmt.where(Conversation.id.in_(include_ids)) + if exclude_ids is not None: + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) + + # define sort fields and directions + sort_field, sort_direction = cls._get_sort_params(sort_by) + + if last_id: + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) + if not last_conversation: + raise LastConversationNotExistsError() + + # build filters based on sorting + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() + + has_more = False + if len(conversations) == limit: + current_page_last_conversation = conversations[-1] + rest_filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, + ) + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) + rest_count = session.scalar(count_stmt) or 0 + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) + + @classmethod + def _get_sort_params(cls, sort_by: str): + if sort_by.startswith("-"): + return sort_by[1:], desc + return sort_by, asc + + @classmethod + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): + field_value = getattr(reference_conversation, sort_field) + if sort_direction == desc: + return getattr(Conversation, sort_field) < field_value + else: + return getattr(Conversation, sort_field) > field_value + + @classmethod + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): + conversation = cls.get_conversation(app_model, conversation_id, user) + + if auto_generate: + return cls.auto_generate_name(app_model, conversation) + else: + conversation.name = name + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return conversation + + @classmethod + def auto_generate_name(cls, app_model: App, conversation: Conversation): + # get conversation first message + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) + + if not message: + raise MessageNotExistsError() + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name( + app_model.tenant_id, message.query, conversation.id, app_model.id + ) + conversation.name = name + except: + pass + + db.session.commit() + + return conversation + + @classmethod + def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) + + if not conversation: + raise ConversationNotExistsError() + + return conversation + + @classmethod + def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + conversation = cls.get_conversation(app_model, conversation_id, user) + + conversation.is_deleted = True + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py new file mode 100644 index 0000000000000000000000000000000000000000..38025b5213aaa3512a5545771270d1a3590f2a6e --- /dev/null +++ b/api/services/dataset_service.py @@ -0,0 +1,2144 @@ +import datetime +import json +import logging +import random +import time +import uuid +from collections import Counter +from typing import Any, Optional + +from flask_login import current_user # type: ignore +from sqlalchemy import func +from werkzeug.exceptions import NotFound + +from configs import dify_config +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from events.dataset_event import dataset_was_deleted +from events.document_event import document_was_deleted +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs import helper +from models.account import Account, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + ChildChunk, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + ExternalKnowledgeBindings, +) +from models.model import UploadFile +from models.source import DataSourceOauthBinding +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + KnowledgeConfig, + MetaDataConfig, + RerankingModel, + RetrievalModel, + SegmentUpdateArgs, +) +from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.dataset import DatasetNameDuplicateError +from services.errors.document import DocumentIndexingError +from services.errors.file import FileNotExistsError +from services.external_knowledge_service import ExternalDatasetService +from services.feature_service import FeatureModel, FeatureService +from services.tag_service import TagService +from services.vector_service import VectorService +from tasks.batch_clean_document_task import batch_clean_document_task +from tasks.clean_notion_document_task import clean_notion_document_task +from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task +from tasks.delete_segment_from_index_task import delete_segment_from_index_task +from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.disable_segments_from_index_task import disable_segments_from_index_task +from tasks.document_indexing_task import document_indexing_task +from tasks.document_indexing_update_task import document_indexing_update_task +from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task +from tasks.enable_segments_to_index_task import enable_segments_to_index_task +from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.retry_document_indexing_task import retry_document_indexing_task +from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task + + +class DatasetService: + @staticmethod + def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): + query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + + if user: + # get permitted dataset ids + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() + permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None + + if user.current_role == TenantAccountRole.DATASET_OPERATOR: + # only show datasets that the user has permission to access + if permitted_dataset_ids: + query = query.filter(Dataset.id.in_(permitted_dataset_ids)) + else: + return [], 0 + else: + if user.current_role != TenantAccountRole.OWNER or not include_all: + # show all datasets that the user has permission to access + if permitted_dataset_ids: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), + ) + ) + else: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + ) + ) + else: + # if no user, only show datasets that are shared with all team members + query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) + + if search: + query = query.filter(Dataset.name.ilike(f"%{search}%")) + + if tag_ids: + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) + if target_ids: + query = query.filter(Dataset.id.in_(target_ids)) + else: + return [], 0 + + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + + return datasets.items, datasets.total + + @staticmethod + def get_process_rules(dataset_id): + # get the latest process rule + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) + if dataset_process_rule: + mode = dataset_process_rule.mode + rules = dataset_process_rule.rules_dict + else: + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} + + @staticmethod + def get_datasets_by_ids(ids, tenant_id): + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( + page=1, per_page=len(ids), max_per_page=len(ids), error_out=False + ) + return datasets.items, datasets.total + + @staticmethod + def create_empty_dataset( + tenant_id: str, + name: str, + description: Optional[str], + indexing_technique: Optional[str], + account: Account, + permission: Optional[str] = None, + provider: str = "vendor", + external_knowledge_api_id: Optional[str] = None, + external_knowledge_id: Optional[str] = None, + ): + # check if dataset name already exists + if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") + embedding_model = None + if indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_default_model_instance( + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset = Dataset(name=name, indexing_technique=indexing_technique) + # dataset = Dataset(name=name, provider=provider, config=config) + dataset.description = description + dataset.created_by = account.id + dataset.updated_by = account.id + dataset.tenant_id = tenant_id + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None + dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.provider = provider + db.session.add(dataset) + db.session.flush() + + if provider == "external" and external_knowledge_api_id: + external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) + if not external_knowledge_api: + raise ValueError("External API template not found.") + external_knowledge_binding = ExternalKnowledgeBindings( + tenant_id=tenant_id, + dataset_id=dataset.id, + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=external_knowledge_id, + created_by=account.id, + ) + db.session.add(external_knowledge_binding) + + db.session.commit() + return dataset + + @staticmethod + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset + + @staticmethod + def check_dataset_model_setting(dataset): + if dataset.indexing_technique == "high_quality": + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") + + @staticmethod + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, + provider=embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embedding_model, + ) + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") + + @staticmethod + def update_dataset(dataset_id, data, user): + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") + + DatasetService.check_dataset_permission(dataset, user) + if dataset.provider == "external": + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", "") + permission = data.get("permission") + if permission: + dataset.permission = permission + external_knowledge_id = data.get("external_knowledge_id", None) + db.session.add(dataset) + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first() + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + db.session.commit() + else: + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + action = None + if dataset.indexing_technique != data["indexing_technique"]: + # if update indexing_technique + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() + + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + + dataset.query.filter_by(id=dataset_id).update(filtered_data) + + db.session.commit() + if action: + deal_dataset_vector_index_task.delay(dataset_id, action) + return dataset + + @staticmethod + def delete_dataset(dataset_id, user): + dataset = DatasetService.get_dataset(dataset_id) + + if dataset is None: + return False + + DatasetService.check_dataset_permission(dataset, user) + + dataset_was_deleted.send(dataset) + + db.session.delete(dataset) + db.session.commit() + return True + + @staticmethod + def dataset_use_check(dataset_id) -> bool: + count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() + if count > 0: + return True + return False + + @staticmethod + def check_dataset_permission(dataset, user): + if dataset.tenant_id != user.current_tenant_id: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if user.current_role != TenantAccountRole.OWNER: + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() + if ( + not user_permission + and dataset.tenant_id != user.current_tenant_id + and dataset.created_by != user.id + ): + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + + @staticmethod + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") + + if not user: + raise ValueError("User not found") + + if user.current_role != TenantAccountRole.OWNER: + if dataset.permission == DatasetPermissionEnum.ONLY_ME: + if dataset.created_by != user.id: + raise NoPermissionError("You do not have permission to access this dataset.") + + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: + if not any( + dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() + ): + raise NoPermissionError("You do not have permission to access this dataset.") + + @staticmethod + def get_dataset_queries(dataset_id: str, page: int, per_page: int): + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + ) + return dataset_queries.items, dataset_queries.total + + @staticmethod + def get_related_apps(dataset_id: str): + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) + + @staticmethod + def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + features = FeatureService.get_features(current_user.current_tenant_id) + if not features.billing.enabled or features.billing.subscription.plan == "sandbox": + return { + "document_ids": [], + "count": 0, + } + # get recent 30 days auto disable logs + start_date = datetime.datetime.now() - datetime.timedelta(days=30) + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ).all() + if dataset_auto_disable_logs: + return { + "document_ids": [log.document_id for log in dataset_auto_disable_logs], + "count": len(dataset_auto_disable_logs), + } + return { + "document_ids": [], + "count": 0, + } + + +class DocumentService: + DEFAULT_RULES: dict[str, Any] = { + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, + ], + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, + "limits": { + "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH, + }, + } + + DOCUMENT_METADATA_SCHEMA: dict[str, Any] = { + "book": { + "title": str, + "language": str, + "author": str, + "publisher": str, + "publication_date": str, + "isbn": str, + "category": str, + }, + "web_page": { + "title": str, + "url": str, + "language": str, + "publish_date": str, + "author/publisher": str, + "topic/keywords": str, + "description": str, + }, + "paper": { + "title": str, + "language": str, + "author": str, + "publish_date": str, + "journal/conference_name": str, + "volume/issue/page_numbers": str, + "doi": str, + "topic/keywords": str, + "abstract": str, + }, + "social_media_post": { + "platform": str, + "author/username": str, + "publish_date": str, + "post_url": str, + "topic/tags": str, + }, + "wikipedia_entry": { + "title": str, + "language": str, + "web_page_url": str, + "last_edit_date": str, + "editor/contributor": str, + "summary/introduction": str, + }, + "personal_document": { + "title": str, + "author": str, + "creation_date": str, + "last_modified_date": str, + "document_type": str, + "tags/category": str, + }, + "business_document": { + "title": str, + "author": str, + "creation_date": str, + "last_modified_date": str, + "document_type": str, + "department/team": str, + }, + "im_chat_log": { + "chat_platform": str, + "chat_participants/group_name": str, + "start_date": str, + "end_date": str, + "summary": str, + }, + "synced_from_notion": { + "title": str, + "language": str, + "author/creator": str, + "creation_date": str, + "last_modified_date": str, + "notion_page_link": str, + "category/tags": str, + "description": str, + }, + "synced_from_github": { + "repository_name": str, + "repository_description": str, + "repository_owner/organization": str, + "code_filename": str, + "code_file_path": str, + "programming_language": str, + "github_link": str, + "open_source_license": str, + "commit_date": str, + "commit_author": str, + }, + "others": dict, + } + + @staticmethod + def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + if document_id: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + return document + else: + return None + + @staticmethod + def get_document_by_id(document_id: str) -> Optional[Document]: + document = db.session.query(Document).filter(Document.id == document_id).first() + + return document + + @staticmethod + def get_document_by_dataset_id(dataset_id: str) -> list[Document]: + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() + + return documents + + @staticmethod + def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) + return documents + + @staticmethod + def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + + return documents + + @staticmethod + def get_document_file_detail(file_id: str): + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() + return file_detail + + @staticmethod + def check_archived(document): + if document.archived: + return True + else: + return False + + @staticmethod + def delete_document(document): + # trigger document_was_deleted signal + file_id = None + if document.data_source_type == "upload_file": + if document.data_source_info: + data_source_info = document.data_source_info_dict + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) + + db.session.delete(document) + db.session.commit() + + @staticmethod + def delete_documents(dataset: Dataset, document_ids: list[str]): + documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + file_ids = [ + document.data_source_info_dict["upload_file_id"] + for document in documents + if document.data_source_type == "upload_file" + ] + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + + for document in documents: + db.session.delete(document) + db.session.commit() + + @staticmethod + def rename_document(dataset_id: str, document_id: str, name: str) -> Document: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found.") + + document = DocumentService.get_document(dataset_id, document_id) + + if not document: + raise ValueError("Document not found.") + + if document.tenant_id != current_user.current_tenant_id: + raise ValueError("No permission.") + + document.name = name + + db.session.add(document) + db.session.commit() + + return document + + @staticmethod + def pause_document(document): + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: + raise DocumentIndexingError() + # update document to be paused + document.is_paused = True + document.paused_by = current_user.id + document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + db.session.add(document) + db.session.commit() + # set document paused flag + indexing_cache_key = "document_{}_is_paused".format(document.id) + redis_client.setnx(indexing_cache_key, "True") + + @staticmethod + def recover_document(document): + if not document.is_paused: + raise DocumentIndexingError() + # update document to be recover + document.is_paused = False + document.paused_by = None + document.paused_at = None + + db.session.add(document) + db.session.commit() + # delete paused flag + indexing_cache_key = "document_{}_is_paused".format(document.id) + redis_client.delete(indexing_cache_key) + # trigger async task + recover_document_indexing_task.delay(document.dataset_id, document.id) + + @staticmethod + def retry_document(dataset_id: str, documents: list[Document]): + for document in documents: + # add retry flag + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) + cache_result = redis_client.get(retry_indexing_cache_key) + if cache_result is not None: + raise ValueError("Document is being retried, please try again later") + # retry document indexing + document.indexing_status = "waiting" + db.session.add(document) + db.session.commit() + + redis_client.setex(retry_indexing_cache_key, 600, 1) + # trigger async task + document_ids = [document.id for document in documents] + retry_document_indexing_task.delay(dataset_id, document_ids) + + @staticmethod + def sync_website_document(dataset_id: str, document: Document): + # add sync flag + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) + cache_result = redis_client.get(sync_indexing_cache_key) + if cache_result is not None: + raise ValueError("Document is being synced, please try again later") + # sync document indexing + document.indexing_status = "waiting" + data_source_info = document.data_source_info_dict + data_source_info["mode"] = "scrape" + document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) + db.session.add(document) + db.session.commit() + + redis_client.setex(sync_indexing_cache_key, 600, 1) + + sync_website_document_indexing_task.delay(dataset_id, document.id) + + @staticmethod + def get_documents_position(dataset_id): + document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + if document: + return document.position + 1 + else: + return 1 + + @staticmethod + def save_document_with_dataset_id( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account | Any, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", + ): + # check document limit + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: + if not knowledge_config.original_document_id: + count = 0 + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) + + # if dataset is empty, update dataset data_source_type + if not dataset.data_source_type: + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + + if not dataset.indexing_technique: + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": + model_manager = ModelManager() + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore + + documents = [] + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # save process rule + if not dataset_process_rule: + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + if knowledge_config.metadata: + document.doc_type = knowledge_config.metadata.doc_type + document.metadata = knowledge_config.metadata.doc_metadata + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + knowledge_config.metadata, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, + } + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + page.page_name, + batch, + knowledge_config.metadata, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page.page_id) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." + else: + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + document_name, + batch, + knowledge_config.metadata, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() + + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + + @staticmethod + def check_documents_upload_quota(count: int, features: FeatureModel): + can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size + if count > can_upload_size: + raise ValueError( + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." + ) + + @staticmethod + def build_document( + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, + account: Account, + name: str, + batch: str, + metadata: Optional[MetaDataConfig] = None, + ): + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info), + dataset_process_rule_id=process_rule_id, + batch=batch, + name=name, + created_from=created_from, + created_by=account.id, + doc_form=document_form, + doc_language=document_language, + ) + if metadata is not None: + document.doc_metadata = metadata.doc_metadata + document.doc_type = metadata.doc_type + return document + + @staticmethod + def get_tenant_documents_count(): + documents_count = Document.query.filter( + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False, + Document.tenant_id == current_user.current_tenant_id, + ).count() + return documents_count + + @staticmethod + def update_document_with_dataset_id( + dataset: Dataset, + document_data: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", + ): + DatasetService.check_dataset_model_setting(dataset) + document = DocumentService.get_document(dataset.id, document_data.original_document_id) + if document is None: + raise NotFound("Document not found") + if document.display_status != "available": + raise ValueError("Document is not available") + # save process rule + if document_data.process_rule: + process_rule = document_data.process_rule + if process_rule.mode in {"custom", "hierarchical"}: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id + # update document data source + if document_data.data_source: + file_name = "" + data_source_info = {} + if document_data.data_source.info_list.data_source_type == "upload_file": + if not document_data.data_source.info_list.file_info_list: + raise ValueError("No file info list found.") + upload_file_list = document_data.data_source.info_list.file_info_list.file_ids + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + elif document_data.data_source.info_list.data_source_type == "notion_import": + if not document_data.data_source.info_list.notion_info_list: + raise ValueError("No notion info list found.") + notion_info_list = document_data.data_source.info_list.notion_info_list + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info.pages: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, + } + elif document_data.data_source.info_list.data_source_type == "website_crawl": + website_info = document_data.data_source.info_list.website_info_list + if website_info: + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, # type: ignore + "mode": "crawl", + } + document.data_source_type = document_data.data_source.info_list.data_source_type + document.data_source_info = json.dumps(data_source_info) + document.name = file_name + + # update document name + if document_data.name: + document.name = document_data.name + # update doc_type and doc_metadata if provided + if document_data.metadata is not None: + document.doc_metadata = document_data.metadata.doc_type + document.doc_type = document_data.metadata.doc_type + # update document to be waiting + document.indexing_status = "waiting" + document.completed_at = None + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = document_data.doc_form + db.session.add(document) + db.session.commit() + # update document segment + update_params = {DocumentSegment.status: "re_segment"} + DocumentSegment.query.filter_by(document_id=document.id).update(update_params) + db.session.commit() + # trigger async task + document_indexing_update_task.delay(document.dataset_id, document.id) + return document + + @staticmethod + def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: + count = 0 + if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + upload_file_list = ( + knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + if knowledge_config.data_source.info_list.file_info_list # type: ignore + else [] + ) + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if notion_info_list: + for notion_info in notion_info_list: + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + if website_info: + count = len(website_info.urls) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) + + dataset_collection_binding_id = None + retrieval_model = None + if knowledge_config.indexing_technique == "high_quality": + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + knowledge_config.embedding_model_provider, # type: ignore + knowledge_config.embedding_model, # type: ignore + ) + dataset_collection_binding_id = dataset_collection_binding.id + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model + else: + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) + # save dataset + dataset = Dataset( + tenant_id=tenant_id, + name="", + data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore + indexing_technique=knowledge_config.indexing_technique, + created_by=account.id, + embedding_model=knowledge_config.embedding_model, + embedding_model_provider=knowledge_config.embedding_model_provider, + collection_binding_id=dataset_collection_binding_id, + retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + ) + + db.session.add(dataset) # type: ignore + db.session.flush() + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + cut_length = 18 + cut_name = documents[0].name[:cut_length] + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name + db.session.commit() + + return dataset, documents, batch + + @classmethod + def document_create_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source and not knowledge_config.process_rule: + raise ValueError("Data source or Process rule is required") + else: + if knowledge_config.data_source: + DocumentService.data_source_args_validate(knowledge_config) + if knowledge_config.process_rule: + DocumentService.process_rule_args_validate(knowledge_config) + + @classmethod + def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source: + raise ValueError("Data source is required") + + if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: + raise ValueError("Data source type is invalid") + + if not knowledge_config.data_source.info_list: + raise ValueError("Data source info is required") + + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: + raise ValueError("File source info is required") + if knowledge_config.data_source.info_list.data_source_type == "notion_import": + if not knowledge_config.data_source.info_list.notion_info_list: + raise ValueError("Notion source info is required") + if knowledge_config.data_source.info_list.data_source_type == "website_crawl": + if not knowledge_config.data_source.info_list.website_info_list: + raise ValueError("Website source info is required") + + @classmethod + def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.process_rule: + raise ValueError("Process rule is required") + + if not knowledge_config.process_rule.mode: + raise ValueError("Process rule mode is required") + + if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: + raise ValueError("Process rule mode is invalid") + + if knowledge_config.process_rule.mode == "automatic": + knowledge_config.process_rule.rules = None + else: + if not knowledge_config.process_rule.rules: + raise ValueError("Process rule rules is required") + + if knowledge_config.process_rule.rules.pre_processing_rules is None: + raise ValueError("Process rule pre_processing_rules is required") + + unique_pre_processing_rule_dicts = {} + for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: + if not pre_processing_rule.id: + raise ValueError("Process rule pre_processing_rules id is required") + + if not isinstance(pre_processing_rule.enabled, bool): + raise ValueError("Process rule pre_processing_rules enabled is invalid") + + unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule + + knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) + + if not knowledge_config.process_rule.rules.segmentation: + raise ValueError("Process rule segmentation is required") + + if not knowledge_config.process_rule.rules.segmentation.separator: + raise ValueError("Process rule segmentation separator is required") + + if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): + raise ValueError("Process rule segmentation separator is invalid") + + if not ( + knowledge_config.process_rule.mode == "hierarchical" + and knowledge_config.process_rule.rules.parent_mode == "full-doc" + ): + if not knowledge_config.process_rule.rules.segmentation.max_tokens: + raise ValueError("Process rule segmentation max_tokens is required") + + if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): + raise ValueError("Process rule segmentation max_tokens is invalid") + + @classmethod + def estimate_args_validate(cls, args: dict): + if "info_list" not in args or not args["info_list"]: + raise ValueError("Data source info is required") + + if not isinstance(args["info_list"], dict): + raise ValueError("Data info is invalid") + + if "process_rule" not in args or not args["process_rule"]: + raise ValueError("Process rule is required") + + if not isinstance(args["process_rule"], dict): + raise ValueError("Process rule is invalid") + + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: + raise ValueError("Process rule mode is required") + + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: + raise ValueError("Process rule mode is invalid") + + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} + else: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: + raise ValueError("Process rule rules is required") + + if not isinstance(args["process_rule"]["rules"], dict): + raise ValueError("Process rule rules is invalid") + + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): + raise ValueError("Process rule pre_processing_rules is required") + + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): + raise ValueError("Process rule pre_processing_rules is invalid") + + unique_pre_processing_rule_dicts = {} + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: + raise ValueError("Process rule pre_processing_rules id is required") + + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: + raise ValueError("Process rule pre_processing_rules enabled is required") + + if not isinstance(pre_processing_rule["enabled"], bool): + raise ValueError("Process rule pre_processing_rules enabled is invalid") + + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule + + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) + + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): + raise ValueError("Process rule segmentation is required") + + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): + raise ValueError("Process rule segmentation is invalid") + + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): + raise ValueError("Process rule segmentation separator is required") + + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): + raise ValueError("Process rule segmentation separator is invalid") + + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): + raise ValueError("Process rule segmentation max_tokens is required") + + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): + raise ValueError("Process rule segmentation max_tokens is invalid") + + +class SegmentService: + @classmethod + def segment_create_args_validate(cls, args: dict, document: Document): + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: + raise ValueError("Answer is required") + if not args["answer"].strip(): + raise ValueError("Answer is empty") + if "content" not in args or not args["content"] or not args["content"].strip(): + raise ValueError("Content is empty") + + @classmethod + def create_segment(cls, args: dict, document: Document, dataset: Dataset): + content = args["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + # calc embedding use tokens + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) + with redis_client.lock(lock_name, timeout=600): + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status="completed", + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_by=current_user.id, + ) + if document.doc_form == "qa_model": + segment_document.word_count += len(args["answer"]) + segment_document.answer = args["answer"] + + db.session.add(segment_document) + # update document word count + document.word_count += segment_document.word_count + db.session.add(document) + db.session.commit() + + # save vector index + try: + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) + except Exception as e: + logging.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + return segment + + @classmethod + def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) + increment_word_count = 0 + with redis_client.lock(lock_name, timeout=600): + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) + pre_segment_data_list = [] + segment_data_list = [] + keywords_list = [] + position = max_position + 1 if max_position else 1 + for segment_item in segments: + content = segment_item["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == "high_quality" and embedding_model: + # calc embedding use tokens + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=position, + content=content, + word_count=len(content), + tokens=tokens, + status="completed", + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_by=current_user.id, + ) + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] + segment_document.word_count += len(segment_item["answer"]) + increment_word_count += segment_document.word_count + db.session.add(segment_document) + segment_data_list.append(segment_document) + position += 1 + + pre_segment_data_list.append(segment_document) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) + else: + keywords_list.append(None) + # update document word count + document.word_count += increment_word_count + db.session.add(document) + try: + # save vector index + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) + except Exception as e: + logging.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + return segment_data_list + + @classmethod + def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise ValueError("Segment is indexing, please try again later") + if args.enabled is not None: + action = args.enabled + if segment.enabled != action: + if not action: + segment.enabled = action + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + db.session.commit() + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + disable_segment_from_index_task.delay(segment.id) + return segment + if not segment.enabled: + if args.enabled is not None: + if not args.enabled: + raise ValueError("Can't update disabled segment") + else: + raise ValueError("Can't update disabled segment") + try: + word_count_change = segment.word_count + content = args.content or segment.content + if segment.content == content: + segment.word_count = len(content) + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + word_count_change = segment.word_count - word_count_change + keyword_changed = False + if args.keywords: + if Counter(segment.keywords) != Counter(args.keywords): + segment.keywords = args.keywords + keyword_changed = True + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + db.session.commit() + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + # update segment index task + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # regenerate child chunks + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if args.enabled or keyword_changed: + VectorService.create_segments_vector( + [args.keywords] if args.keywords else None, + [segment], + dataset, + document.doc_form, + ) + else: + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + # calc embedding use tokens + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + segment.content = content + segment.index_node_hash = segment_hash + segment.word_count = len(content) + segment.tokens = tokens + segment.status = "completed" + segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.updated_by = current_user.id + segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + word_count_change = segment.word_count - word_count_change + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + db.session.add(segment) + db.session.commit() + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # update segment vector index + VectorService.update_segment_vector(args.keywords, segment, dataset) + + except Exception as e: + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.status = "error" + segment.error = str(e) + db.session.commit() + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment + + @classmethod + def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise ValueError("Segment is deleting.") + + # enabled segment need to delete index + if segment.enabled: + # send delete segment index task + redis_client.setex(indexing_cache_key, 600, 1) + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) + db.session.delete(segment) + # update document word count + document.word_count -= segment.word_count + db.session.add(document) + db.session.commit() + + @classmethod + def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + index_node_ids = ( + DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.commit() + + @classmethod + def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + if action == "enable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + elif action == "disable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + else: + raise InvalidActionError() + + @classmethod + def create_child_chunk( + cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> ChildChunk: + lock_name = "add_child_lock_{}".format(segment.id) + with redis_client.lock(lock_name, timeout=20): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(content) + child_chunk_count = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .count() + ) + max_position = ( + db.session.query(func.max(ChildChunk.position)) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .scalar() + ) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=max_position + 1, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=content, + word_count=len(content), + type="customized", + created_by=current_user.id, + ) + db.session.add(child_chunk) + # save vector index + try: + VectorService.create_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("create child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + db.session.commit() + + return child_chunk + + @classmethod + def update_child_chunks( + cls, + child_chunks_update_args: list[ChildChunkUpdateArgs], + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> list[ChildChunk]: + child_chunks = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .all() + ) + child_chunks_map = {chunk.id: chunk for chunk in child_chunks} + + new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] + + for child_chunk_update_args in child_chunks_update_args: + if child_chunk_update_args.id: + child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) + if child_chunk: + if child_chunk.content != child_chunk_update_args.content: + child_chunk.content = child_chunk_update_args.content + child_chunk.word_count = len(child_chunk.content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + update_child_chunks.append(child_chunk) + else: + new_child_chunks_args.append(child_chunk_update_args) + if child_chunks_map: + delete_child_chunks = list(child_chunks_map.values()) + try: + if update_child_chunks: + db.session.bulk_save_objects(update_child_chunks) + + if delete_child_chunks: + for child_chunk in delete_child_chunks: + db.session.delete(child_chunk) + if new_child_chunks_args: + child_chunk_count = len(child_chunks) + for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(args.content) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=position, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=args.content, + word_count=len(args.content), + type="customized", + created_by=current_user.id, + ) + + db.session.add(child_chunk) + db.session.flush() + new_child_chunks.append(child_chunk) + VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) + + @classmethod + def update_child_chunk( + cls, + content: str, + child_chunk: ChildChunk, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> ChildChunk: + try: + child_chunk.content = content + child_chunk.word_count = len(content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + db.session.add(child_chunk) + VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return child_chunk + + @classmethod + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): + db.session.delete(child_chunk) + try: + VectorService.delete_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("delete child chunk index failed") + db.session.rollback() + raise ChildChunkDeleteIndexError(str(e)) + db.session.commit() + + @classmethod + def get_child_chunks( + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + ): + query = ChildChunk.query.filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ).order_by(ChildChunk.position.asc()) + if keyword: + query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + + +class DatasetCollectionBindingService: + @classmethod + def get_dataset_collection_binding( + cls, provider_name: str, model_name: str, collection_type: str = "dataset" + ) -> DatasetCollectionBinding: + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), + type=collection_type, + ) + db.session.add(dataset_collection_binding) + db.session.commit() + return dataset_collection_binding + + @classmethod + def get_dataset_collection_binding_by_id_and_type( + cls, collection_binding_id: str, collection_type: str = "dataset" + ) -> DatasetCollectionBinding: + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") + + return dataset_collection_binding + + +class DatasetPermissionService: + @classmethod + def get_dataset_partial_member_list(cls, dataset_id): + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) + + user_list = [] + for user in user_list_query: + user_list.append(user.account_id) + + return user_list + + @classmethod + def update_partial_member_list(cls, tenant_id, dataset_id, user_list): + try: + db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + permissions = [] + for user in user_list: + permission = DatasetPermission( + tenant_id=tenant_id, + dataset_id=dataset_id, + account_id=user["user_id"], + ) + permissions.append(permission) + + db.session.add_all(permissions) + db.session.commit() + except Exception as e: + db.session.rollback() + raise e + + @classmethod + def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): + if not user.is_dataset_editor: + raise NoPermissionError("User does not have permission to edit this dataset.") + + if user.is_dataset_operator and dataset.permission != requested_permission: + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") + + if user.is_dataset_operator and requested_permission == "partial_members": + if not requested_partial_member_list: + raise ValueError("Partial member list is required when setting to partial members.") + + local_member_list = cls.get_dataset_partial_member_list(dataset.id) + request_member_list = [user["user_id"] for user in requested_partial_member_list] + if set(local_member_list) != set(request_member_list): + raise ValueError("Dataset operators cannot change the dataset permissions.") + + @classmethod + def clear_partial_member_list(cls, dataset_id): + try: + db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + db.session.commit() + except Exception as e: + db.session.rollback() + raise e diff --git a/api/services/enterprise/__init__.py b/api/services/enterprise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3c3f9704440342c16339434ae7a1b8d8e7fccb3f --- /dev/null +++ b/api/services/enterprise/base.py @@ -0,0 +1,20 @@ +import os + +import requests + + +class EnterpriseRequest: + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + + proxies = { + "http": "", + "https": "", + } + + @classmethod + def send_request(cls, method, endpoint, json=None, params=None): + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) + return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py new file mode 100644 index 0000000000000000000000000000000000000000..abc01ddf8f58b0e53014fffe2a2f1cc9412eff2c --- /dev/null +++ b/api/services/enterprise/enterprise_service.py @@ -0,0 +1,11 @@ +from services.enterprise.base import EnterpriseRequest + + +class EnterpriseService: + @classmethod + def get_info(cls): + return EnterpriseRequest.send_request("GET", "/info") + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/__init__.py b/api/services/entities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..4545f385eb98912ea5295b2ae56aaf918c54bf7e --- /dev/null +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel + + +class AuthorizationConfig(BaseModel): + type: Literal[None, "basic", "bearer", "custom"] + api_key: Union[None, str] = None + header: Union[None, str] = None + + +class Authorization(BaseModel): + type: Literal["no-auth", "api-key"] + config: Optional[AuthorizationConfig] = None + + +class ProcessStatusSetting(BaseModel): + request_method: str + url: str + + +class ExternalKnowledgeApiSetting(BaseModel): + url: str + request_method: str + headers: Optional[dict] = None + params: Optional[dict] = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..f14c5b513a8687b4116a51c8243128c40a0f18ae --- /dev/null +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -0,0 +1,126 @@ +from enum import Enum +from typing import Literal, Optional + +from pydantic import BaseModel + + +class SegmentUpdateEntity(BaseModel): + content: str + answer: Optional[str] = None + keywords: Optional[list[str]] = None + enabled: Optional[bool] = None + + +class ParentMode(str, Enum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class NotionIcon(BaseModel): + type: str + url: Optional[str] = None + emoji: Optional[str] = None + + +class NotionPage(BaseModel): + page_id: str + page_name: str + page_icon: Optional[NotionIcon] = None + type: str + + +class NotionInfo(BaseModel): + workspace_id: str + pages: list[NotionPage] + + +class WebsiteInfo(BaseModel): + provider: str + job_id: str + urls: list[str] + only_main_content: bool = True + + +class FileInfo(BaseModel): + file_ids: list[str] + + +class InfoList(BaseModel): + data_source_type: Literal["upload_file", "notion_import", "website_crawl"] + notion_info_list: Optional[list[NotionInfo]] = None + file_info_list: Optional[FileInfo] = None + website_info_list: Optional[WebsiteInfo] = None + + +class DataSource(BaseModel): + info_list: InfoList + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: Optional[list[PreProcessingRule]] = None + segmentation: Optional[Segmentation] = None + parent_mode: Optional[Literal["full-doc", "paragraph"]] = None + subchunk_segmentation: Optional[Segmentation] = None + + +class ProcessRule(BaseModel): + mode: Literal["automatic", "custom", "hierarchical"] + rules: Optional[Rule] = None + + +class RerankingModel(BaseModel): + reranking_provider_name: Optional[str] = None + reranking_model_name: Optional[str] = None + + +class RetrievalModel(BaseModel): + search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + reranking_enable: bool + reranking_model: Optional[RerankingModel] = None + top_k: int + score_threshold_enabled: bool + score_threshold: Optional[float] = None + + +class MetaDataConfig(BaseModel): + doc_type: str + doc_metadata: dict + + +class KnowledgeConfig(BaseModel): + original_document_id: Optional[str] = None + duplicate: bool = True + indexing_technique: Literal["high_quality", "economy"] + data_source: Optional[DataSource] = None + process_rule: Optional[ProcessRule] = None + retrieval_model: Optional[RetrievalModel] = None + doc_form: str = "text_model" + doc_language: str = "English" + embedding_model: Optional[str] = None + embedding_model_provider: Optional[str] = None + name: Optional[str] = None + metadata: Optional[MetaDataConfig] = None + + +class SegmentUpdateArgs(BaseModel): + content: Optional[str] = None + answer: Optional[str] = None + keywords: Optional[list[str]] = None + regenerate_child_chunks: bool = False + enabled: Optional[bool] = None + + +class ChildChunkUpdateArgs(BaseModel): + id: Optional[str] = None + content: str diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..f1417c6cb94b8032124c1f580c3dce90be8fd762 --- /dev/null +++ b/api/services/entities/model_provider_entities.py @@ -0,0 +1,158 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from configs import dify_config +from core.entities.model_entities import ( + ModelWithProviderEntity, + ProviderModelWithStatusEntity, +) +from core.entities.provider_entities import QuotaConfiguration +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) +from models.provider import ProviderQuotaType, ProviderType + + +class CustomConfigurationStatus(Enum): + """ + Enum class for custom configuration status. + """ + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" + + +class CustomConfigurationResponse(BaseModel): + """ + Model class for provider custom configuration response. + """ + + status: CustomConfigurationStatus + + +class SystemConfigurationResponse(BaseModel): + """ + Model class for provider system configuration response. + """ + + enabled: bool + current_quota_type: Optional[ProviderQuotaType] = None + quota_configurations: list[QuotaConfiguration] = [] + + +class ProviderResponse(BaseModel): + """ + Model class for provider response. + """ + + provider: str + label: I18nObject + description: Optional[I18nObject] = None + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + background: Optional[str] = None + help: Optional[ProviderHelpEntity] = None + supported_model_types: list[ModelType] + configurate_methods: list[ConfigurateMethod] + provider_credential_schema: Optional[ProviderCredentialSchema] = None + model_credential_schema: Optional[ModelCredentialSchema] = None + preferred_provider_type: ProviderType + custom_configuration: CustomConfigurationResponse + system_configuration: SystemConfigurationResponse + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + def __init__(self, **data) -> None: + super().__init__(**data) + + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + if self.icon_small is not None: + self.icon_small = I18nObject( + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + ) + + if self.icon_large is not None: + self.icon_large = I18nObject( + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + ) + + +class ProviderWithModelsResponse(BaseModel): + """ + Model class for provider with models response. + """ + + provider: str + label: I18nObject + icon_small: Optional[I18nObject] = None + icon_large: Optional[I18nObject] = None + status: CustomConfigurationStatus + models: list[ProviderModelWithStatusEntity] + + def __init__(self, **data) -> None: + super().__init__(**data) + + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + if self.icon_small is not None: + self.icon_small = I18nObject( + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + ) + + if self.icon_large is not None: + self.icon_large = I18nObject( + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + ) + + +class SimpleProviderEntityResponse(SimpleProviderEntity): + """ + Simple provider entity response. + """ + + def __init__(self, **data) -> None: + super().__init__(**data) + + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + if self.icon_small is not None: + self.icon_small = I18nObject( + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + ) + + if self.icon_large is not None: + self.icon_large = I18nObject( + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + ) + + +class DefaultModelResponse(BaseModel): + """ + Default model entity. + """ + + model: str + model_type: ModelType + provider: SimpleProviderEntityResponse + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class ModelWithProviderEntityResponse(ModelWithProviderEntity): + """ + Model with provider entity. + """ + + # FIXME type error ignore here + provider: SimpleProviderEntityResponse # type: ignore + + def __init__(self, model: ModelWithProviderEntity) -> None: + super().__init__(**model.model_dump()) diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1f055708cdc4d5ac4527cb943a40827ccdfbdf --- /dev/null +++ b/api/services/errors/__init__.py @@ -0,0 +1,29 @@ +from . import ( + account, + app, + app_model_config, + audio, + base, + completion, + conversation, + dataset, + document, + file, + index, + message, +) + +__all__ = [ + "account", + "app", + "app_model_config", + "audio", + "base", + "completion", + "conversation", + "dataset", + "document", + "file", + "index", + "message", +] diff --git a/api/services/errors/account.py b/api/services/errors/account.py new file mode 100644 index 0000000000000000000000000000000000000000..5aca12ffeb9891af9fbfa82e470135b785c81d6d --- /dev/null +++ b/api/services/errors/account.py @@ -0,0 +1,61 @@ +from services.errors.base import BaseServiceError + + +class AccountNotFoundError(BaseServiceError): + pass + + +class AccountRegisterError(BaseServiceError): + pass + + +class AccountLoginError(BaseServiceError): + pass + + +class AccountPasswordError(BaseServiceError): + pass + + +class AccountNotLinkTenantError(BaseServiceError): + pass + + +class CurrentPasswordIncorrectError(BaseServiceError): + pass + + +class LinkAccountIntegrateError(BaseServiceError): + pass + + +class TenantNotFoundError(BaseServiceError): + pass + + +class AccountAlreadyInTenantError(BaseServiceError): + pass + + +class InvalidActionError(BaseServiceError): + pass + + +class CannotOperateSelfError(BaseServiceError): + pass + + +class NoPermissionError(BaseServiceError): + pass + + +class MemberNotInTenantError(BaseServiceError): + pass + + +class RoleAlreadyAssignedError(BaseServiceError): + pass + + +class RateLimitExceededError(BaseServiceError): + pass diff --git a/api/services/errors/app.py b/api/services/errors/app.py new file mode 100644 index 0000000000000000000000000000000000000000..87e9e9247d6422043fb52bedceeae29d0f29b4e4 --- /dev/null +++ b/api/services/errors/app.py @@ -0,0 +1,6 @@ +class MoreLikeThisDisabledError(Exception): + pass + + +class WorkflowHashNotEqualError(Exception): + pass diff --git a/api/services/errors/app_model_config.py b/api/services/errors/app_model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c0669ed231a1032be37fcb103151bdbfb6152937 --- /dev/null +++ b/api/services/errors/app_model_config.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class AppModelConfigBrokenError(BaseServiceError): + pass diff --git a/api/services/errors/audio.py b/api/services/errors/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..4005cbfcd7a26e25c1c725d04ab656bd73ccecb7 --- /dev/null +++ b/api/services/errors/audio.py @@ -0,0 +1,22 @@ +class NoAudioUploadedServiceError(Exception): + pass + + +class AudioTooLargeServiceError(Exception): + pass + + +class UnsupportedAudioTypeServiceError(Exception): + pass + + +class ProviderNotSupportSpeechToTextServiceError(Exception): + pass + + +class ProviderNotSupportTextToSpeechServiceError(Exception): + pass + + +class ProviderNotSupportTextToSpeechLanageServiceError(Exception): + pass diff --git a/api/services/errors/base.py b/api/services/errors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..35ea28468e0d86110c394a6482b33c254cecdab8 --- /dev/null +++ b/api/services/errors/base.py @@ -0,0 +1,6 @@ +from typing import Optional + + +class BaseServiceError(ValueError): + def __init__(self, description: Optional[str] = None): + self.description = description diff --git a/api/services/errors/chunk.py b/api/services/errors/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..75bf4d5d5f81220ecabbbab89b167f558aedd3e0 --- /dev/null +++ b/api/services/errors/chunk.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class ChildChunkIndexingError(BaseServiceError): + description = "{message}" + + +class ChildChunkDeleteIndexError(BaseServiceError): + description = "{message}" diff --git a/api/services/errors/completion.py b/api/services/errors/completion.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc50a588e6343900894d82029841e92231f95fc --- /dev/null +++ b/api/services/errors/completion.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class CompletionStoppedError(BaseServiceError): + pass diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..139dd9a70aef294ad34d347b28face50f6962456 --- /dev/null +++ b/api/services/errors/conversation.py @@ -0,0 +1,13 @@ +from services.errors.base import BaseServiceError + + +class LastConversationNotExistsError(BaseServiceError): + pass + + +class ConversationNotExistsError(BaseServiceError): + pass + + +class ConversationCompletedError(Exception): + pass diff --git a/api/services/errors/dataset.py b/api/services/errors/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d36cd1111c78f8738d19edee76d42db9b943aaae --- /dev/null +++ b/api/services/errors/dataset.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class DatasetNameDuplicateError(BaseServiceError): + pass + + +class DatasetInUseError(BaseServiceError): + pass diff --git a/api/services/errors/document.py b/api/services/errors/document.py new file mode 100644 index 0000000000000000000000000000000000000000..7327b9d032b7dba971313c76c6b7561d85771e03 --- /dev/null +++ b/api/services/errors/document.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class DocumentIndexingError(BaseServiceError): + pass diff --git a/api/services/errors/file.py b/api/services/errors/file.py new file mode 100644 index 0000000000000000000000000000000000000000..29f3f44eece89d0071c6a37524fe8435d25e599e --- /dev/null +++ b/api/services/errors/file.py @@ -0,0 +1,13 @@ +from services.errors.base import BaseServiceError + + +class FileNotExistsError(BaseServiceError): + pass + + +class FileTooLargeError(BaseServiceError): + description = "{message}" + + +class UnsupportedFileTypeError(BaseServiceError): + pass diff --git a/api/services/errors/index.py b/api/services/errors/index.py new file mode 100644 index 0000000000000000000000000000000000000000..8513b6a55d1d8a927508afcb603b5ba5537f9eb3 --- /dev/null +++ b/api/services/errors/index.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class IndexNotInitializedError(BaseServiceError): + pass diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fac6f7450040d441782a589318ad0d6b4baf1a --- /dev/null +++ b/api/services/errors/llm.py @@ -0,0 +1,19 @@ +from typing import Optional + + +class InvokeError(Exception): + """Base class for all LLM exceptions.""" + + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + def __str__(self): + return self.description or self.__class__.__name__ + + +class InvokeRateLimitError(InvokeError): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/services/errors/message.py b/api/services/errors/message.py new file mode 100644 index 0000000000000000000000000000000000000000..969447df9f1e779987e2c181b99d3e1bc3270bbb --- /dev/null +++ b/api/services/errors/message.py @@ -0,0 +1,17 @@ +from services.errors.base import BaseServiceError + + +class FirstMessageNotExistsError(BaseServiceError): + pass + + +class LastMessageNotExistsError(BaseServiceError): + pass + + +class MessageNotExistsError(BaseServiceError): + pass + + +class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError): + pass diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py new file mode 100644 index 0000000000000000000000000000000000000000..714064ffdf8c3da4bfa4159550e0d7c89734a7d8 --- /dev/null +++ b/api/services/errors/workspace.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class WorkSpaceNotAllowedCreateError(BaseServiceError): + pass + + +class WorkSpaceNotFoundError(BaseServiceError): + pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8916a951c7408e24fd1b27d92e074975bb544f97 --- /dev/null +++ b/api/services/external_knowledge_service.py @@ -0,0 +1,288 @@ +import json +from copy import deepcopy +from datetime import UTC, datetime +from typing import Any, Optional, Union, cast + +import httpx +import validators + +from constants import HIDDEN_VALUE +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models.dataset import ( + Dataset, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, +) +from services.entities.external_knowledge_entities.external_knowledge_entities import ( + Authorization, + ExternalKnowledgeApiSetting, +) +from services.errors.dataset import DatasetNameDuplicateError + + +class ExternalDatasetService: + @staticmethod + def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: + query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( + ExternalKnowledgeApis.created_at.desc() + ) + if search: + query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + + external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + + return external_knowledge_apis.items, external_knowledge_apis.total + + @classmethod + def validate_api_list(cls, api_settings: dict): + if not api_settings: + raise ValueError("api list is empty") + if "endpoint" not in api_settings and not api_settings["endpoint"]: + raise ValueError("endpoint is required") + if "api_key" not in api_settings and not api_settings["api_key"]: + raise ValueError("api_key is required") + + @staticmethod + def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) + external_knowledge_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=user_id, + updated_by=user_id, + name=args.get("name"), + description=args.get("description", ""), + settings=json.dumps(args.get("settings"), ensure_ascii=False), + ) + + db.session.add(external_knowledge_api) + db.session.commit() + return external_knowledge_api + + @staticmethod + def check_endpoint_and_api_key(settings: dict): + if "endpoint" not in settings or not settings["endpoint"]: + raise ValueError("endpoint is required") + if "api_key" not in settings or not settings["api_key"]: + raise ValueError("api_key is required") + + endpoint = f"{settings['endpoint']}/retrieval" + api_key = settings["api_key"] + if not validators.url(endpoint, simple_host=True): + if not endpoint.startswith("http://") and not endpoint.startswith("https://"): + raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://") + else: + raise ValueError(f"invalid endpoint: {endpoint}") + try: + response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"}) + except Exception as e: + raise ValueError(f"failed to connect to the endpoint: {endpoint}") + if response.status_code == 502: + raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}") + if response.status_code == 404: + raise ValueError(f"Not Found: failed to connect to the endpoint: {endpoint}") + if response.status_code == 403: + raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}") + + @staticmethod + def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api + + @staticmethod + def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: + args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") + + external_knowledge_api.name = args.get("name") + external_knowledge_api.description = args.get("description", "") + external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) + external_knowledge_api.updated_by = user_id + external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + return external_knowledge_api + + @staticmethod + def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + + db.session.delete(external_knowledge_api) + db.session.commit() + + @staticmethod + def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: + count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() + if count > 0: + return True, count + return False, 0 + + @staticmethod + def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError("external knowledge binding not found") + return external_knowledge_binding + + @staticmethod + def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + settings = json.loads(external_knowledge_api.settings) + for setting in settings: + custom_parameters = setting.get("document_process_setting") + if custom_parameters: + for parameter in custom_parameters: + if parameter.get("required", False) and not process_parameter.get(parameter.get("name")): + raise ValueError(f"{parameter.get('name')} is required") + + @staticmethod + def process_external_api( + settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] + ) -> httpx.Response: + """ + do http request depending on api bundle + """ + + kwargs = { + "url": settings.url, + "headers": settings.headers, + "follow_redirects": True, + } + + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) + return response + + @staticmethod + def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]: + authorization = deepcopy(authorization) + if headers: + headers = deepcopy(headers) + else: + headers = {} + if authorization.type == "api-key": + if authorization.config is None: + raise ValueError("authorization config is required") + + if authorization.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif authorization.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key + + return headers + + @staticmethod + def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting: + return ExternalKnowledgeApiSetting.parse_obj(settings) + + @staticmethod + def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: + # check if dataset name already exists + if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=args.get("external_knowledge_api_id"), tenant_id=tenant_id + ).first() + + if external_knowledge_api is None: + raise ValueError("api template not found") + + dataset = Dataset( + tenant_id=tenant_id, + name=args.get("name"), + description=args.get("description", ""), + provider="external", + retrieval_model=args.get("external_retrieval_model"), + created_by=user_id, + ) + + db.session.add(dataset) + db.session.flush() + + external_knowledge_binding = ExternalKnowledgeBindings( + tenant_id=tenant_id, + dataset_id=dataset.id, + external_knowledge_api_id=args.get("external_knowledge_api_id"), + external_knowledge_id=args.get("external_knowledge_id"), + created_by=user_id, + ) + db.session.add(external_knowledge_binding) + + db.session.commit() + + return dataset + + @staticmethod + def fetch_external_knowledge_retrieval( + tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict + ) -> list: + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError("external knowledge binding not found") + + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_binding.external_knowledge_api_id + ).first() + if not external_knowledge_api: + raise ValueError("external api template not found") + + settings = json.loads(external_knowledge_api.settings) + headers = {"Content-Type": "application/json"} + if settings.get("api_key"): + headers["Authorization"] = f"Bearer {settings.get('api_key')}" + score_threshold_enabled = external_retrieval_parameters.get("score_threshold_enabled") or False + score_threshold = external_retrieval_parameters.get("score_threshold", 0.0) if score_threshold_enabled else 0.0 + request_params = { + "retrieval_setting": { + "top_k": external_retrieval_parameters.get("top_k"), + "score_threshold": score_threshold, + }, + "query": query, + "knowledge_id": external_knowledge_binding.external_knowledge_id, + } + + response = ExternalDatasetService.process_external_api( + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, + ) + if response.status_code == 200: + return cast(list[Any], response.json().get("records", [])) + return [] diff --git a/api/services/feature_service.py b/api/services/feature_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b9261d19d7930ec4ce0ef8da4e357a6d3ec8f47f --- /dev/null +++ b/api/services/feature_service.py @@ -0,0 +1,182 @@ +from enum import StrEnum + +from pydantic import BaseModel, ConfigDict + +from configs import dify_config +from services.billing_service import BillingService +from services.enterprise.enterprise_service import EnterpriseService + + +class SubscriptionModel(BaseModel): + plan: str = "sandbox" + interval: str = "" + + +class BillingModel(BaseModel): + enabled: bool = False + subscription: SubscriptionModel = SubscriptionModel() + + +class LimitationModel(BaseModel): + size: int = 0 + limit: int = 0 + + +class LicenseStatus(StrEnum): + NONE = "none" + INACTIVE = "inactive" + ACTIVE = "active" + EXPIRING = "expiring" + EXPIRED = "expired" + LOST = "lost" + + +class LicenseModel(BaseModel): + status: LicenseStatus = LicenseStatus.NONE + expired_at: str = "" + + +class FeatureModel(BaseModel): + billing: BillingModel = BillingModel() + members: LimitationModel = LimitationModel(size=0, limit=1) + apps: LimitationModel = LimitationModel(size=0, limit=10) + vector_space: LimitationModel = LimitationModel(size=0, limit=5) + annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) + documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) + docs_processing: str = "standard" + can_replace_logo: bool = False + model_load_balancing_enabled: bool = False + dataset_operator_enabled: bool = False + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + +class SystemFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = "" + sso_enforced_for_web: bool = False + sso_enforced_for_web_protocol: str = "" + enable_web_sso_switch_component: bool = False + enable_email_code_login: bool = False + enable_email_password_login: bool = True + enable_social_oauth_login: bool = False + is_allow_register: bool = False + is_allow_create_workspace: bool = False + is_email_setup: bool = False + license: LicenseModel = LicenseModel() + + +class FeatureService: + @classmethod + def get_features(cls, tenant_id: str) -> FeatureModel: + features = FeatureModel() + + cls._fulfill_params_from_env(features) + + if dify_config.BILLING_ENABLED and tenant_id: + cls._fulfill_params_from_billing_api(features, tenant_id) + + return features + + @classmethod + def get_system_features(cls) -> SystemFeatureModel: + system_features = SystemFeatureModel() + + cls._fulfill_system_params_from_env(system_features) + + if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True + + cls._fulfill_params_from_enterprise(system_features) + + return system_features + + @classmethod + def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): + system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN + system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN + system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.is_allow_register = dify_config.ALLOW_REGISTER + system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + + @classmethod + def _fulfill_params_from_env(cls, features: FeatureModel): + features.can_replace_logo = dify_config.CAN_REPLACE_LOGO + features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED + features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED + + @classmethod + def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): + billing_info = BillingService.get_info(tenant_id) + + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] + + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] + + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] + + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] + + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] + + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] + + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] + + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] + + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] + + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + if "sso_enforced_for_signin" in enterprise_info: + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + + if "sso_enforced_for_signin_protocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + + if "sso_enforced_for_web" in enterprise_info: + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + + if "sso_enforced_for_web_protocol" in enterprise_info: + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + + if "enable_email_code_login" in enterprise_info: + features.enable_email_code_login = enterprise_info["enable_email_code_login"] + + if "enable_email_password_login" in enterprise_info: + features.enable_email_password_login = enterprise_info["enable_email_password_login"] + + if "is_allow_register" in enterprise_info: + features.is_allow_register = enterprise_info["is_allow_register"] + + if "is_allow_create_workspace" in enterprise_info: + features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] + + if "license" in enterprise_info: + license_info = enterprise_info["license"] + + if "status" in license_info: + features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + + if "expired_at" in license_info: + features.license.expired_at = license_info["expired_at"] diff --git a/api/services/file_service.py b/api/services/file_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d417e81734c8af7f590e4d8f80828c8e1cbbdcff --- /dev/null +++ b/api/services/file_service.py @@ -0,0 +1,205 @@ +import datetime +import hashlib +import uuid +from typing import Any, Literal, Union + +from flask_login import current_user # type: ignore +from werkzeug.exceptions import NotFound + +from configs import dify_config +from constants import ( + AUDIO_EXTENSIONS, + DOCUMENT_EXTENSIONS, + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, +) +from core.file import helpers as file_helpers +from core.rag.extractor.extract_processor import ExtractProcessor +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.account import Account +from models.enums import CreatedByRole +from models.model import EndUser, UploadFile + +from .errors.file import FileTooLargeError, UnsupportedFileTypeError + +PREVIEW_WORDS_LIMIT = 3000 + + +class FileService: + @staticmethod + def upload_file( + *, + filename: str, + content: bytes, + mimetype: str, + user: Union[Account, EndUser, Any], + source: Literal["datasets"] | None = None, + source_url: str = "", + ) -> UploadFile: + # get file extension + extension = filename.split(".")[-1].lower() + if len(filename) > 200: + filename = filename.split(".")[0][:200] + "." + extension + + if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: + raise UnsupportedFileTypeError() + + # get file size + file_size = len(content) + + # check if the file size is exceeded + if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): + raise FileTooLargeError + + # generate file key + file_uuid = str(uuid.uuid4()) + + if isinstance(user, Account): + current_tenant_id = user.current_tenant_id + else: + # end_user + current_tenant_id = user.tenant_id + + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension + + # save file to storage + storage.save(file_key, content) + + # save file to db + upload_file = UploadFile( + tenant_id=current_tenant_id or "", + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=filename, + size=file_size, + extension=extension, + mime_type=mimetype, + created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), + created_by=user.id, + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used=False, + hash=hashlib.sha3_256(content).hexdigest(), + source_url=source_url, + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def is_file_size_within_limit(*, extension: str, file_size: int) -> bool: + if extension in IMAGE_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 + else: + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + + return file_size <= file_size_limit + + @staticmethod + def upload_text(text: str, text_name: str) -> UploadFile: + if len(text_name) > 200: + text_name = text_name[:200] + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" + + # save file to storage + storage.save(file_key, text.encode("utf-8")) + + # save file to db + upload_file = UploadFile( + tenant_id=current_user.current_tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=text_name, + size=len(text), + extension="txt", + mime_type="text/plain", + created_by=current_user.id, + created_by_role=CreatedByRole.ACCOUNT, + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used=True, + used_by=current_user.id, + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def get_file_preview(file_id: str): + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found") + + # extract text from file + extension = upload_file.extension + if extension.lower() not in DOCUMENT_EXTENSIONS: + raise UnsupportedFileTypeError() + + text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" + + return text + + @staticmethod + def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_image_signature( + upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign + ) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + # extract text from file + extension = upload_file.extension + if extension.lower() not in IMAGE_EXTENSIONS: + raise UnsupportedFileTypeError() + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file.mime_type + + @staticmethod + def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file + + @staticmethod + def get_public_image_preview(file_id: str): + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + # extract text from file + extension = upload_file.extension + if extension.lower() not in IMAGE_EXTENSIONS: + raise UnsupportedFileTypeError() + + generator = storage.load(upload_file.key) + + return generator, upload_file.mime_type diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e9176fc1c6015c8f08872d51d9e18a09cd58c5dd --- /dev/null +++ b/api/services/hit_testing_service.py @@ -0,0 +1,146 @@ +import logging +import time +from typing import Any + +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from extensions.ext_database import db +from models.account import Account +from models.dataset import Dataset, DatasetQuery + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class HitTestingService: + @classmethod + def retrieve( + cls, + dataset: Dataset, + query: str, + account: Account, + retrieval_model: Any, # FIXME drop this any + external_retrieval_model: dict, + limit: int = 10, + ) -> dict: + if dataset.available_document_count == 0 or dataset.available_segment_count == 0: + return { + "query": { + "content": query, + "tsne_position": {"x": 0, "y": 0}, + }, + "records": [], + } + + start = time.perf_counter() + + # get retrieval model , if the model is not setting , using default + if not retrieval_model: + retrieval_model = dataset.retrieval_model or default_retrieval_model + + all_documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + end = time.perf_counter() + logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") + + dataset_query = DatasetQuery( + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + ) + + db.session.add(dataset_query) + db.session.commit() + + return cls.compact_retrieve_response(query, all_documents) # type: ignore + + @classmethod + def external_retrieve( + cls, + dataset: Dataset, + query: str, + account: Account, + external_retrieval_model: dict, + ) -> dict: + if dataset.provider != "external": + return { + "query": {"content": query}, + "records": [], + } + + start = time.perf_counter() + + all_documents = RetrievalService.external_retrieve( + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + external_retrieval_model=external_retrieval_model, + ) + + end = time.perf_counter() + logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") + + dataset_query = DatasetQuery( + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + ) + + db.session.add(dataset_query) + db.session.commit() + + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) + + @classmethod + def compact_retrieve_response(cls, query: str, documents: list[Document]): + records = RetrievalService.format_retrieval_documents(documents) + + return { + "query": { + "content": query, + }, + "records": [record.model_dump() for record in records], + } + + @classmethod + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: + records = [] + if dataset.provider == "external": + for document in documents: + record = { + "content": document.get("content", None), + "title": document.get("title", None), + "score": document.get("score", None), + "metadata": document.get("metadata", None), + } + records.append(record) + return { + "query": {"content": query}, + "records": records, + } + return {"query": {"content": query}, "records": []} + + @classmethod + def hit_testing_args_check(cls, args): + query = args["query"] + + if not query or len(query) > 250: + raise ValueError("Query is required and cannot exceed 250 characters") + + @staticmethod + def escape_query_for_search(query: str) -> str: + return query.replace('"', '\\"') diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8df1a6ba144d4e529a32b2e467d96c571067dc94 --- /dev/null +++ b/api/services/knowledge_service.py @@ -0,0 +1,45 @@ +import boto3 # type: ignore + +from configs import dify_config + + +class ExternalDatasetTestService: + # this service is only for internal testing + @staticmethod + def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + # get bedrock client + client = boto3.client( + "bedrock-agent-runtime", + aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, + aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, + # example: us-east-1 + region_name="us-east-1", + ) + # fetch external knowledge retrieval + response = client.retrieve( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": retrieval_setting.get("top_k"), + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + # parse response + results = [] + if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: + if response.get("retrievalResults"): + retrieval_results = response.get("retrievalResults") + for retrieval_result in retrieval_results: + # filter out results with score less than threshold + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + continue + result = { + "metadata": retrieval_result.get("metadata"), + "score": retrieval_result.get("score"), + "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"), + "content": retrieval_result.get("content").get("text"), + } + results.append(result) + return {"records": results} diff --git a/api/services/message_service.py b/api/services/message_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c17122ef647ecd99bf82b7c66e6701b1b56095f4 --- /dev/null +++ b/api/services/message_service.py @@ -0,0 +1,303 @@ +import json +from typing import Optional, Union + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.llm_generator.llm_generator import LLMGenerator +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.utils import measure_time +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.account import Account +from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback +from services.conversation_service import ConversationService +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.workflow_service import WorkflowService + + +class MessageService: + @classmethod + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + order: str = "asc", + ) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + if not conversation_id: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + conversation = ConversationService.get_conversation( + app_model=app_model, user=user, conversation_id=conversation_id + ) + + if first_id: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) + + if not first_message: + raise FirstMessageNotExistsError() + + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) + else: + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) + + has_more = False + if len(history_messages) == limit: + current_page_first_message = history_messages[-1] + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) + + if rest_count > 0: + has_more = True + + if order == "asc": + history_messages = list(reversed(history_messages)) + + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) + + @classmethod + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + base_query = db.session.query(Message) + + if conversation_id is not None: + conversation = ConversationService.get_conversation( + app_model=app_model, user=user, conversation_id=conversation_id + ) + + base_query = base_query.filter(Message.conversation_id == conversation.id) + + if include_ids is not None: + base_query = base_query.filter(Message.id.in_(include_ids)) + + if last_id: + last_message = base_query.filter(Message.id == last_id).first() + + if not last_message: + raise LastMessageNotExistsError() + + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) + else: + history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() + + has_more = False + if len(history_messages) == limit: + current_page_first_message = history_messages[-1] + rest_count = base_query.filter( + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) + + @classmethod + def create_feedback( + cls, + *, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], + ): + if not user: + raise ValueError("user cannot be None") + + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) + + feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback + + if not rating and feedback: + db.session.delete(feedback) + elif rating and feedback: + feedback.rating = rating + feedback.content = content + elif not rating and not feedback: + raise ValueError("rating cannot be None when feedback not exists") + else: + feedback = MessageFeedback( + app_id=app_model.id, + conversation_id=message.conversation_id, + message_id=message.id, + rating=rating, + content=content, + from_source=("user" if isinstance(user, EndUser) else "admin"), + from_end_user_id=(user.id if isinstance(user, EndUser) else None), + from_account_id=(user.id if isinstance(user, Account) else None), + ) + db.session.add(feedback) + + db.session.commit() + + return feedback + + @classmethod + def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) + + if not message: + raise MessageNotExistsError() + + return message + + @classmethod + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: + if not user: + raise ValueError("user cannot be None") + + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) + + conversation = ConversationService.get_conversation( + app_model=app_model, conversation_id=message.conversation_id, user=user + ) + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != "normal": + raise ConversationCompletedError() + + model_manager = ModelManager() + + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow_service = WorkflowService() + if invoke_from == InvokeFrom.DEBUGGER: + workflow = workflow_service.get_draft_workflow(app_model=app_model) + else: + workflow = workflow_service.get_published_workflow(app_model=app_model) + + if workflow is None: + return [] + + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + if not app_config.additional_features.suggested_questions_after_answer: + raise SuggestedQuestionsAfterAnswerDisabledError() + + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, model_type=ModelType.LLM + ) + else: + if not conversation.override_model_configs: + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) + else: + conversation_override_model_configs = json.loads(conversation.override_model_configs) + app_model_config = AppModelConfig( + id=conversation.app_model_config_id, + app_id=app_model.id, + ) + + app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") + + suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict + if suggested_questions_after_answer.get("enabled", False) is False: + raise SuggestedQuestionsAfterAnswerDisabledError() + + model_instance = model_manager.get_model_instance( + tenant_id=app_model.tenant_id, + provider=app_model_config.model_dict["provider"], + model_type=ModelType.LLM, + model=app_model_config.model_dict["name"], + ) + + # get memory of conversation (read-only) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + histories = memory.get_history_prompt_text( + max_token_limit=3000, + message_limit=3, + ) + + with measure_time() as timer: + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( + tenant_id=app_model.tenant_id, histories=histories + ) + + # get tracing instance + trace_manager = TraceQueueManager(app_id=app_model.id) + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer + ) + ) + + return questions diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py new file mode 100644 index 0000000000000000000000000000000000000000..bacd3a8ec3d04f15a729cae7b03bd062e39551ac --- /dev/null +++ b/api/services/model_load_balancing_service.py @@ -0,0 +1,570 @@ +import datetime +import json +import logging +from json import JSONDecodeError +from typing import Optional, Union + +from constants import HIDDEN_VALUE +from core.entities.provider_configuration import ProviderConfiguration +from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from extensions.ext_database import db +from models.provider import LoadBalancingModelConfig + +logger = logging.getLogger(__name__) + + +class ModelLoadBalancingService: + def __init__(self) -> None: + self.provider_manager = ProviderManager() + + def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model load balancing + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) + + def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # disable model load balancing + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) + + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: + """ + Get load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type_enum = ModelType.value_of(model_type) + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=model_type_enum, + model=model, + ) + + is_load_balancing_enabled = False + if provider_model_setting and provider_model_setting.load_balancing_enabled: + is_load_balancing_enabled = True + + # Get load balancing configurations + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) + + if provider_configuration.custom_configuration.provider: + # check if the inherit configuration exists, + # inherit is represented for the provider or model custom credentials + inherit_config_exists = False + for load_balancing_config in load_balancing_configs: + if load_balancing_config.name == "__inherit__": + inherit_config_exists = True + break + + if not inherit_config_exists: + # Initialize the inherit configuration + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) + + # prepend the inherit configuration + load_balancing_configs.insert(0, inherit_config) + else: + # move the inherit configuration to the first + for i, load_balancing_config in enumerate(load_balancing_configs[:]): + if load_balancing_config.name == "__inherit__": + inherit_config = load_balancing_configs.pop(i) + load_balancing_configs.insert(0, inherit_config) + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + # fetch status and ttl for each config + datas = [] + for load_balancing_config in load_balancing_configs: + in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl( + tenant_id=tenant_id, + provider=provider, + model=model, + model_type=model_type_enum, + config_id=load_balancing_config.id, + ) + + try: + if load_balancing_config.encrypted_config: + credentials = json.loads(load_balancing_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get provider credential secret variables + credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + # decrypt credentials + for variable in credential_secret_variables: + if variable in credentials: + try: + credentials[variable] = encrypter.decrypt_token_with_decoding( + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa + ) + except ValueError: + pass + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas + ) + + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) + + return is_load_balancing_enabled, datas + + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: + """ + Get load balancing configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type_enum = ModelType.value_of(model_type) + + # Get load balancing configurations + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) + + if not load_balancing_model_config: + return None + + try: + if load_balancing_model_config.encrypted_config: + credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas + ) + + return { + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, + } + + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: + """ + Initialize the inherit configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Initialize the inherit configuration + inherit_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + name="__inherit__", + ) + db.session.add(inherit_config) + db.session.commit() + + return inherit_config + + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: + """ + Update load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param configs: load balancing configs + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type_enum = ModelType.value_of(model_type) + + if not isinstance(configs, list): + raise ValueError("Invalid load balancing configs") + + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) + + # id as key, config as value + current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} + updated_config_ids = set() + + for config in configs: + if not isinstance(config, dict): + raise ValueError("Invalid load balancing config") + + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") + + if not name: + raise ValueError("Invalid load balancing config name") + + if enabled is None: + raise ValueError("Invalid load balancing config enabled") + + # is config exists + if config_id: + config_id = str(config_id) + + if config_id not in current_load_balancing_configs_dict: + raise ValueError("Invalid load balancing config id: {}".format(config_id)) + + updated_config_ids.add(config_id) + + load_balancing_config = current_load_balancing_configs_dict[config_id] + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: + raise ValueError("Load balancing config name {} already exists".format(name)) + + if credentials: + if not isinstance(credentials, dict): + raise ValueError("Invalid load balancing config credentials") + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_config, + validate=False, + ) + + # update load balancing config + load_balancing_config.encrypted_config = json.dumps(credentials) + + load_balancing_config.name = name + load_balancing_config.enabled = enabled + load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + else: + # create load balancing config + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.name == name: + raise ValueError("Load balancing config name {} already exists".format(name)) + + if not credentials: + raise ValueError("Invalid load balancing config credentials") + + if not isinstance(credentials, dict): + raise ValueError("Invalid load balancing config credentials") + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + validate=False, + ) + + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials), + ) + + db.session.add(load_balancing_model_config) + db.session.commit() + + # get deleted config ids + deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids + for config_id in deleted_config_ids: + db.session.delete(current_load_balancing_configs_dict[config_id]) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: + """ + Validate load balancing credentials. + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: credentials + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type_enum = ModelType.value_of(model_type) + + load_balancing_model_config = None + if config_id: + # Get load balancing config + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) + + if not load_balancing_model_config: + raise ValueError(f"Load balancing config {config_id} does not exist.") + + # Validate custom provider config + self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_model_config, + ) + + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: + """ + Validate custom credentials. + :param tenant_id: workspace id + :param provider_configuration: provider configuration + :param model_type: model type + :param model: model name + :param credentials: credentials + :param load_balancing_model_config: load balancing model config + :param validate: validate credentials + :return: + """ + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get provider credential secret variables + provider_credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + if load_balancing_model_config: + try: + # fix origin data + if load_balancing_model_config.encrypted_config: + original_credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) + + if validate: + if isinstance(credential_schemas, ModelCredentialSchema): + credentials = model_provider_factory.model_credentials_validate( + provider=provider_configuration.provider.provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + else: + credentials = model_provider_factory.provider_credentials_validate( + provider=provider_configuration.provider.provider, credentials=credentials + ) + + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + return credentials + + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" + if provider_configuration.provider.model_credential_schema: + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema + else: + raise ValueError("No credential schema found") + + def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: + """ + Clear credentials cache. + :param tenant_id: workspace id + :param config_id: load balancing config id + :return: + """ + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + ) + + provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b10c5ad2d616e9d5f6d901c08a1c5df86f48df4a --- /dev/null +++ b/api/services/model_provider_service.py @@ -0,0 +1,563 @@ +import logging +import mimetypes +import os +from pathlib import Path +from typing import Optional, cast + +import requests +from flask import current_app + +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.model_runtime.entities.model_entities import ModelType, ParameterRule +from core.model_runtime.model_providers import model_provider_factory +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.provider_manager import ProviderManager +from models.provider import ProviderType +from services.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + DefaultModelResponse, + ModelWithProviderEntityResponse, + ProviderResponse, + ProviderWithModelsResponse, + SimpleProviderEntityResponse, + SystemConfigurationResponse, +) + +logger = logging.getLogger(__name__) + + +class ModelProviderService: + """ + Model Provider Service + """ + + def __init__(self) -> None: + self.provider_manager = ProviderManager() + + def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: + """ + get provider list. + + :param tenant_id: workspace id + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + provider_responses = [] + for provider_configuration in provider_configurations.values(): + if model_type: + model_type_entity = ModelType.value_of(model_type) + if model_type_entity not in provider_configuration.provider.supported_model_types: + continue + + provider_response = ProviderResponse( + provider=provider_configuration.provider.provider, + label=provider_configuration.provider.label, + description=provider_configuration.provider.description, + icon_small=provider_configuration.provider.icon_small, + icon_large=provider_configuration.provider.icon_large, + background=provider_configuration.provider.background, + help=provider_configuration.provider.help, + supported_model_types=provider_configuration.provider.supported_model_types, + configurate_methods=provider_configuration.provider.configurate_methods, + provider_credential_schema=provider_configuration.provider.provider_credential_schema, + model_credential_schema=provider_configuration.provider.model_credential_schema, + preferred_provider_type=provider_configuration.preferred_provider_type, + custom_configuration=CustomConfigurationResponse( + status=CustomConfigurationStatus.ACTIVE + if provider_configuration.is_custom_configuration_available() + else CustomConfigurationStatus.NO_CONFIGURE + ), + system_configuration=SystemConfigurationResponse( + enabled=provider_configuration.system_configuration.enabled, + current_quota_type=provider_configuration.system_configuration.current_quota_type, + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), + ) + + provider_responses.append(provider_response) + + return provider_responses + + def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]: + """ + get provider models. + For the model provider page, + only supports passing in a single provider to query the list of supported models. + + :param tenant_id: + :param provider: + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider available models + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] + + def get_provider_credentials(self, tenant_id: str, provider: str): + """ + get provider credentials. + """ + provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + return provider_configuration.get_custom_credentials(obfuscated=True) + + def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: + """ + validate provider credentials. + + :param tenant_id: + :param provider: + :param credentials: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + provider_configuration.custom_credentials_validate(credentials) + + def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: + """ + save custom provider config. + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Add or update custom provider credentials. + provider_configuration.add_or_update_custom_credentials(credentials) + + def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: + """ + remove custom provider config. + + :param tenant_id: workspace id + :param provider: provider name + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Remove custom provider credentials. + provider_configuration.delete_custom_credentials() + + def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str): + """ + get model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Get model custom credentials from ProviderModel if exists + return provider_configuration.get_custom_model_credentials( + model_type=ModelType.value_of(model_type), model=model, obfuscated=True + ) + + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: + """ + validate model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Validate model credentials + provider_configuration.custom_model_credentials_validate( + model_type=ModelType.value_of(model_type), model=model, credentials=credentials + ) + + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: + """ + save model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Add or update custom model credentials + provider_configuration.add_or_update_custom_model_credentials( + model_type=ModelType.value_of(model_type), model=model, credentials=credentials + ) + + def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: + """ + remove model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Remove custom model credentials + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) + + def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: + """ + get models by model type. + + :param tenant_id: workspace id + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider available models + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) + + # Group models by provider + provider_models: dict[str, list[ModelWithProviderEntity]] = {} + for model in models: + if model.provider.provider not in provider_models: + provider_models[model.provider.provider] = [] + + if model.deprecated: + continue + + if model.status != ModelStatus.ACTIVE: + continue + + provider_models[model.provider.provider].append(model) + + # convert to ProviderWithModelsResponse list + providers_with_models: list[ProviderWithModelsResponse] = [] + for provider, models in provider_models.items(): + if not models: + continue + + first_model = models[0] + + providers_with_models.append( + ProviderWithModelsResponse( + provider=provider, + label=first_model.provider.label, + icon_small=first_model.provider.icon_small, + icon_large=first_model.provider.icon_large, + status=CustomConfigurationStatus.ACTIVE, + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], + ) + ) + + return providers_with_models + + def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]: + """ + get model parameter rules. + Only supports LLM. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Get model instance of LLM + model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM) + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # fetch credentials + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) + + if not credentials: + return [] + + # Call get_parameter_rules method of model instance to get model parameter rules + return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials)) + + def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: + """ + get default model of model type. + + :param tenant_id: workspace id + :param model_type: model type + :return: + """ + model_type_enum = ModelType.value_of(model_type) + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + try: + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), + ) + if result + else None + ) + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None + + def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: + """ + update default model of model type. + + :param tenant_id: workspace id + :param model_type: model type + :param provider: provider name + :param model: model name + :return: + """ + model_type_enum = ModelType.value_of(model_type) + self.provider_manager.update_default_model_record( + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model + ) + + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: + """ + get model provider icon. + + :param provider: provider name + :param icon_type: icon type (icon_small or icon_large) + :param lang: language (zh_Hans or en_US) + :return: + """ + provider_instance = model_provider_factory.get_provider_instance(provider) + provider_schema = provider_instance.get_provider_schema() + file_name: str | None = None + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + + if lang.lower() == "zh_hans": + file_name = provider_schema.icon_small.zh_Hans + else: + file_name = provider_schema.icon_small.en_US + else: + if not provider_schema.icon_large: + raise ValueError(f"Provider {provider} does not have large icon.") + + if lang.lower() == "zh_hans": + file_name = provider_schema.icon_large.zh_Hans + else: + file_name = provider_schema.icon_large.en_US + if not file_name: + return None, None + + root_path = current_app.root_path + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) + file_path = os.path.join(provider_instance_path, "_assets") + file_path = os.path.join(file_path, file_name) + + if not os.path.exists(file_path): + return None, None + + mimetype, _ = mimetypes.guess_type(file_path) + mimetype = mimetype or "application/octet-stream" + + # read binary from file + byte_data = Path(file_path).read_bytes() + return byte_data, mimetype + + def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: + """ + switch preferred provider. + + :param tenant_id: workspace id + :param provider: provider name + :param preferred_provider_type: preferred provider type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Convert preferred_provider_type to ProviderType + preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Switch preferred provider type + provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) + + def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) + + def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) + + def free_quota_submit(self, tenant_id: str, provider: str): + api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") + api_url = api_base_url + "/api/v1/providers/apply" + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) + if not response.ok: + logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") + raise ValueError(f"Error: {response.status_code} ") + + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") + + rst = response.json() + + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} + else: + return {"type": rst["type"], "result": "success"} + + def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): + api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") + api_url = api_base_url + "/api/v1/providers/qualification-verify" + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} + if token: + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) + if not response.ok: + logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") + raise ValueError(f"Error: {response.status_code} ") + + rst = response.json() + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") + + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} + else: + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..082afeed89a5e4125d6325c5eb52b2a3307ea5e1 --- /dev/null +++ b/api/services/moderation_service.py @@ -0,0 +1,23 @@ +from typing import Optional + +from core.moderation.factory import ModerationFactory, ModerationOutputsResult +from extensions.ext_database import db +from models.model import App, AppModelConfig + + +class ModerationService: + def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: + app_model_config: Optional[AppModelConfig] = None + + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) + + if not app_model_config: + raise ValueError("app model config not found") + + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] + + moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) + return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8b64bcd5d34479b7ea293e0a443c34f52e966c --- /dev/null +++ b/api/services/operation_service.py @@ -0,0 +1,29 @@ +import os + +import requests + + +class OperationService: + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") + + @classmethod + def _send_request(cls, method, endpoint, json=None, params=None): + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} + + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers) + + return response.json() + + @classmethod + def record_utm(cls, tenant_id: str, utm_info: dict): + params = { + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), + } + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py new file mode 100644 index 0000000000000000000000000000000000000000..78340d2bcc2cee04d04cbcec1b791f73a6088fd7 --- /dev/null +++ b/api/services/ops_service.py @@ -0,0 +1,199 @@ +from typing import Optional + +from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map +from extensions.ext_database import db +from models.model import App, TraceAppConfig + + +class OpsService: + @classmethod + def get_tracing_app_config(cls, app_id: str, tracing_provider: str): + """ + Get tracing app config + :param app_id: app id + :param tracing_provider: tracing provider + :return: + """ + trace_config_data: Optional[TraceAppConfig] = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) + + if not trace_config_data: + return None + + # decrypt_token and obfuscated_token + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) + + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): + try: + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update( + { + "project_url": "{host}/project/{key}".format( + host=decrypt_tracing_config.get("host"), key=project_key + ) + } + ) + except Exception: + new_decrypt_tracing_config.update( + {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} + ) + + if tracing_provider == "langsmith" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) + + if tracing_provider == "opik" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"}) + + trace_config_data.tracing_config = new_decrypt_tracing_config + return trace_config_data.to_dict() + + @classmethod + def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + """ + Create tracing app config + :param app_id: app id + :param tracing_provider: tracing provider + :param tracing_config: tracing config + :return: + """ + if tracing_provider not in provider_config_map and tracing_provider: + return {"error": f"Invalid tracing provider: {tracing_provider}"} + + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore + if key in tracing_config and tracing_config[key] == "": + tracing_config[key] = getattr(default_config_instance, key, None) + + # api check + if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): + return {"error": "Invalid Credentials"} + + # get project url + if tracing_provider == "langfuse": + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) + elif tracing_provider in ("langsmith", "opik"): + project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) + else: + project_url = None + + # check if trace config already exists + trace_config_data: Optional[TraceAppConfig] = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) + + if trace_config_data: + return None + + # get tenant id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id + tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if project_url: + tracing_config["project_url"] = project_url + trace_config_data = TraceAppConfig( + app_id=app_id, + tracing_provider=tracing_provider, + tracing_config=tracing_config, + ) + db.session.add(trace_config_data) + db.session.commit() + + return {"result": "success"} + + @classmethod + def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + """ + Update tracing app config + :param app_id: app id + :param tracing_provider: tracing provider + :param tracing_config: tracing config + :return: + """ + if tracing_provider not in provider_config_map: + raise ValueError(f"Invalid tracing provider: {tracing_provider}") + + # check if trace config already exists + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) + + if not current_trace_config: + return None + + # get tenant id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id + tracing_config = OpsTraceManager.encrypt_tracing_config( + tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config + ) + + # api check + # decrypt_token + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if not OpsTraceManager.check_trace_config_is_effective(decrypt_tracing_config, tracing_provider): + raise ValueError("Invalid Credentials") + + current_trace_config.tracing_config = tracing_config + db.session.commit() + + return current_trace_config.to_dict() + + @classmethod + def delete_tracing_app_config(cls, app_id: str, tracing_provider: str): + """ + Delete tracing app config + :param app_id: app id + :param tracing_provider: tracing provider + :return: + """ + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) + + if not trace_config: + return None + + db.session.delete(trace_config) + db.session.commit() + + return True diff --git a/api/services/recommend_app/__init__.py b/api/services/recommend_app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/recommend_app/buildin/__init__.py b/api/services/recommend_app/buildin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..523aebeed52a4e22b82ee63de0538dcdb1147cec --- /dev/null +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -0,0 +1,64 @@ +import json +from os import path +from pathlib import Path +from typing import Optional + +from flask import current_app + +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from buildin, the location is constants/recommended_apps.json + """ + + builtin_data: Optional[dict] = None + + def get_type(self) -> str: + return RecommendAppType.BUILDIN + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_builtin(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_builtin(app_id) + return result + + @classmethod + def _get_builtin_data(cls) -> dict: + """ + Get builtin data. + :return: + """ + if cls.builtin_data: + return cls.builtin_data + + root_path = current_app.root_path + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") + ) + + return cls.builtin_data or {} + + @classmethod + def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: + """ + Fetch recommended apps from builtin. + :param language: language + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) + + @classmethod + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from builtin. + :param app_id: App ID + :return: + """ + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/database/__init__.py b/api/services/recommend_app/database/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..3295516cce66f3f58a1b1ddb1368e73fb3e4d0e5 --- /dev/null +++ b/api/services/recommend_app/database/database_retrieval.py @@ -0,0 +1,105 @@ +from typing import Optional + +from constants.languages import languages +from extensions.ext_database import db +from models.model import App, RecommendedApp +from services.app_dsl_service import AppDslService +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_db(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_db(app_id) + return result + + def get_type(self) -> str: + return RecommendAppType.DATABASE + + @classmethod + def fetch_recommended_apps_from_db(cls, language: str) -> dict: + """ + Fetch recommended apps from db. + :param language: language + :return: + """ + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) + + if len(recommended_apps) == 0: + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) + + categories = set() + recommended_apps_result = [] + for recommended_app in recommended_apps: + app = recommended_app.app + if not app or not app.is_public: + continue + + site = app.site + if not site: + continue + + recommended_app_result = { + "id": recommended_app.id, + "app": recommended_app.app, + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) + + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + + @classmethod + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from db. + :param app_id: App ID + :return: + """ + # is in public recommended list + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) + + if not recommended_app: + return None + + # get app detail + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: + return None + + return { + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), + } diff --git a/api/services/recommend_app/recommend_app_base.py b/api/services/recommend_app/recommend_app_base.py new file mode 100644 index 0000000000000000000000000000000000000000..00c037710e869cd1807686c544b6c7cfe31cf818 --- /dev/null +++ b/api/services/recommend_app/recommend_app_base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + + +class RecommendAppRetrievalBase(ABC): + """Interface for recommend app retrieval.""" + + @abstractmethod + def get_recommended_apps_and_categories(self, language: str) -> dict: + raise NotImplementedError + + @abstractmethod + def get_recommend_app_detail(self, app_id: str): + raise NotImplementedError + + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError diff --git a/api/services/recommend_app/recommend_app_factory.py b/api/services/recommend_app/recommend_app_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..e53667c0b06dd66160872ed63f2ec197ca6ccd40 --- /dev/null +++ b/api/services/recommend_app/recommend_app_factory.py @@ -0,0 +1,23 @@ +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class RecommendAppRetrievalFactory: + @staticmethod + def get_recommend_app_factory(mode: str) -> type[RecommendAppRetrievalBase]: + match mode: + case RecommendAppType.REMOTE: + return RemoteRecommendAppRetrieval + case RecommendAppType.DATABASE: + return DatabaseRecommendAppRetrieval + case RecommendAppType.BUILDIN: + return BuildInRecommendAppRetrieval + case _: + raise ValueError(f"invalid fetch recommended apps mode: {mode}") + + @staticmethod + def get_buildin_recommend_app_retrieval(): + return BuildInRecommendAppRetrieval diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py new file mode 100644 index 0000000000000000000000000000000000000000..e60e435b3a02f8938cf7a9fe86d6755a0d8ef69d --- /dev/null +++ b/api/services/recommend_app/recommend_app_type.py @@ -0,0 +1,7 @@ +from enum import StrEnum + + +class RecommendAppType(StrEnum): + REMOTE = "remote" + BUILDIN = "builtin" + DATABASE = "db" diff --git a/api/services/recommend_app/remote/__init__.py b/api/services/recommend_app/remote/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..80e1aefc01da8574c866eb3937695961cab32f60 --- /dev/null +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -0,0 +1,71 @@ +import logging +from typing import Optional + +import requests + +from configs import dify_config +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + +logger = logging.getLogger(__name__) + + +class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from dify official + """ + + def get_recommend_app_detail(self, app_id: str): + try: + result = self.fetch_recommended_app_detail_from_dify_official(app_id) + except Exception as e: + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) + return result + + def get_recommended_apps_and_categories(self, language: str) -> dict: + try: + result = self.fetch_recommended_apps_from_dify_official(language) + except Exception as e: + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) + return result + + def get_type(self) -> str: + return RecommendAppType.REMOTE + + @classmethod + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from dify official. + :param app_id: App ID + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps/{app_id}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + return None + data: dict = response.json() + return data + + @classmethod + def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + """ + Fetch recommended apps from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps?language={language}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") + + result: dict = response.json() + + if "categories" in result: + result["categories"] = sorted(result["categories"]) + + return result diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py new file mode 100644 index 0000000000000000000000000000000000000000..54c58455155c03845955d54e472dfbca854db099 --- /dev/null +++ b/api/services/recommended_app_service.py @@ -0,0 +1,37 @@ +from typing import Optional + +from configs import dify_config +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory + + +class RecommendedAppService: + @classmethod + def get_recommended_apps_and_categories(cls, language: str) -> dict: + """ + Get recommended apps and categories. + :param language: language + :return: + """ + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommended_apps_and_categories(language) + if not result.get("recommended_apps") and language != "en-US": + result = ( + RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin( + "en-US" + ) + ) + + return result + + @classmethod + def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: + """ + Get recommend app detail. + :param app_id: app id + :return: + """ + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result: dict = retrieval_instance.get_recommend_app_detail(app_id) + return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb8700117e6f3cf60b56b19be3ce247b5107591 --- /dev/null +++ b/api/services/saved_message_service.py @@ -0,0 +1,83 @@ +from typing import Optional, Union + +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.account import Account +from models.model import App, EndUser +from models.web import SavedMessage +from services.message_service import MessageService + + +class SavedMessageService: + @classmethod + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) + message_ids = [sm.message_id for sm in saved_messages] + + return MessageService.pagination_by_last_id( + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids + ) + + @classmethod + def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) + + if saved_message: + return + + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) + + saved_message = SavedMessage( + app_id=app_model.id, + message_id=message.id, + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, + ) + + db.session.add(saved_message) + db.session.commit() + + @classmethod + def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) + + if not saved_message: + return + + db.session.delete(saved_message) + db.session.commit() diff --git a/api/services/tag_service.py b/api/services/tag_service.py new file mode 100644 index 0000000000000000000000000000000000000000..9600601633cddbbafccab51793f44cd7774dd464 --- /dev/null +++ b/api/services/tag_service.py @@ -0,0 +1,158 @@ +import uuid +from typing import Optional + +from flask_login import current_user # type: ignore +from sqlalchemy import func +from werkzeug.exceptions import NotFound + +from extensions.ext_database import db +from models.dataset import Dataset +from models.model import App, Tag, TagBinding + + +class TagService: + @staticmethod + def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list: + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) + ) + if keyword: + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) + results: list = query.order_by(Tag.created_at.desc()).all() + return results + + @staticmethod + def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) + if not tags: + return [] + tag_ids = [tag.id for tag in tags] + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) + if not tag_bindings: + return [] + results = [tag_binding.target_id for tag_binding in tag_bindings] + return results + + @staticmethod + def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) + + return tags or [] + + @staticmethod + def save_tags(args: dict) -> Tag: + tag = Tag( + id=str(uuid.uuid4()), + name=args["name"], + type=args["type"], + created_by=current_user.id, + tenant_id=current_user.current_tenant_id, + ) + db.session.add(tag) + db.session.commit() + return tag + + @staticmethod + def update_tags(args: dict, tag_id: str) -> Tag: + tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + if not tag: + raise NotFound("Tag not found") + tag.name = args["name"] + db.session.commit() + return tag + + @staticmethod + def get_tag_binding_count(tag_id: str) -> int: + count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count() + return count + + @staticmethod + def delete_tag(tag_id: str): + tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + if not tag: + raise NotFound("Tag not found") + db.session.delete(tag) + # delete tag binding + tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all() + if tag_bindings: + for tag_binding in tag_bindings: + db.session.delete(tag_binding) + db.session.commit() + + @staticmethod + def save_tag_binding(args): + # check if target exists + TagService.check_target_exists(args["type"], args["target_id"]) + # save tag binding + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) + if tag_binding: + continue + new_tag_binding = TagBinding( + tag_id=tag_id, + target_id=args["target_id"], + tenant_id=current_user.current_tenant_id, + created_by=current_user.id, + ) + db.session.add(new_tag_binding) + db.session.commit() + + @staticmethod + def delete_tag_binding(args): + # check if target exists + TagService.check_target_exists(args["type"], args["target_id"]) + # delete tag binding + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) + if tag_bindings: + db.session.delete(tag_bindings) + db.session.commit() + + @staticmethod + def check_target_exists(type: str, target_id: str): + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) + if not dataset: + raise NotFound("Dataset not found") + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) + if not app: + raise NotFound("App not found") + else: + raise NotFound("Invalid binding type") diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py new file mode 100644 index 0000000000000000000000000000000000000000..988f9df927e4f72287a886906280e33bbc419127 --- /dev/null +++ b/api/services/tools/api_tools_manage_service.py @@ -0,0 +1,469 @@ +import json +import logging +from collections.abc import Mapping +from typing import Any, Optional, cast + +from httpx import get + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ApiProviderSchemaType, + ToolCredentialsOption, + ToolProviderCredentials, +) +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolConfigurationManager +from core.tools.utils.parser import ApiBasedToolSchemaParser +from extensions.ext_database import db +from models.tools import ApiToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ApiToolManageService: + @staticmethod + def parser_api_schema(schema: str) -> Mapping[str, Any]: + """ + parse api schema to tool bundle + """ + try: + warnings: dict[str, str] = {} + try: + tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) + except Exception as e: + raise ValueError(f"invalid schema: {str(e)}") + + credentials_schema = [ + ToolProviderCredentials( + name="auth_type", + type=ToolProviderCredentials.CredentialsType.SELECT, + required=True, + default="none", + options=[ + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), + ], + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), + ), + ToolProviderCredentials( + name="api_key_header", + type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + required=False, + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), + ), + ToolProviderCredentials( + name="api_key_value", + type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + required=False, + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", + ), + ] + + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), + ) + except Exception as e: + raise ValueError(f"invalid schema: {str(e)}") + + @staticmethod + def convert_schema_to_tool_bundles( + schema: str, extra_info: Optional[dict] = None + ) -> tuple[list[ApiToolBundle], str]: + """ + convert schema to tool bundles + + :return: the list of tool bundles, description + """ + try: + tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) + return tool_bundles + except Exception as e: + raise ValueError(f"invalid schema: {str(e)}") + + @staticmethod + def create_api_tool_provider( + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], + ): + """ + create api tool provider + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f"invalid schema type {schema}") + + provider_name = provider_name.strip() + + # check if the provider exists + provider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) + + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") + + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") + + # create db provider + db_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str={}, + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) + + if "auth_type" not in credentials: + raise ValueError("auth_type is required") + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + + # create provider entity + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) + + # encrypt credentials + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) + db_provider.credentials_str = json.dumps(encrypted_credentials) + + db.session.add(db_provider) + db.session.commit() + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) + + return {"result": "success"} + + @staticmethod + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): + """ + get api tool provider remote schema + """ + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + "Accept": "*/*", + } + + try: + response = get(url, headers=headers, timeout=10) + if response.status_code != 200: + raise ValueError(f"Got status code {response.status_code}") + schema = response.text + + # try to parse schema, avoid SSRF attack + ApiToolManageService.parser_api_schema(schema) + except Exception as e: + logger.exception("parse api schema error") + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} + + @staticmethod + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]: + """ + list api tool provider tools + """ + provider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) + + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") + + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(controller) + + return [ + ToolTransformService.tool_to_user_tool( + tool_bundle, + labels=labels, + ) + for tool_bundle in provider.tools + ] + + @staticmethod + def update_api_tool_provider( + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], + ): + """ + update api tool provider + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f"invalid schema type {schema}") + + provider_name = provider_name.strip() + + # check if the provider exists + provider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) + + if provider is None: + raise ValueError(f"api provider {provider_name} does not exists") + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer + + if "auth_type" not in credentials: + raise ValueError("auth_type is required") + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) + + # get original credentials if exists + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + + original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = tool_configuration.encrypt_tool_credentials(credentials) + provider.credentials_str = json.dumps(credentials) + + db.session.add(provider) + db.session.commit() + + # delete cache + tool_configuration.delete_tool_credentials_cache() + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels) + + return {"result": "success"} + + @staticmethod + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): + """ + delete tool provider + """ + provider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) + + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") + + db.session.delete(provider) + db.session.commit() + + return {"result": "success"} + + @staticmethod + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + """ + get api tool provider + """ + return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) + + @staticmethod + def test_api_tool_preview( + tenant_id: str, + provider_name: str, + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, + ): + """ + test api tool before adding api tool provider + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f"invalid schema type {schema_type}") + + try: + tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) + except Exception as e: + raise ValueError("invalid schema") + + # get tool bundle + tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) + if tool_bundle is None: + raise ValueError(f"invalid tool name {tool_name}") + + db_provider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) + + if not db_provider: + # create a fake db provider + db_provider = ApiToolProvider( + tenant_id="", + user_id="", + name="", + icon="", + schema=schema, + description="", + schema_type_str=ApiProviderSchemaType.OPENAPI.value, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str=json.dumps(credentials), + ) + + if "auth_type" not in credentials: + raise ValueError("auth_type is required") + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + + # create provider entity + provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) + + # decrypt credentials + if db_provider.id: + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) + # check if the credential has changed, save the original credential + masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = decrypted_credentials[name] + + try: + provider_controller.validate_credentials_format(credentials) + # get tool + tool = provider_controller.get_tool(tool_name) + runtime_tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) + result = runtime_tool.validate_credentials(credentials, parameters) + except Exception as e: + return {"error": str(e)} + + return {"result": result or "empty response"} + + @staticmethod + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: + """ + list api tools + """ + # get all api providers + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) + + result: list[UserToolProvider] = [] + + for provider in db_providers: + # convert provider controller to user provider + provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + labels = ToolLabelManager.get_tool_labels(provider_controller) + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller, db_provider=provider, decrypt_credentials=True + ) + user_provider.labels = labels + + # add icon + ToolTransformService.repack_provider(user_provider) + + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) + + for tool in tools or []: + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) + + result.append(user_provider) + + return result diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py new file mode 100644 index 0000000000000000000000000000000000000000..21adbb0074724eeada9424f7a1c3ec7bfdffd867 --- /dev/null +++ b/api/services/tools/builtin_tools_manage_service.py @@ -0,0 +1,248 @@ +import json +import logging +from pathlib import Path + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from core.helper.position_helper import is_filtered +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.provider.builtin._positions import BuiltinToolProviderSort +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolConfigurationManager +from extensions.ext_database import db +from models.tools import BuiltinToolProvider +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class BuiltinToolManageService: + @staticmethod + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: + """ + list builtin tool provider tools + """ + provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) + tools = provider_controller.get_tools() + + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) + # check if user has added the provider + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) + + credentials = {} + if builtin_provider is not None: + # get credentials + credentials = builtin_provider.credentials + credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) + + result: list[UserTool] = [] + for tool in tools or []: + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) + + return result + + @staticmethod + def list_builtin_provider_credentials_schema(provider_name): + """ + list builtin provider credentials schema + + :return: the list of tool providers + """ + provider = ToolManager.get_builtin_provider(provider_name) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) + + @staticmethod + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): + """ + update builtin tool provider + """ + # get if the provider exists + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + provider = session.scalar(stmt) + + try: + # get provider + provider_controller = ToolManager.get_builtin_provider(provider_name) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # get original credentials if exists + if provider is not None: + original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + # validate credentials + provider_controller.validate_credentials(credentials) + # encrypt credentials + credentials = tool_configuration.encrypt_tool_credentials(credentials) + except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: + raise ValueError(str(e)) + + if provider is None: + # create provider + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + ) + + session.add(provider) + + else: + provider.encrypted_credentials = json.dumps(credentials) + + # delete cache + tool_configuration.delete_tool_credentials_cache() + + return {"result": "success"} + + @staticmethod + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + """ + get builtin tool provider credentials + """ + provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) + + if provider is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider.provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + credentials = tool_configuration.mask_tool_credentials(credentials) + return credentials + + @staticmethod + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + """ + delete tool provider + """ + provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) + + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") + + db.session.delete(provider) + db.session.commit() + + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider_name) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration.delete_tool_credentials_cache() + + return {"result": "success"} + + @staticmethod + def get_builtin_tool_provider_icon(provider: str): + """ + get tool provider icon and it's mimetype + """ + icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) + icon_bytes = Path(icon_path).read_bytes() + + return icon_bytes, mime_type + + @staticmethod + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: + """ + list builtin tools + """ + # get all builtin providers + provider_controllers = ToolManager.list_builtin_providers() + + # get all user added providers + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) + + # find provider + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) + + result: list[UserToolProvider] = [] + + for provider_controller in provider_controllers: + try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + if provider_controller.identity is None: + continue + + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.identity.name), + decrypt_credentials=True, + ) + + # add icon + ToolTransformService.repack_provider(user_builtin_provider) + + tools = provider_controller.get_tools() + for tool in tools or []: + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) + + result.append(user_builtin_provider) + except Exception as e: + raise e + + return BuiltinToolProviderSort.sort(result) diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py new file mode 100644 index 0000000000000000000000000000000000000000..35e58b5adec58f1ad3b948eda72a59d7f9540ad0 --- /dev/null +++ b/api/services/tools/tool_labels_service.py @@ -0,0 +1,8 @@ +from core.tools.entities.tool_entities import ToolLabel +from core.tools.entities.values import default_tool_labels + + +class ToolLabelsService: + @classmethod + def list_tool_labels(cls) -> list[ToolLabel]: + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py new file mode 100644 index 0000000000000000000000000000000000000000..1c67f7648ca99f843e7b9b73fd8295d86fb1b80a --- /dev/null +++ b/api/services/tools/tools_manage_service.py @@ -0,0 +1,26 @@ +import logging + +from core.tools.entities.api_entities import UserToolProviderTypeLiteral +from core.tools.tool_manager import ToolManager +from services.tools.tools_transform_service import ToolTransformService + +logger = logging.getLogger(__name__) + + +class ToolCommonService: + @staticmethod + def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): + """ + list tool providers + + :return: the list of tool providers + """ + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) + + # add icon + for provider in providers: + ToolTransformService.repack_provider(provider) + + result = [provider.to_dict() for provider in providers] + + return result diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3a45be0da1c4569ec7f453f61ebd657745ee13 --- /dev/null +++ b/api/services/tools/tools_transform_service.py @@ -0,0 +1,290 @@ +import json +import logging +from typing import Optional, Union, cast + +from configs import dify_config +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderCredentials, + ToolProviderType, +) +from core.tools.provider.api_tool_provider import ApiToolProviderController +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool +from core.tools.utils.configuration import ToolConfigurationManager +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider + +logger = logging.getLogger(__name__) + + +class ToolTransformService: + @staticmethod + def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: + """ + get tool provider icon url + """ + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + + if provider_type == ToolProviderType.BUILT_IN.value: + return url_prefix + "builtin/" + provider_name + "/icon" + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: + try: + return cast(dict, json.loads(icon)) + except: + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + + @staticmethod + def repack_provider(provider: Union[dict, UserToolProvider]): + """ + repack provider + + :param provider: the provider dict + """ + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] + ) + elif isinstance(provider, UserToolProvider): + provider.icon = cast( + str, + ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + ), + ) + + @staticmethod + def builtin_provider_to_user_provider( + provider_controller: BuiltinToolProviderController, + db_provider: Optional[BuiltinToolProvider], + decrypt_credentials: bool = True, + ) -> UserToolProvider: + """ + convert provider controller to user provider + """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + + result = UserToolProvider( + id=provider_controller.identity.name, + author=provider_controller.identity.author, + name=provider_controller.identity.name, + description=I18nObject( + en_US=provider_controller.identity.description.en_US, + zh_Hans=provider_controller.identity.description.zh_Hans, + pt_BR=provider_controller.identity.description.pt_BR, + ja_JP=provider_controller.identity.description.ja_JP, + ), + icon=provider_controller.identity.icon, + label=I18nObject( + en_US=provider_controller.identity.label.en_US, + zh_Hans=provider_controller.identity.label.zh_Hans, + pt_BR=provider_controller.identity.label.pt_BR, + ja_JP=provider_controller.identity.label.ja_JP, + ), + type=ToolProviderType.BUILT_IN, + masked_credentials={}, + is_team_authorization=False, + tools=[], + labels=provider_controller.tool_labels, + ) + + # get credentials schema + schema = provider_controller.get_credentials_schema() + for name, value in schema.items(): + assert result.masked_credentials is not None, "masked credentials is None" + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type)) + + # check if the provider need credentials + if not provider_controller.need_credentials: + result.is_team_authorization = True + result.allow_delete = False + elif db_provider: + result.is_team_authorization = True + + if decrypt_credentials: + credentials = db_provider.credentials + + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, provider_controller=provider_controller + ) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + + result.masked_credentials = masked_credentials + result.original_credentials = decrypted_credentials + + return result + + @staticmethod + def api_provider_to_controller( + db_provider: ApiToolProvider, + ) -> ApiToolProviderController: + """ + convert provider controller to user provider + """ + # package tool provider controller + controller = ApiToolProviderController.from_db( + db_provider=db_provider, + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, + ) + + return controller + + @staticmethod + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: + """ + convert provider controller to provider + """ + return WorkflowToolProviderController.from_db(db_provider) + + @staticmethod + def workflow_provider_to_user_provider( + provider_controller: WorkflowToolProviderController, labels: Optional[list[str]] = None + ): + """ + convert provider controller to user provider + """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + + return UserToolProvider( + id=provider_controller.provider_id, + author=provider_controller.identity.author, + name=provider_controller.identity.name, + description=I18nObject( + en_US=provider_controller.identity.description.en_US, + zh_Hans=provider_controller.identity.description.zh_Hans, + ), + icon=provider_controller.identity.icon, + label=I18nObject( + en_US=provider_controller.identity.label.en_US, + zh_Hans=provider_controller.identity.label.zh_Hans, + ), + type=ToolProviderType.WORKFLOW, + masked_credentials={}, + is_team_authorization=True, + tools=[], + labels=labels or [], + ) + + @staticmethod + def api_provider_to_user_provider( + provider_controller: ApiToolProviderController, + db_provider: ApiToolProvider, + decrypt_credentials: bool = True, + labels: Optional[list[str]] = None, + ) -> UserToolProvider: + """ + convert provider controller to user provider + """ + username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") + try: + username = db_provider.user.name + except Exception as e: + logger.exception(f"failed to get user name for api provider {db_provider.id}") + # add provider into providers + credentials = db_provider.credentials + result = UserToolProvider( + id=db_provider.id, + author=username, + name=db_provider.name, + description=I18nObject( + en_US=db_provider.description, + zh_Hans=db_provider.description, + ), + icon=db_provider.icon, + label=I18nObject( + en_US=db_provider.name, + zh_Hans=db_provider.name, + ), + type=ToolProviderType.API, + masked_credentials={}, + is_team_authorization=True, + tools=[], + labels=labels or [], + ) + + if decrypt_credentials: + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, provider_controller=provider_controller + ) + + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + + result.masked_credentials = masked_credentials + + return result + + @staticmethod + def tool_to_user_tool( + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: Optional[dict] = None, + tenant_id: Optional[str] = None, + labels: Optional[list[str]] = None, + ) -> UserTool: + """ + convert tool to user tool + """ + if isinstance(tool, Tool): + # fork tool runtime + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) + + # get tool parameters + parameters = tool.parameters or [] + # get tool runtime parameters + runtime_parameters = tool.get_runtime_parameters() + # override parameters + current_parameters = parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + if tool.identity is None: + raise ValueError("tool identity is None") + + return UserTool( + author=tool.identity.author, + name=tool.identity.name, + label=tool.identity.label, + description=tool.description.human if tool.description else "", # type: ignore + parameters=current_parameters, + labels=labels, + ) + if isinstance(tool, ApiToolBundle): + return UserTool( + author=tool.author, + name=tool.operation_id or "", + label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), + parameters=tool.parameters, + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py new file mode 100644 index 0000000000000000000000000000000000000000..69430de432b143b73e51f2fb730d4f7b48817cb8 --- /dev/null +++ b/api/services/tools/workflow_tools_manage_service.py @@ -0,0 +1,347 @@ +import json +from collections.abc import Mapping +from datetime import datetime +from typing import Any, Optional + +from sqlalchemy import or_ + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from extensions.ext_database import db +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.tools.tools_transform_service import ToolTransformService + + +class WorkflowToolManageService: + """ + Service class for managing workflow tools. + """ + + @staticmethod + def create_workflow_tool( + *, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[Mapping[str, Any]], + privacy_policy: str = "", + labels: Optional[list[str]] = None, + ) -> dict: + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) + + if existing_workflow_tool_provider is not None: + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + + app = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() + if app is None: + raise ValueError(f"App {workflow_app_id} not found") + + workflow = app.workflow + if workflow is None: + raise ValueError(f"Workflow not found for app {workflow_app_id}") + + workflow_tool_provider = WorkflowToolProvider( + tenant_id=tenant_id, + user_id=user_id, + app_id=workflow_app_id, + name=name, + label=label, + icon=json.dumps(icon), + description=description, + parameter_configuration=json.dumps(parameters), + privacy_policy=privacy_policy, + version=workflow.version, + ) + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels + ) + return {"result": "success"} + + @classmethod + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[Mapping[str, Any]], + privacy_policy: str = "", + labels: Optional[list[str]] = None, + ) -> dict: + """ + Update a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_tool_id: workflow tool id + :param name: name + :param label: label + :param icon: icon + :param description: description + :param parameters: parameters + :param privacy_policy: privacy policy + :param labels: labels + :return: the updated tool + """ + WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) + + # check if the name is unique + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) + + if existing_workflow_tool_provider is not None: + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: Optional[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) + + if workflow_tool_provider is None: + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: Optional[App] = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) + + if app is None: + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + + workflow: Optional[Workflow] = app.workflow + if workflow is None: + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() + + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) + + db.session.add(workflow_tool_provider) + db.session.commit() + + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels + ) + + return {"result": "success"} + + @classmethod + def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: + """ + List workflow tools. + :param user_id: the user id + :param tenant_id: the tenant id + :return: the list of tools + """ + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + + tools = [] + for provider in db_tools: + try: + tools.append(ToolTransformService.workflow_provider_to_controller(provider)) + except: + # skip deleted tools + pass + + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) + + result = [] + + for tool in tools: + user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=tool, labels=labels.get(tool.provider_id, []) + ) + ToolTransformService.repack_provider(user_tool_provider) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + continue + user_tool_provider.tools = [ + ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, [])) + ] + result.append(user_tool_provider) + + return result + + @classmethod + def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Delete a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + """ + db.session.query(WorkflowToolProvider).filter( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id + ).delete() + + db.session.commit() + + return {"result": "success"} + + @classmethod + def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: Optional[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) + + if db_tool is None: + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) + + if workflow_app is None: + raise ValueError(f"App {db_tool.app_id} not found") + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + + return { + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) + ), + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, + "privacy_policy": db_tool.privacy_policy, + } + + @classmethod + def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: + """ + Get a workflow tool. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the tool + """ + db_tool: Optional[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) + + if db_tool is None: + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) + + if workflow_app is None: + raise ValueError(f"App {db_tool.app_id} not found") + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_app_id} not found") + + return { + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) + ), + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, + "privacy_policy": db_tool.privacy_policy, + } + + @classmethod + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: + """ + List workflow tool provider tools. + :param user_id: the user id + :param tenant_id: the tenant id + :param workflow_app_id: the workflow app id + :return: the list of tools + """ + db_tool: Optional[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) + + if db_tool is None: + raise ValueError(f"Tool {workflow_tool_id} not found") + + tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + + return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))] diff --git a/api/services/vector_service.py b/api/services/vector_service.py new file mode 100644 index 0000000000000000000000000000000000000000..92422bf29dc121ce36a48b945c4f6b7f1670865f --- /dev/null +++ b/api/services/vector_service.py @@ -0,0 +1,217 @@ +from typing import Optional + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document +from extensions.ext_database import db +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.entities.knowledge_entities.knowledge_entities import ParentMode + + +class VectorService: + @classmethod + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str + ): + documents = [] + + for segment in segments: + if doc_form == IndexType.PARENT_CHILD_INDEX: + document = DatasetDocument.query.filter_by(id=segment.document_id).first() + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + else: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + documents.append(document) + if len(documents) > 0: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + + @classmethod + def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): + # update segment index task + + # format new index + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + vector.delete_by_ids([segment.index_node_id]) + vector.add_texts([document], duplicate_check=True) + + # update keyword index + keyword = Keyword(dataset) + keyword.delete_by_ids([segment.index_node_id]) + + # save keyword index + if keywords and len(keywords) > 0: + keyword.add_texts([document], keywords_list=[keywords]) + else: + keyword.add_texts([document]) + + @classmethod + def generate_child_chunks( + cls, + segment: DocumentSegment, + dataset_document: DatasetDocument, + dataset: Dataset, + embedding_model_instance: ModelInstance, + processing_rule: DatasetProcessRule, + regenerate: bool = False, + ): + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + if regenerate: + # delete child chunks + index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) + + # generate child chunks + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + # use full doc mode to generate segment's child chunk + processing_rule_dict = processing_rule.to_dict() + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + documents = index_processor.transform( + [document], + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule_dict, + tenant_id=dataset.tenant_id, + doc_language=dataset_document.doc_language, + ) + # save child chunks + if documents and documents[0].children: + index_processor.load(dataset, documents) + + for position, child_chunk in enumerate(documents[0].children, start=1): + child_segment = ChildChunk( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=dataset_document.id, + segment_id=segment.id, + position=position, + index_node_id=child_chunk.metadata["doc_id"], + index_node_hash=child_chunk.metadata["doc_hash"], + content=child_chunk.page_content, + word_count=len(child_chunk.page_content), + type="automatic", + created_by=dataset_document.created_by, + ) + db.session.add(child_segment) + db.session.commit() + + @classmethod + def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): + child_document = Document( + page_content=child_segment.content, + metadata={ + "doc_id": child_segment.index_node_id, + "doc_hash": child_segment.index_node_hash, + "document_id": child_segment.document_id, + "dataset_id": child_segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # save vector index + vector = Vector(dataset=dataset) + vector.add_texts([child_document], duplicate_check=True) + + @classmethod + def update_child_chunk_vector( + cls, + new_child_chunks: list[ChildChunk], + update_child_chunks: list[ChildChunk], + delete_child_chunks: list[ChildChunk], + dataset: Dataset, + ): + documents = [] + delete_node_ids = [] + for new_child_chunk in new_child_chunks: + new_child_document = Document( + page_content=new_child_chunk.content, + metadata={ + "doc_id": new_child_chunk.index_node_id, + "doc_hash": new_child_chunk.index_node_hash, + "document_id": new_child_chunk.document_id, + "dataset_id": new_child_chunk.dataset_id, + }, + ) + documents.append(new_child_document) + for update_child_chunk in update_child_chunks: + child_document = Document( + page_content=update_child_chunk.content, + metadata={ + "doc_id": update_child_chunk.index_node_id, + "doc_hash": update_child_chunk.index_node_hash, + "document_id": update_child_chunk.document_id, + "dataset_id": update_child_chunk.dataset_id, + }, + ) + documents.append(child_document) + delete_node_ids.append(update_child_chunk.index_node_id) + for delete_child_chunk in delete_child_chunks: + delete_node_ids.append(delete_child_chunk.index_node_id) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + if delete_node_ids: + vector.delete_by_ids(delete_node_ids) + if documents: + vector.add_texts(documents, duplicate_check=True) + + @classmethod + def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): + vector = Vector(dataset=dataset) + vector.delete_by_ids([child_chunk.index_node_id]) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..f698ed3084bdac2bb32cf36273a2e92fccfb6b5f --- /dev/null +++ b/api/services/web_conversation_service.py @@ -0,0 +1,113 @@ +from typing import Optional, Union + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.account import Account +from models.model import App, EndUser +from models.web import PinnedConversation +from services.conversation_service import ConversationService + + +class WebConversationService: + @classmethod + def pagination_by_last_id( + cls, + *, + session: Session, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") + include_ids = None + exclude_ids = None + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + ) + pinned_conversation_ids = session.scalars(stmt).all() + + if pinned: + include_ids = pinned_conversation_ids + else: + exclude_ids = pinned_conversation_ids + + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=user, + last_id=last_id, + limit=limit, + invoke_from=invoke_from, + include_ids=include_ids, + exclude_ids=exclude_ids, + sort_by=sort_by, + ) + + @classmethod + def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) + + if pinned_conversation: + return + + conversation = ConversationService.get_conversation( + app_model=app_model, conversation_id=conversation_id, user=user + ) + + pinned_conversation = PinnedConversation( + app_id=app_model.id, + conversation_id=conversation.id, + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, + ) + + db.session.add(pinned_conversation) + db.session.commit() + + @classmethod + def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) + + if not pinned_conversation: + return + + db.session.delete(pinned_conversation) + db.session.commit() diff --git a/api/services/website_service.py b/api/services/website_service.py new file mode 100644 index 0000000000000000000000000000000000000000..85d32c9e8aed3264a0a7d5a6673ade43b9d6ebdb --- /dev/null +++ b/api/services/website_service.py @@ -0,0 +1,227 @@ +import datetime +import json +from typing import Any + +import requests +from flask_login import current_user # type: ignore + +from core.helper import encrypter +from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class WebsiteService: + @classmethod + def document_create_args_validate(cls, args: dict): + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") + + @classmethod + def crawl_url(cls, args: dict) -> dict: + provider = args.get("provider", "") + url = args.get("url") + options = args.get("options", "") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) + if not crawl_sub_pages: + params = { + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": only_main_content}, + } + else: + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] + params = { + "includePaths": includes, + "excludePaths": excludes, + "limit": options.get("limit", 1), + "scrapeOptions": {"onlyMainContent": only_main_content}, + } + if options.get("max_depth"): + params["maxDepth"] = options.get("max_depth") + job_id = firecrawl_app.crawl_url(url, params) + website_crawl_time_cache_key = f"website_crawl_{job_id}" + time = str(datetime.datetime.now().timestamp()) + redis_client.setex(website_crawl_time_cache_key, 3600, time) + return {"status": "active", "job_id": job_id} + elif provider == "jinareader": + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + crawl_sub_pages = options.get("crawl_sub_pages", False) + if not crawl_sub_pages: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "data": response.json().get("data")} + else: + response = requests.post( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={ + "url": url, + "maxPages": options.get("limit", 1), + "useSitemap": options.get("use_sitemap", True), + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + else: + raise ValueError("Invalid provider") + + @classmethod + def get_crawl_status(cls, job_id: str, provider: str) -> dict: + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + result = firecrawl_app.check_crawl_status(job_id) + crawl_status_data = { + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), + } + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" + start_time = redis_client.get(website_crawl_time_cache_key) + if start_time: + end_time = datetime.datetime.now().timestamp() + time_consuming = abs(end_time - float(start_time)) + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" + redis_client.delete(website_crawl_time_cache_key) + elif provider == "jinareader": + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + crawl_status_data = { + "status": data.get("status", "active"), + "job_id": job_id, + "total": len(data.get("urls", [])), + "current": len(data.get("processed", [])) + len(data.get("failed", [])), + "data": [], + "time_consuming": data.get("duration", 0) / 1000, + } + + if crawl_status_data["status"] == "completed": + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, + ) + data = response.json().get("data", {}) + formatted_data = [ + { + "title": item.get("data", {}).get("title"), + "source_url": item.get("data", {}).get("url"), + "description": item.get("data", {}).get("description"), + "markdown": item.get("data", {}).get("content"), + } + for item in data.get("processed", {}).values() + ] + crawl_status_data["data"] = formatted_data + else: + raise ValueError("Invalid provider") + return crawl_status_data + + @classmethod + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + # decrypt api_key + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" + if storage.exists(file_key): + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) + else: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + result = firecrawl_app.check_crawl_status(job_id) + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") + if data: + for item in data: + if item.get("source_url") == url: + return dict(item) + return None + elif provider == "jinareader": + if not job_id: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return dict(response.json().get("data", {})) + else: + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + if data.get("status") != "completed": + raise ValueError("Crawl job is not completed") + + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, + ) + data = response.json().get("data", {}) + for item in data.get("processed", {}).values(): + if item.get("data", {}).get("url") == url: + return dict(item.get("data", {})) + return None + else: + raise ValueError("Invalid provider") + + @classmethod + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": + # decrypt api_key + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"onlyMainContent": only_main_content} + result = firecrawl_app.scrape_url(url, params) + return result + else: + raise ValueError("Invalid provider") diff --git a/api/services/workflow/__init__.py b/api/services/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0d57bdfdeda3759f62302686501f24f42b13b7 --- /dev/null +++ b/api/services/workflow/workflow_converter.py @@ -0,0 +1,630 @@ +import json +from typing import Any, Optional + +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + EasyUIBasedAppConfig, + ExternalDataVariableEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, +) +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.file.models import FileUploadConfig +from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.workflow.nodes import NodeType +from events.app_event import app_was_created +from extensions.ext_database import db +from models.account import Account +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import App, AppMode, AppModelConfig +from models.workflow import Workflow, WorkflowType + + +class WorkflowConverter: + """ + App Convert to Workflow Mode + """ + + def convert_to_workflow( + self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str + ): + """ + Convert app to workflow + + - basic mode of chatbot app + + - expert mode of chatbot app + + - completion app + + :param app_model: App instance + :param account: Account + :param name: new app name + :param icon: new app icon + :param icon_type: new app icon type + :param icon_background: new app icon background + :return: new App instance + """ + # convert app model config + if not app_model.app_model_config: + raise ValueError("App model config is required") + + workflow = self.convert_app_model_config_to_workflow( + app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id + ) + + # create new app + new_app = App() + new_app.tenant_id = app_model.tenant_id + new_app.name = name or app_model.name + "(workflow)" + new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon_type = icon_type or app_model.icon_type + new_app.icon = icon or app_model.icon + new_app.icon_background = icon_background or app_model.icon_background + new_app.enable_site = app_model.enable_site + new_app.enable_api = app_model.enable_api + new_app.api_rpm = app_model.api_rpm + new_app.api_rph = app_model.api_rph + new_app.is_demo = False + new_app.is_public = app_model.is_public + new_app.created_by = account.id + new_app.updated_by = account.id + db.session.add(new_app) + db.session.flush() + db.session.commit() + + workflow.app_id = new_app.id + db.session.commit() + + app_was_created.send(new_app, account=account) + + return new_app + + def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): + """ + Convert app model config to workflow mode + :param app_model: App instance + :param app_model_config: AppModelConfig instance + :param account_id: Account ID + """ + # get new app mode + new_app_mode = self._get_new_app_mode(app_model) + + # convert app model config + app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) + + # init workflow graph + graph: dict[str, Any] = {"nodes": [], "edges": []} + + # Convert list: + # - variables -> start + # - model_config -> llm + # - prompt_template -> llm + # - file_upload -> llm + # - external_data_variables -> http-request + # - dataset -> knowledge-retrieval + # - show_retrieve_source -> knowledge-retrieval + + # convert to start node + start_node = self._convert_to_start_node(variables=app_config.variables) + + graph["nodes"].append(start_node) + + # convert to http request node + external_data_variable_node_mapping: dict[str, str] = {} + if app_config.external_data_variables: + http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( + app_model=app_model, + variables=app_config.variables, + external_data_variables=app_config.external_data_variables, + ) + + for http_request_node in http_request_nodes: + graph = self._append_node(graph, http_request_node) + + # convert to knowledge retrieval node + if app_config.dataset: + knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model + ) + + if knowledge_retrieval_node: + graph = self._append_node(graph, knowledge_retrieval_node) + + # convert to llm node + llm_node = self._convert_to_llm_node( + original_app_mode=AppMode.value_of(app_model.mode), + new_app_mode=new_app_mode, + graph=graph, + model_config=app_config.model, + prompt_template=app_config.prompt_template, + file_upload=app_config.additional_features.file_upload, + external_data_variable_node_mapping=external_data_variable_node_mapping, + ) + + graph = self._append_node(graph, llm_node) + + if new_app_mode == AppMode.WORKFLOW: + # convert to end node by app mode + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + else: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + + app_model_config_dict = app_config.app_model_config_dict + + # features + if new_app_mode == AppMode.ADVANCED_CHAT: + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + else: + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + + # create workflow record + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(new_app_mode).value, + version="draft", + graph=json.dumps(graph), + features=json.dumps(features), + created_by=account_id, + environment_variables=[], + conversation_variables=[], + ) + + db.session.add(workflow) + db.session.commit() + + return workflow + + def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + elif app_mode_enum == AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + elif app_mode_enum == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + else: + raise ValueError("Invalid app mode") + + return app_config + + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + """ + Convert to Start Node + :param variables: list of variables + :return: + """ + return { + "id": "start", + "position": None, + "data": { + "title": "START", + "type": NodeType.START.value, + "variables": [jsonable_encoder(v) for v in variables], + }, + } + + def _convert_to_http_request_node( + self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity] + ) -> tuple[list[dict], dict[str, str]]: + """ + Convert API Based Extension to HTTP Request Node + :param app_model: App instance + :param variables: list of variables + :param external_data_variables: list of external data variables + :return: + """ + index = 1 + nodes = [] + external_data_variable_node_mapping = {} + tenant_id = app_model.tenant_id + for external_data_variable in external_data_variables: + tool_type = external_data_variable.type + if tool_type != "api": + continue + + tool_variable = external_data_variable.variable + tool_config = external_data_variable.config + + # get params from config + api_based_extension_id = tool_config.get("api_based_extension_id") + if not api_based_extension_id: + continue + + # get api_based_extension + api_based_extension = self._get_api_based_extension( + tenant_id=tenant_id, api_based_extension_id=api_based_extension_id + ) + + # decrypt api_key + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) + + inputs = {} + for v in variables: + inputs[v.variable] = "{{#start." + v.variable + "#}}" + + request_body = { + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "params": { + "app_id": app_model.id, + "tool_variable": tool_variable, + "inputs": inputs, + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + }, + } + + request_body_json = json.dumps(request_body) + request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") + + http_request_node = { + "id": f"http_request_{index}", + "position": None, + "data": { + "title": f"HTTP REQUEST {api_based_extension.name}", + "type": NodeType.HTTP_REQUEST.value, + "method": "post", + "url": api_based_extension.api_endpoint, + "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, + "headers": "", + "params": "", + "body": {"type": "json", "data": request_body_json}, + }, + } + + nodes.append(http_request_node) + + # append code node for response body parsing + code_node: dict[str, Any] = { + "id": f"code_{index}", + "position": None, + "data": { + "title": f"Parse {api_based_extension.name} Response", + "type": NodeType.CODE.value, + "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], + "code_language": "python3", + "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" + 'response_json)\n return {\n "result": response_body["result"]\n }', + "outputs": {"result": {"type": "string"}}, + }, + } + + nodes.append(code_node) + + external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] + index += 1 + + return nodes, external_data_variable_node_mapping + + def _convert_to_knowledge_retrieval_node( + self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity + ) -> Optional[dict]: + """ + Convert datasets to Knowledge Retrieval Node + :param new_app_mode: new app mode + :param dataset_config: dataset + :param model_config: model config + :return: + """ + retrieve_config = dataset_config.retrieve_config + if new_app_mode == AppMode.ADVANCED_CHAT: + query_variable_selector = ["sys", "query"] + elif retrieve_config.query_variable: + # fetch query variable + query_variable_selector = ["start", retrieve_config.query_variable] + else: + return None + + return { + "id": "knowledge_retrieval", + "position": None, + "data": { + "title": "KNOWLEDGE RETRIEVAL", + "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "query_variable_selector": query_variable_selector, + "dataset_ids": dataset_config.dataset_ids, + "retrieval_mode": retrieve_config.retrieve_strategy.value, + "single_retrieval_config": { + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": { + **model_config.parameters, + "stop": model_config.stop, + }, + } + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + else None, + "multiple_retrieval_config": { + "top_k": retrieve_config.top_k, + "score_threshold": retrieve_config.score_threshold, + "reranking_model": retrieve_config.reranking_model, + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, + }, + } + + def _convert_to_llm_node( + self, + original_app_mode: AppMode, + new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadConfig] = None, + external_data_variable_node_mapping: dict[str, str] | None = None, + ) -> dict: + """ + Convert to LLM Node + :param original_app_mode: original app mode + :param new_app_mode: new app mode + :param graph: graph + :param model_config: model config + :param prompt_template: prompt template + :param file_upload: file upload config (optional) + :param external_data_variable_node_mapping: external data variable node mapping + """ + # fetch start and knowledge retrieval node + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + knowledge_retrieval_node = next( + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + ) + + role_prefix = None + prompts: Any = None + + # Chat Model + if model_config.mode == LLMMode.CHAT.value: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=original_app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False, + ) + + template = prompt_template_config["prompt_template"].template + if not template: + prompts = [] + else: + template = self._replace_template_variables( + template, start_node["data"]["variables"], external_data_variable_node_mapping + ) + + prompts = [{"role": "user", "text": template}] + else: + advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template + + prompts = [] + if advanced_chat_prompt_template: + for m in advanced_chat_prompt_template.messages: + text = m.text + text = self._replace_template_variables( + text, start_node["data"]["variables"], external_data_variable_node_mapping + ) + + prompts.append({"role": m.role.value, "text": text}) + # Completion Model + else: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=original_app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False, + ) + + template = prompt_template_config["prompt_template"].template + template = self._replace_template_variables( + template=template, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, + ) + + prompts = {"text": template} + + prompt_rules = prompt_template_config["prompt_rules"] + role_prefix = { + "user": prompt_rules.get("human_prefix", "Human"), + "assistant": prompt_rules.get("assistant_prefix", "Assistant"), + } + else: + advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + if advanced_completion_prompt_template: + text = advanced_completion_prompt_template.prompt + text = self._replace_template_variables( + template=text, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, + ) + else: + text = "" + + text = text.replace("{{#query#}}", "{{#sys.query#}}") + + prompts = { + "text": text, + } + + if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: + role_prefix = { + "user": advanced_completion_prompt_template.role_prefix.user, + "assistant": advanced_completion_prompt_template.role_prefix.assistant, + } + + memory = None + if new_app_mode == AppMode.ADVANCED_CHAT: + memory = {"role_prefix": role_prefix, "window": {"enabled": False}} + + completion_params = model_config.parameters + completion_params.update({"stop": model_config.stop}) + return { + "id": "llm", + "position": None, + "data": { + "title": "LLM", + "type": NodeType.LLM.value, + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": completion_params, + }, + "prompt_template": prompts, + "memory": memory, + "context": { + "enabled": knowledge_retrieval_node is not None, + "variable_selector": ["knowledge_retrieval", "result"] + if knowledge_retrieval_node is not None + else None, + }, + "vision": { + "enabled": file_upload is not None, + "variable_selector": ["sys", "files"] if file_upload is not None else None, + "configs": {"detail": file_upload.image_config.detail} + if file_upload is not None and file_upload.image_config is not None + else None, + }, + }, + } + + def _replace_template_variables( + self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None + ) -> str: + """ + Replace Template Variables + :param template: template + :param variables: list of variables + :param external_data_variable_node_mapping: external data variable node mapping + :return: + """ + for v in variables: + template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") + + if external_data_variable_node_mapping: + for variable, code_node_id in external_data_variable_node_mapping.items(): + template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") + + return template + + def _convert_to_end_node(self) -> dict: + """ + Convert to End Node + :return: + """ + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], + }, + } + + def _convert_to_answer_node(self) -> dict: + """ + Convert to Answer Node + :return: + """ + # for original chat app + return { + "id": "answer", + "position": None, + "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, + } + + def _create_edge(self, source: str, target: str) -> dict: + """ + Create Edge + :param source: source node id + :param target: target node id + :return: + """ + return {"id": f"{source}-{target}", "source": source, "target": target} + + def _append_node(self, graph: dict, node: dict) -> dict: + """ + Append Node to Graph + + :param graph: Graph, include: nodes, edges + :param node: Node to append + :return: + """ + previous_node = graph["nodes"][-1] + graph["nodes"].append(node) + graph["edges"].append(self._create_edge(previous_node["id"], node["id"])) + return graph + + def _get_new_app_mode(self, app_model: App) -> AppMode: + """ + Get new app mode + :param app_model: App instance + :return: AppMode + """ + if app_model.mode == AppMode.COMPLETION.value: + return AppMode.WORKFLOW + else: + return AppMode.ADVANCED_CHAT + + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str): + """ + Get API Based Extension + :param tenant_id: tenant id + :param api_based_extension_id: api based extension id + :return: + """ + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") + + return api_based_extension diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7eab0ac1d8611fbf10060032f6363895724375ef --- /dev/null +++ b/api/services/workflow_app_service.py @@ -0,0 +1,67 @@ +import uuid + +from flask_sqlalchemy.pagination import Pagination +from sqlalchemy import and_, or_ + +from extensions.ext_database import db +from models import App, EndUser, WorkflowAppLog, WorkflowRun +from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus + + +class WorkflowAppService: + def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: + """ + Get paginate workflow app logs + :param app: app model + :param args: request args + :return: + """ + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id + ) + + status = WorkflowRunStatus.value_of(args.get("status", "")) if args.get("status") else None + keyword = args["keyword"] + if keyword or status: + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) + + if keyword: + keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + keyword_conditions = [ + WorkflowRun.inputs.ilike(keyword_like_val), + WorkflowRun.outputs.ilike(keyword_like_val), + # filter keyword by end user session id if created by end user role + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + ] + + # filter keyword by workflow run id + keyword_uuid = self._safe_parse_uuid(keyword) + if keyword_uuid: + keyword_conditions.append(WorkflowRun.id == keyword_uuid) + + query = query.outerjoin( + EndUser, + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), + ).filter(or_(*keyword_conditions)) + + if status: + # join with workflow_run and filter by status + query = query.filter(WorkflowRun.status == status.value) + + query = query.order_by(WorkflowAppLog.created_at.desc()) + + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) + + return pagination + + @staticmethod + def _safe_parse_uuid(value: str): + # fast check + if len(value) < 32: + return None + + try: + return uuid.UUID(value) + except ValueError: + return None diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py new file mode 100644 index 0000000000000000000000000000000000000000..4343596a236f5fba194b8fb27eef2207709abbea --- /dev/null +++ b/api/services/workflow_run_service.py @@ -0,0 +1,138 @@ +from typing import Optional + +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom +from models.model import App +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, +) + + +class WorkflowRunService: + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + """ + Get advanced chat app workflow run list + Only return triggered_from == advanced_chat + + :param app_model: app model + :param args: request args + """ + + class WorkflowWithMessage: + message_id: str + conversation_id: str + + def __init__(self, workflow_run: WorkflowRun): + self._workflow_run = workflow_run + + def __getattr__(self, item): + return getattr(self._workflow_run, item) + + pagination = self.get_paginate_workflow_runs(app_model, args) + + with_message_workflow_runs = [] + for workflow_run in pagination.data: + message = workflow_run.message + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) + if message: + with_message_workflow_run.message_id = message.id + with_message_workflow_run.conversation_id = message.conversation_id + + with_message_workflow_runs.append(with_message_workflow_run) + + pagination.data = with_message_workflow_runs + return pagination + + def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get("limit", 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + ) + + if args.get("last_id"): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get("last_id"), + ).first() + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) + else: + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id, + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) + + return workflow_run + + def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_workflow_run(app_model, run_id) + + if not workflow_run: + return [] + + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) + + return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7a9c770d9306bc21dcae550bcb9503ab197e7c --- /dev/null +++ b/api/services/workflow_service.py @@ -0,0 +1,388 @@ +import json +import time +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Any, Optional, cast +from uuid import uuid4 + +from sqlalchemy import desc + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.model_runtime.utils.encoders import jsonable_encoder +from core.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from core.workflow.workflow_entry import WorkflowEntry +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from models.account import Account +from models.enums import CreatedByRole +from models.model import App, AppMode +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowType, +) +from services.errors.app import WorkflowHashNotEqualError +from services.workflow.workflow_converter import WorkflowConverter + + +class WorkflowService: + """ + Workflow Service + """ + + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) + + # return draft workflow + return workflow + + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + + if not app_model.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) + + return workflow + + def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: + """ + Get published workflow with pagination + """ + if not app_model.workflow_id: + return [], False + + workflows = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_model.id) + .order_by(desc(Workflow.version)) + .offset((page - 1) * limit) + .limit(limit + 1) + .all() + ) + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + + def sync_draft_workflow( + self, + *, + app_model: App, + graph: dict, + features: dict, + unique_hash: Optional[str], + account: Account, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ) -> Workflow: + """ + Sync draft workflow + :raises WorkflowHashNotEqualError + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_model.mode).value, + version="draft", + graph=json.dumps(graph), + features=json.dumps(features), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + + # commit db session changes + db.session.commit() + + # trigger app workflow events + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow) + + # return draft workflow + return workflow + + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: + """ + Publish workflow from draft + + :param app_model: App instance + :param account: Account instance + :param draft_workflow: Workflow instance + """ + if not draft_workflow: + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError("No valid workflow found.") + + # create new workflow + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=draft_workflow.type, + version=str(datetime.now(UTC).replace(tzinfo=None)), + graph=draft_workflow.graph, + features=draft_workflow.features, + created_by=account.id, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + ) + + # commit db session changes + db.session.add(workflow) + db.session.flush() + db.session.commit() + + app_model.workflow_id = workflow.id + db.session.commit() + + # trigger app workflow events + app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + + # return new workflow + return workflow + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configs + """ + # return default block config + default_block_configs = [] + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_type_enum = NodeType(node_type) + + # return default block config + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + return None + + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + try: + node_instance, generator = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + node_instance = cast(BaseNode[BaseNodeData], node_instance) + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args: dict[str, Any] = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None + except WorkflowNodeRunFailedError as e: + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + if run_succeeded and node_run_result: + # create workflow node execution + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None + ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = json.dumps(inputs) + workflow_node_execution.process_data = json.dumps(process_data) + workflow_node_execution.outputs = json.dumps(outputs) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.error = node_run_result.error + else: + # create workflow node execution + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + """ + Basic mode of chatbot app(expert mode) to workflow + Completion App to Workflow App + + :param app_model: App instance + :param account: Account instance + :param args: dict + :return: + """ + # chatbot convert to workflow mode + workflow_converter = WorkflowConverter() + + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") + + # convert to workflow + new_app: App = workflow_converter.convert_to_workflow( + app_model=app_model, + account=account, + name=args.get("name", "Default Name"), + icon_type=args.get("icon_type", "emoji"), + icon=args.get("icon", "🤖"), + icon_background=args.get("icon_background", "#FFEAD5"), + ) + + return new_app + + def validate_features_structure(self, app_model: App, features: dict) -> dict: + if app_model.mode == AppMode.ADVANCED_CHAT.value: + return AdvancedChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + elif app_model.mode == AppMode.WORKFLOW.value: + return WorkflowAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + else: + raise ValueError(f"Invalid app mode: {app_model.mode}") diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7637b31454e556c3afe9ef788af71457310f86ff --- /dev/null +++ b/api/services/workspace_service.py @@ -0,0 +1,53 @@ +from flask_login import current_user # type: ignore + +from configs import dify_config +from extensions.ext_database import db +from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole +from services.account_service import TenantService +from services.feature_service import FeatureService + + +class WorkspaceService: + @classmethod + def get_tenant_info(cls, tenant: Tenant): + if not tenant: + return None + tenant_info = { + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", + } + + # Get role of user + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + assert tenant_account_join is not None, "TenantAccountJoin not found" + tenant_info["role"] = tenant_account_join.role + + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): + base_url = dify_config.FILES_URL + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, + } + + return tenant_info diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7fcdadeaa374ac72d5799b7de96334a269df9f --- /dev/null +++ b/api/tasks/add_document_to_index_task.py @@ -0,0 +1,118 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DatasetAutoDisableLog, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def add_document_to_index_task(dataset_document_id: str): + """ + Async Add document to index + :param dataset_document_id: + + Usage: add_document_to_index.delay(dataset_document_id) + """ + logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) + start_at = time.perf_counter() + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() + if not dataset_document: + raise NotFound("Document not found") + + if dataset_document.indexing_status != "completed": + return + + indexing_cache_key = "document_{}_indexing".format(dataset_document.id) + + try: + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == False, + DocumentSegment.status == "completed", + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + + dataset = dataset_document.dataset + + if not dataset: + raise Exception("Document has no dataset") + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents) + + # delete auto disable log + db.session.query(DatasetAutoDisableLog).filter( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + + # update segment to enable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update( + { + DocumentSegment.enabled: True, + DocumentSegment.disabled_at: None, + DocumentSegment.disabled_by: None, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" + ) + ) + except Exception as e: + logging.exception("add document to index failed") + dataset_document.enabled = False + dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.status = "error" + dataset_document.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..aab21a44109975c55aecac2a21960cb535c79223 --- /dev/null +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -0,0 +1,57 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue="dataset") +def add_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): + """ + Add annotation to index. + :param annotation_id: annotation id + :param question: question + :param tenant_id: tenant id + :param app_id: app id + :param collection_binding_id: embedding binding id + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, "annotation" + ) + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) + + document = Document( + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} + ) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.create([document], duplicate_check=True) + + end_at = time.perf_counter() + logging.info( + click.style( + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py new file mode 100644 index 0000000000000000000000000000000000000000..06162b02d60f8b7c1c88963bc41632bcacdf7747 --- /dev/null +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -0,0 +1,90 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import App, AppAnnotationSetting, MessageAnnotation +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue="dataset") +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str): + """ + Add annotation to index. + :param job_id: job_id + :param content_list: content list + :param app_id: app id + :param tenant_id: tenant id + :param user_id: user_id + + """ + logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) + start_at = time.perf_counter() + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) + # get app info + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + + if app: + try: + documents = [] + for content in content_list: + annotation = MessageAnnotation( + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id + ) + db.session.add(annotation) + db.session.flush() + + document = Document( + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + # if annotation reply is enabled , batch add annotations' index + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + + if app_annotation_setting: + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) + ) + if not dataset_collection_binding: + raise NotFound("App annotation setting not found") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.create(documents, duplicate_check=True) + + db.session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logging.info( + click.style( + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception as e: + db.session.rollback() + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) + redis_client.setex(indexing_error_msg_key, 600, str(e)) + logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a598ce4b6bcaebfc61c27f6ae65afe42c9af3a --- /dev/null +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.datasource.vdb.vector_factory import Vector +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue="dataset") +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str): + """ + Async delete annotation index task + """ + logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) + start_at = time.perf_counter() + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, "annotation" + ) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + collection_binding_id=dataset_collection_binding.id, + ) + + try: + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") + end_at = time.perf_counter() + logging.info( + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) + except Exception as e: + logging.exception("Annotation deleted index failed") diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py new file mode 100644 index 0000000000000000000000000000000000000000..26bf1c7c9fa32e7623669a05a1f297aa47174a51 --- /dev/null +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -0,0 +1,68 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.datasource.vdb.vector_factory import Vector +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import App, AppAnnotationSetting, MessageAnnotation + + +@shared_task(queue="dataset") +def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): + """ + Async enable annotation reply task + """ + logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) + start_at = time.perf_counter() + # get app info + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() + if not app: + raise NotFound("App not found") + + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + + if not app_annotation_setting: + raise NotFound("App annotation setting not found") + + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) + + try: + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, + ) + + try: + if annotations_count > 0: + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("app_id", app_id) + except Exception: + logging.exception("Delete annotation index failed when annotation deleted.") + redis_client.setex(disable_app_annotation_job_key, 600, "completed") + + # delete annotation setting + db.session.delete(app_annotation_setting) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) + except Exception as e: + logging.exception("Annotation batch deleted index failed") + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) + redis_client.setex(disable_app_annotation_error_key, 600, str(e)) + finally: + redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b42af0c7faf67e623d878590b010dd6d41681212 --- /dev/null +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -0,0 +1,102 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset +from models.model import App, AppAnnotationSetting, MessageAnnotation +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue="dataset") +def enable_annotation_reply_task( + job_id: str, + app_id: str, + user_id: str, + tenant_id: str, + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, +): + """ + Async enable annotation reply task + """ + logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) + start_at = time.perf_counter() + # get app info + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + + if not app: + raise NotFound("App not found") + + annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) + + try: + documents = [] + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, embedding_model_name, "annotation" + ) + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + annotation_setting.score_threshold = score_threshold + annotation_setting.collection_binding_id = dataset_collection_binding.id + annotation_setting.updated_user_id = user_id + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(annotation_setting) + else: + new_app_annotation_setting = AppAnnotationSetting( + app_id=app_id, + score_threshold=score_threshold, + collection_binding_id=dataset_collection_binding.id, + created_user_id=user_id, + updated_user_id=user_id, + ) + db.session.add(new_app_annotation_setting) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id, + ) + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question, + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + try: + vector.delete_by_metadata_field("app_id", app_id) + except Exception as e: + logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) + vector.create(documents) + db.session.commit() + redis_client.setex(enable_app_annotation_job_key, 600, "completed") + end_at = time.perf_counter() + logging.info( + click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) + except Exception as e: + logging.exception("Annotation batch created index failed") + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) + redis_client.setex(enable_app_annotation_error_key, 600, str(e)) + db.session.rollback() + finally: + redis_client.delete(enable_app_annotation_key) diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..8c675feaa6e06f80e73b23521f055de020178d32 --- /dev/null +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -0,0 +1,58 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document +from models.dataset import Dataset +from services.dataset_service import DatasetCollectionBindingService + + +@shared_task(queue="dataset") +def update_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): + """ + Update annotation to index. + :param annotation_id: annotation id + :param question: question + :param tenant_id: tenant id + :param app_id: app id + :param collection_binding_id: embedding binding id + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, "annotation" + ) + + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) + + document = Document( + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} + ) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) + vector.add_texts([document]) + end_at = time.perf_counter() + logging.info( + click.style( + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Build index for annotation failed") diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py new file mode 100644 index 0000000000000000000000000000000000000000..3bae82a5e3fff9d33240e6b9b4133984faaa9579 --- /dev/null +++ b/api/tasks/batch_clean_document_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_ids: file ids + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_ids: + files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + for file in files: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d23927973c5f801533f718dc5b6386a36c3c77 --- /dev/null +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -0,0 +1,123 @@ +import datetime +import logging +import time +import uuid + +import click +from celery import shared_task # type: ignore +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs import helper +from models.dataset import Dataset, Document, DocumentSegment +from services.vector_service import VectorService + + +@shared_task(queue="dataset") +def batch_create_segment_to_index_task( + job_id: str, + content: list, + dataset_id: str, + document_id: str, + tenant_id: str, + user_id: str, +): + """ + Async batch create segment to index + :param job_id: + :param content: + :param dataset_id: + :param document_id: + :param tenant_id: + :param user_id: + + Usage: batch_create_segment_to_index_task.delay(segment_id) + """ + logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) + start_at = time.perf_counter() + + indexing_cache_key = "segment_batch_import_{}".format(job_id) + + try: + with Session(db.engine) as session: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise ValueError("Dataset not exist.") + + dataset_document = session.get(Document, document_id) + if not dataset_document: + raise ValueError("Document not exist.") + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + raise ValueError("Document is not available.") + document_segments = [] + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + word_count_change = 0 + segments_to_insert: list[str] = [] + max_position_stmt = select(func.max(DocumentSegment.position)).where( + DocumentSegment.document_id == dataset_document.id + ) + max_position = session.scalar(max_position_stmt) or 1 + for segment in content: + content_str = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content_str) + # calc embedding use tokens + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position, + content=content_str, + word_count=len(content_str), + tokens=tokens, + created_by=user_id, + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + status="completed", + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) + max_position += 1 + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed + # update document word count + dataset_document.word_count += word_count_change + session.add(dataset_document) + # add index to db + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) + session.commit() + + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logging.info( + click.style( + "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), + fg="green", + ) + ) + except Exception as e: + logging.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc7a896fc05aaaf6e3df56011ca32f2cc9c65c9 --- /dev/null +++ b/api/tasks/clean_dataset_task.py @@ -0,0 +1,119 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, +) +from models.model import UploadFile + + +# Add import statement for ValueError +@shared_task(queue="dataset") +def clean_dataset_task( + dataset_id: str, + tenant_id: str, + indexing_technique: str, + index_struct: str, + collection_binding_id: str, + doc_form: str, +): + """ + Clean dataset when dataset deleted. + :param dataset_id: dataset id + :param tenant_id: tenant id + :param indexing_technique: indexing technique + :param index_struct: index struct dict + :param collection_binding_id: collection binding id + :param doc_form: dataset form + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = Dataset( + id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, + ) + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() + segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + + if documents is None or len(documents) == 0: + logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) + else: + logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) + # Specify the index type before initializing the index processor + if doc_form is None: + raise ValueError("Index type must be specified.") + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + + for document in documents: + db.session.delete(document) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() + db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() + db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() + + # delete files + if documents: + for document in documents: + try: + if document.data_source_type == "upload_file": + if document.data_source_info: + data_source_info = document.data_source_info_dict + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) + if not file: + continue + storage.delete(file.key) + db.session.delete(file) + except Exception: + continue + + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py new file mode 100644 index 0000000000000000000000000000000000000000..7a536f74265757d1e2e4e191360b228309587584 --- /dev/null +++ b/api/tasks/clean_document_task.py @@ -0,0 +1,78 @@ +import logging +import time +from typing import Optional + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): + """ + Clean document when document deleted. + :param document_id: document id + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_id: file id + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_id: + file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + if file: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file_id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6eb00a6259d5803d1b16f433066d1dea670e11 --- /dev/null +++ b/api/tasks/clean_notion_document_task.py @@ -0,0 +1,55 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + + +@shared_task(queue="dataset") +def clean_notion_document_task(document_ids: list[str], dataset_id: str): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + + Usage: clean_notion_document_task.delay(document_ids, dataset_id) + """ + logging.info( + click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + ) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + for document_id in document_ids: + document = db.session.query(Document).filter(Document.id == document_id).first() + db.session.delete(document) + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa053a43cbc614f717808b0bee40217c22a757d --- /dev/null +++ b/api/tasks/create_segment_to_index_task.py @@ -0,0 +1,95 @@ +import datetime +import logging +import time +from typing import Optional + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task(queue="dataset") +def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): + """ + Async create segment to index + :param segment_id: + :param keywords: + Usage: create_segment_to_index_task.delay(segment_id) + """ + logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound("Segment not found") + + if segment.status != "waiting": + return + + indexing_cache_key = "segment_{}_indexing".format(segment.id) + + try: + # update segment status to indexing + update_params = { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + dataset = segment.dataset + + if not dataset: + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + return + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, [document]) + + # update segment to completed + update_params = { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) + except Exception as e: + logging.exception("create segment to index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.status = "error" + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b5ab91a8934caca1f4b6b6457ee1779045c36c --- /dev/null +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -0,0 +1,169 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def deal_dataset_vector_index_task(dataset_id: str, action: str): + """ + Async deal dataset from index + :param dataset_id: dataset_id + :param action: action + Usage: deal_dataset_vector_index_task.delay(dataset_id, action) + """ + logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = Dataset.query.filter_by(id=dataset_id).first() + + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "remove": + index_processor.clean(dataset, None, with_keywords=False) + elif action == "add": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py new file mode 100644 index 0000000000000000000000000000000000000000..52c884ca29e3dc37dab3edd0d2b5f1febcb7b375 --- /dev/null +++ b/api/tasks/delete_account_task.py @@ -0,0 +1,26 @@ +import logging + +from celery import shared_task # type: ignore + +from extensions.ext_database import db +from models.account import Account +from services.billing_service import BillingService +from tasks.mail_account_deletion_task import send_deletion_success_task + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def delete_account_task(account_id): + account = db.session.query(Account).filter(Account.id == account_id).first() + try: + BillingService.delete_account(account_id) + except Exception as e: + logger.exception(f"Failed to delete account {account_id} from billing service.") + raise + + if not account: + logger.error(f"Account {account_id} not found.") + return + # send success email + send_deletion_success_task.delay(account.email) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..3b04143dd9a0756d53d204a3664d03a9d7d40fb0 --- /dev/null +++ b/api/tasks/delete_segment_from_index_task.py @@ -0,0 +1,43 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from models.dataset import Dataset, Document + + +@shared_task(queue="dataset") +def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): + """ + Async Remove segment from index + :param index_node_ids: + :param dataset_id: + :param document_id: + + Usage: delete_segment_from_index_task.delay(segment_ids) + """ + logging.info(click.style("Start delete segment from index", fg="green")) + start_at = time.perf_counter() + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + return + + dataset_document = db.session.query(Document).filter(Document.id == document_id).first() + if not dataset_document: + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + return + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + end_at = time.perf_counter() + logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + logging.exception("delete segment from index failed") diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..f30a1cc7acfd6ce66eb233581db1971621f8d356 --- /dev/null +++ b/api/tasks/disable_segment_from_index_task.py @@ -0,0 +1,64 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task(queue="dataset") +def disable_segment_from_index_task(segment_id: str): + """ + Async disable segment from index + :param segment_id: + + Usage: disable_segment_from_index_task.delay(segment_id) + """ + logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound("Segment not found") + + if segment.status != "completed": + raise NotFound("Segment is not completed , disable action is not allowed.") + + indexing_cache_key = "segment_{}_indexing".format(segment.id) + + try: + dataset = segment.dataset + + if not dataset: + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + return + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [segment.index_node_id]) + + end_at = time.perf_counter() + logging.info( + click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("remove segment from index failed") + segment.enabled = True + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..67112666e7b5389aa397c62a4a9de63a6399fe82 --- /dev/null +++ b/api/tasks/disable_segments_from_index_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async disable segments from index + :param segment_ids: + + Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py new file mode 100644 index 0000000000000000000000000000000000000000..d686698b9a53380ec9763592f0fd079de4a068f7 --- /dev/null +++ b/api/tasks/document_indexing_sync_task.py @@ -0,0 +1,112 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.extractor.notion_extractor import NotionExtractor +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment +from models.source import DataSourceOauthBinding + + +@shared_task(queue="dataset") +def document_indexing_sync_task(dataset_id: str, document_id: str): + """ + Async update document + :param dataset_id: + :param document_id: + + Usage: document_indexing_sync_task.delay(dataset_id, document_id) + """ + logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + + if not document: + raise NotFound("Document not found") + + data_source_info = document.data_source_info_dict + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): + raise ValueError("no notion page found") + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == document.tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + + loader = NotionExtractor( + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type, + notion_access_token=data_source_binding.access_token, + tenant_id=document.tenant_id, + ) + + last_edited_time = loader.get_notion_last_edited_time() + + # check the page is updated + if last_edited_time != page_edited_time: + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + + # delete all document segment and index + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned document when document update data source or process rule failed") + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logging.info( + click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..21b571b6cb5bd456793c1f1419eb9c393d1642a4 --- /dev/null +++ b/api/tasks/document_indexing_task.py @@ -0,0 +1,80 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from configs import dify_config +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from extensions.ext_database import db +from models.dataset import Dataset, Document +from services.feature_service import FeatureService + + +@shared_task(queue="dataset") +def document_indexing_task(dataset_id: str, document_ids: list): + """ + Async process document + :param dataset_id: + :param document_ids: + + Usage: document_indexing_task.delay(dataset_id, document_id) + """ + documents = [] + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow")) + return + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + for document_id in document_ids: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + return + + for document_id in document_ids: + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + + if document: + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + documents.append(document) + db.session.add(document) + db.session.commit() + + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f14830c979ada9d2d68201dc62f1727cb05aa5 --- /dev/null +++ b/api/tasks/document_indexing_update_task.py @@ -0,0 +1,75 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + + +@shared_task(queue="dataset") +def document_indexing_update_task(dataset_id: str, document_id: str): + """ + Async update document + :param dataset_id: + :param document_id: + + Usage: document_indexing_update_task.delay(dataset_id, document_id) + """ + logging.info(click.style("Start update document: {}".format(document_id), fg="green")) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + + if not document: + raise NotFound("Document not found") + + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + + # delete all document segment and index + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") + + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned document when document update data source or process rule failed") + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1d2b6b5d147e81660f60a18245ec12749232ee --- /dev/null +++ b/api/tasks/duplicate_document_indexing_task.py @@ -0,0 +1,96 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from configs import dify_config +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment +from services.feature_service import FeatureService + + +@shared_task(queue="dataset") +def duplicate_document_indexing_task(dataset_id: str, document_ids: list): + """ + Async process document + :param dataset_id: + :param document_ids: + + Usage: duplicate_document_indexing_task.delay(dataset_id, document_id) + """ + documents = [] + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") + + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + for document_id in document_ids: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + return + + for document_id in document_ids: + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + + if document: + # clean old data + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + documents.append(document) + db.session.add(document) + db.session.commit() + + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..76522f4720cf952e4c798c021b831a0225649162 --- /dev/null +++ b/api/tasks/enable_segment_to_index_task.py @@ -0,0 +1,96 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task(queue="dataset") +def enable_segment_to_index_task(segment_id: str): + """ + Async enable segment to index + :param segment_id: + + Usage: enable_segment_to_index_task.delay(segment_id) + """ + logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound("Segment not found") + + if segment.status != "completed": + raise NotFound("Segment is not completed, enable action is not allowed.") + + indexing_cache_key = "segment_{}_indexing".format(segment.id) + + try: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + dataset = segment.dataset + + if not dataset: + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + return + + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + # save vector index + index_processor.load(dataset, [document]) + + end_at = time.perf_counter() + logging.info( + click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) + except Exception as e: + logging.exception("enable segment to index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.status = "error" + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..0864e05e25f5a64228e35a46387f559f21bc8217 --- /dev/null +++ b/api/tasks/enable_segments_to_index_task.py @@ -0,0 +1,108 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async enable segments to index + :param segment_ids: + + Usage: enable_segments_to_index_task.delay(segment_ids) + """ + start_at = time.perf_counter() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + if not segments: + return + + try: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents) + + end_at = time.perf_counter() + logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + except Exception as e: + logging.exception("enable segments to index failed") + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + "enabled": False, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/external_document_indexing_task.py b/api/tasks/external_document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..a45b3030bf253a226df6708eb7dc8ba040d4b1ff --- /dev/null +++ b/api/tasks/external_document_indexing_task.py @@ -0,0 +1,91 @@ +import json +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.indexing_runner import DocumentIsPausedError +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, ExternalKnowledgeApis +from models.model import UploadFile +from services.external_knowledge_service import ExternalDatasetService + + +@shared_task(queue="dataset") +def external_document_indexing_task( + dataset_id: str, external_knowledge_api_id: str, data_source: dict, process_parameter: dict +): + """ + Async process document + :param dataset_id: + :param external_knowledge_api_id: + :param data_source: + :param process_parameter: + Usage: external_document_indexing_task.delay(dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info( + click.style("Processed external dataset: {} failed, dataset not exit.".format(dataset_id), fg="red") + ) + return + + # get external api template + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter( + ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == dataset.tenant_id + ) + .first() + ) + + if not external_knowledge_api: + logging.info( + click.style( + "Processed external dataset: {} failed, api template: {} not exit.".format( + dataset_id, external_knowledge_api_id + ), + fg="red", + ) + ) + return + files = {} + if data_source["type"] == "upload_file": + upload_file_list = data_source["info_list"]["file_info_list"]["file_ids"] + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + if file: + files[file.id] = (file.name, storage.load_once(file.key), file.mime_type) + try: + settings = ExternalDatasetService.get_external_knowledge_api_settings( + json.loads(external_knowledge_api.settings) + ) + + # do http request + response = ExternalDatasetService.process_external_api(settings, files) + job_id = response.json().get("job_id") + if job_id: + # save job_id to dataset + dataset.job_id = job_id + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Processed external dataset: {} successful, latency: {}".format(dataset.id, end_at - start_at), + fg="green", + ) + ) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + + except Exception: + pass diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py new file mode 100644 index 0000000000000000000000000000000000000000..49a3a6d280c1c95e7c2641de918fd75b94dd5c0f --- /dev/null +++ b/api/tasks/mail_account_deletion_task.py @@ -0,0 +1,70 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_deletion_success_task(to): + """Send email to user regarding account deletion. + + Args: + log (AccountDeletionLog): Account deletion log object + """ + if not mail.is_inited(): + return + + logging.info(click.style(f"Start send account deletion success email to {to}", fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template( + "delete_account_success_template_en-US.html", + to=to, + email=to, + ) + mail.send(to=to, subject="Your Dify.AI Account Has Been Successfully Deleted", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send account deletion success email to {}: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send account deletion success email to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_account_deletion_verification_code(to, code): + """Send email to user regarding account deletion verification code. + + Args: + to (str): Recipient email address + code (str): Verification code + """ + if not mail.is_inited(): + return + + logging.info(click.style(f"Start send account deletion verification code email to {to}", fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template("delete_account_code_email_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Dify.AI Account Deletion and Verification", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send account deletion verification code email to {} succeeded: latency: {}".format( + to, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logging.exception("Send account deletion verification code email to {} failed".format(to)) diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc935548f90b82e1b962b3c4c87c32cc95bbb2d --- /dev/null +++ b/api/tasks/mail_email_code_login.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_email_code_login_mail_task(language: str, to: str, code: str): + """ + Async Send email code login mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Email code to be included in the email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send email code login mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="邮箱验证码", html=html_content) + else: + html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Email Code", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send email code login mail to {} failed".format(to)) diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py new file mode 100644 index 0000000000000000000000000000000000000000..3094527fd4094598965c88935b0b7400cd558db7 --- /dev/null +++ b/api/tasks/mail_invite_member_task.py @@ -0,0 +1,61 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from configs import dify_config +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): + """ + Async Send invite member mail + :param language + :param to + :param token + :param inviter_name + :param workspace_name + + Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) + """ + if not mail.is_inited(): + return + + logging.info( + click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") + ) + start_at = time.perf_counter() + + # send invite member mail using different languages + try: + url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" + if language == "zh-Hans": + html_content = render_template( + "invite_member_mail_template_zh-CN.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) + mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) + else: + html_content = render_template( + "invite_member_mail_template_en-US.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) + mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py new file mode 100644 index 0000000000000000000000000000000000000000..d5be94431b62217a1778b2b5d822dcd7caf2678b --- /dev/null +++ b/api/tasks/mail_reset_password_task.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_reset_password_mail_task(language: str, to: str, code: str): + """ + Async Send reset password mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Reset password code + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send reset password mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) + else: + html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Set Your Dify Password", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send password reset mail to {} failed".format(to)) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3b9e17ead6d23ddfa05d61b3b5579b8f09d702 --- /dev/null +++ b/api/tasks/ops_trace_task.py @@ -0,0 +1,54 @@ +import json +import logging + +from celery import shared_task # type: ignore +from flask import current_app + +from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY +from core.ops.entities.trace_entity import trace_info_info_map +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from models.model import Message +from models.workflow import WorkflowRun + + +@shared_task(queue="ops_trace") +def process_trace_tasks(file_info): + """ + Async process trace tasks + :param tasks_data: List of dictionaries containing task data + + Usage: process_trace_tasks.delay(tasks_data) + """ + from core.ops.ops_trace_manager import OpsTraceManager + + app_id = file_info.get("app_id") + file_id = file_info.get("file_id") + file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json" + file_data = json.loads(storage.load(file_path)) + trace_info = file_data.get("trace_info") + trace_info_type = file_data.get("trace_info_type") + trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) + + if trace_info.get("message_data"): + trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"]) + if trace_info.get("workflow_data"): + trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) + if trace_info.get("documents"): + trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] + + try: + if trace_instance: + with current_app.app_context(): + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + trace_instance.trace(trace_info) + logging.info(f"Processing trace tasks success, app_id: {app_id}") + except Exception: + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) + logging.info(f"Processing trace tasks failed, app_id: {app_id}") + finally: + storage.delete(file_path) diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b603d689ba9d8eb9f1174690b08a367bae514f5d --- /dev/null +++ b/api/tasks/recover_document_indexing_task.py @@ -0,0 +1,45 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from extensions.ext_database import db +from models.dataset import Document + + +@shared_task(queue="dataset") +def recover_document_indexing_task(dataset_id: str, document_id: str): + """ + Async recover document + :param dataset_id: + :param document_id: + + Usage: recover_document_indexing_task.delay(dataset_id, document_id) + """ + logging.info(click.style("Recover document: {}".format(document_id), fg="green")) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + + if not document: + raise NotFound("Document not found") + + try: + indexing_runner = IndexingRunner() + if document.indexing_status in {"waiting", "parsing", "cleaning"}: + indexing_runner.run([document]) + elif document.indexing_status == "splitting": + indexing_runner.run_in_splitting_status(document) + elif document.indexing_status == "indexing": + indexing_runner.run_in_indexing_status(document) + end_at = time.perf_counter() + logging.info( + click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) + except Exception: + pass diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py new file mode 100644 index 0000000000000000000000000000000000000000..c3910e2be3a499a544b74c1544ac9776736414b5 --- /dev/null +++ b/api/tasks/remove_app_and_related_data_task.py @@ -0,0 +1,329 @@ +import logging +import time +from collections.abc import Callable + +import click +from celery import shared_task # type: ignore +from sqlalchemy import delete +from sqlalchemy.exc import SQLAlchemyError + +from extensions.ext_database import db +from models.dataset import AppDatasetJoin +from models.model import ( + ApiToken, + AppAnnotationHitHistory, + AppAnnotationSetting, + AppModelConfig, + Conversation, + EndUser, + InstalledApp, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, + RecommendedApp, + Site, + TagBinding, + TraceAppConfig, +) +from models.tools import WorkflowToolProvider +from models.web import PinnedConversation, SavedMessage +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun + + +@shared_task(queue="app_deletion", bind=True, max_retries=3) +def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): + logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green")) + start_at = time.perf_counter() + try: + # Delete related data + _delete_app_model_configs(tenant_id, app_id) + _delete_app_site(tenant_id, app_id) + _delete_app_api_tokens(tenant_id, app_id) + _delete_installed_apps(tenant_id, app_id) + _delete_recommended_apps(tenant_id, app_id) + _delete_app_annotation_data(tenant_id, app_id) + _delete_app_dataset_joins(tenant_id, app_id) + _delete_app_workflows(tenant_id, app_id) + _delete_app_workflow_runs(tenant_id, app_id) + _delete_app_workflow_node_executions(tenant_id, app_id) + _delete_app_workflow_app_logs(tenant_id, app_id) + _delete_app_conversations(tenant_id, app_id) + _delete_app_messages(tenant_id, app_id) + _delete_workflow_tool_providers(tenant_id, app_id) + _delete_app_tag_bindings(tenant_id, app_id) + _delete_end_users(tenant_id, app_id) + _delete_trace_app_configs(tenant_id, app_id) + _delete_conversation_variables(app_id=app_id) + + end_at = time.perf_counter() + logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) + except SQLAlchemyError as e: + logging.exception( + click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red") + ) + raise self.retry(exc=e, countdown=60) # Retry after 60 seconds + except Exception as e: + logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red")) + raise self.retry(exc=e, countdown=60) # Retry after 60 seconds + + +def _delete_app_model_configs(tenant_id: str, app_id: str): + def del_model_config(model_config_id: str): + db.session.query(AppModelConfig).filter(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + + _delete_records( + """select id from app_model_configs where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_model_config, + "app model config", + ) + + +def _delete_app_site(tenant_id: str, app_id: str): + def del_site(site_id: str): + db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) + + _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") + + +def _delete_app_api_tokens(tenant_id: str, app_id: str): + def del_api_token(api_token_id: str): + db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) + + _delete_records( + """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" + ) + + +def _delete_installed_apps(tenant_id: str, app_id: str): + def del_installed_app(installed_app_id: str): + db.session.query(InstalledApp).filter(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + + _delete_records( + """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_installed_app, + "installed app", + ) + + +def _delete_recommended_apps(tenant_id: str, app_id: str): + def del_recommended_app(recommended_app_id: str): + db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from recommended_apps where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_recommended_app, + "recommended app", + ) + + +def _delete_app_annotation_data(tenant_id: str, app_id: str): + def del_annotation_hit_history(annotation_hit_history_id: str): + db.session.query(AppAnnotationHitHistory).filter( + AppAnnotationHitHistory.id == annotation_hit_history_id + ).delete(synchronize_session=False) + + _delete_records( + """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_annotation_hit_history, + "annotation hit history", + ) + + def del_annotation_setting(annotation_setting_id: str): + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from app_annotation_settings where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_annotation_setting, + "annotation setting", + ) + + +def _delete_app_dataset_joins(tenant_id: str, app_id: str): + def del_dataset_join(dataset_join_id: str): + db.session.query(AppDatasetJoin).filter(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + + _delete_records( + """select id from app_dataset_joins where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_dataset_join, + "dataset join", + ) + + +def _delete_app_workflows(tenant_id: str, app_id: str): + def del_workflow(workflow_id: str): + db.session.query(Workflow).filter(Workflow.id == workflow_id).delete(synchronize_session=False) + + _delete_records( + """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow, + "workflow", + ) + + +def _delete_app_workflow_runs(tenant_id: str, app_id: str): + def del_workflow_run(workflow_run_id: str): + db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False) + + _delete_records( + """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_run, + "workflow run", + ) + + +def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): + def del_workflow_node_execution(workflow_node_execution_id: str): + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_node_execution, + "workflow node execution", + ) + + +def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): + def del_workflow_app_log(workflow_app_log_id: str): + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_app_log, + "workflow app log", + ) + + +def _delete_app_conversations(tenant_id: str, app_id: str): + def del_conversation(conversation_id: str): + db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) + + _delete_records( + """select id from conversations where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_conversation, + "conversation", + ) + + +def _delete_conversation_variables(*, app_id: str): + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + with db.engine.connect() as conn: + conn.execute(stmt) + conn.commit() + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) + + +def _delete_app_messages(tenant_id: str, app_id: str): + def del_message(message_id: str): + db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( + synchronize_session=False + ) + db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( + synchronize_session=False + ) + db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) + db.session.query(Message).filter(Message.id == message_id).delete() + + _delete_records( + """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" + ) + + +def _delete_workflow_tool_providers(tenant_id: str, app_id: str): + def del_tool_provider(tool_provider_id: str): + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_tool_provider, + "tool workflow provider", + ) + + +def _delete_app_tag_bindings(tenant_id: str, app_id: str): + def del_tag_binding(tag_binding_id: str): + db.session.query(TagBinding).filter(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + + _delete_records( + """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_tag_binding, + "tag binding", + ) + + +def _delete_end_users(tenant_id: str, app_id: str): + def del_end_user(end_user_id: str): + db.session.query(EndUser).filter(EndUser.id == end_user_id).delete(synchronize_session=False) + + _delete_records( + """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_end_user, + "end user", + ) + + +def _delete_trace_app_configs(tenant_id: str, app_id: str): + def del_trace_app_config(trace_app_config_id: str): + db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from trace_app_config where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_trace_app_config, + "trace app config", + ) + + +def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: + while True: + with db.engine.begin() as conn: + rs = conn.execute(db.text(query_sql), params) + if rs.rowcount == 0: + break + + for i in rs: + record_id = str(i.id) + try: + delete_func(record_id) + db.session.commit() + logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) + except Exception: + logging.exception(f"Error occurred while deleting {name} {record_id}") + continue + rs.close() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c4382f58d75ae9e31a3245865317ef798d68cb --- /dev/null +++ b/api/tasks/remove_document_from_index_task.py @@ -0,0 +1,73 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Document, DocumentSegment + + +@shared_task(queue="dataset") +def remove_document_from_index_task(document_id: str): + """ + Async Remove document from index + :param document_id: document id + + Usage: remove_document_from_index.delay(document_id) + """ + logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id).first() + if not document: + raise NotFound("Document not found") + + if document.indexing_status != "completed": + return + + indexing_cache_key = "document_{}_indexing".format(document.id) + + try: + dataset = document.dataset + + if not dataset: + raise Exception("Document has no dataset") + + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() + index_node_ids = [segment.index_node_id for segment in segments] + if index_node_ids: + try: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + except Exception: + logging.exception(f"clean dataset {dataset.id} from index failed") + # update segment to disable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update( + { + DocumentSegment.enabled: False, + DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.disabled_by: document.disabled_by, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("remove document from index failed") + if not document.archived: + document.enabled = True + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..74fd542f6c4a80bfb425397ddc8557b649a82699 --- /dev/null +++ b/api/tasks/retry_document_indexing_task.py @@ -0,0 +1,96 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.indexing_runner import IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document, DocumentSegment +from services.feature_service import FeatureService + + +@shared_task(queue="dataset") +def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): + """ + Async process document + :param dataset_id: + :param document_ids: + + Usage: retry_document_indexing_task.delay(dataset_id, document_id) + """ + documents: list[Document] = [] + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + + for document_id in document_ids: + retry_indexing_cache_key = "document_{}_is_retried".format(document_id) + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + redis_client.delete(retry_indexing_cache_key) + return + + logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) + except Exception as ex: + document.indexing_status = "error" + document.error = str(ex) + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + logging.info(click.style(str(ex), fg="yellow")) + redis_client.delete(retry_indexing_cache_key) + pass + end_at = time.perf_counter() + logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py new file mode 100644 index 0000000000000000000000000000000000000000..8da050d0d1e2d30bf6daa11aba8a01a1bf6ddc65 --- /dev/null +++ b/api/tasks/sync_website_document_indexing_task.py @@ -0,0 +1,92 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.indexing_runner import IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document, DocumentSegment +from services.feature_service import FeatureService + + +@shared_task(queue="dataset") +def sync_website_document_indexing_task(dataset_id: str, document_id: str): + """ + Async process document + :param dataset_id: + :param document_id: + + Usage: sync_website_document_indexing_task.delay(dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") + + sync_indexing_cache_key = "document_{}_is_sync".format(document_id) + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + redis_client.delete(sync_indexing_cache_key) + return + + logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) + except Exception as ex: + document.indexing_status = "error" + document.error = str(ex) + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() + logging.info(click.style(str(ex), fg="yellow")) + redis_client.delete(sync_indexing_cache_key) + pass + end_at = time.perf_counter() + logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html new file mode 100644 index 0000000000000000000000000000000000000000..88e78f41c78b468235446972b6e2ab76dd37db2d --- /dev/null +++ b/api/templates/clean_document_job_mail_template-US.html @@ -0,0 +1,100 @@ + + + + + + Documents Disabled Notification + + + + + + \ No newline at end of file diff --git a/api/templates/delete_account_code_email_template_en-US.html b/api/templates/delete_account_code_email_template_en-US.html new file mode 100644 index 0000000000000000000000000000000000000000..7707385334eca15279d12894ba9c09cea07d388c --- /dev/null +++ b/api/templates/delete_account_code_email_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify.AI Account Deletion and Verification

+

We received a request to delete your Dify account. To ensure the security of your account and + confirm this action, please use the verification code below:

+
+ {{code}} +
+
+

To complete the account deletion process:

+

1. Return to the account deletion page on our website

+

2. Enter the verification code above

+

3. Click "Confirm Deletion"

+
+

Please note:

+
    +
  • This code is valid for 5 minutes
  • +
  • As the Owner of any Workspaces, your workspaces will be scheduled in a queue for permanent deletion.
  • +
  • All your user data will be queued for permanent deletion.
  • +
+
+ + + \ No newline at end of file diff --git a/api/templates/delete_account_success_template_en-US.html b/api/templates/delete_account_success_template_en-US.html new file mode 100644 index 0000000000000000000000000000000000000000..c5df75cabce0931d6ff0d871342dfde7d9ff4b6d --- /dev/null +++ b/api/templates/delete_account_success_template_en-US.html @@ -0,0 +1,105 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Your Dify.AI Account Has Been Successfully Deleted

+

We're writing to confirm that your Dify.AI account has been successfully deleted as per your request. Your + account is no longer accessible, and you can't log in using your previous credentials. If you decide to use + Dify.AI services in the future, you'll need to create a new account after 30 days. We appreciate the time you + spent with Dify.AI and are sorry to see you go. If you have any questions or concerns about the deletion process, + please don't hesitate to reach out to our support team.

+

Thank you for being a part of the Dify.AI community.

+

Best regards,

+

Dify.AI Team

+
+ + + \ No newline at end of file diff --git a/api/templates/email_code_login_mail_template_en-US.html b/api/templates/email_code_login_mail_template_en-US.html new file mode 100644 index 0000000000000000000000000000000000000000..066818d10c5a11b94001ec8a818ba9ac73539d0f --- /dev/null +++ b/api/templates/email_code_login_mail_template_en-US.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Your login code for Dify

+

Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request a login, don't worry. You can safely ignore this email.

+
+ + diff --git a/api/templates/email_code_login_mail_template_zh-CN.html b/api/templates/email_code_login_mail_template_zh-CN.html new file mode 100644 index 0000000000000000000000000000000000000000..0c2b63a1f1a6944119cbae5631ee91b7a0d12193 --- /dev/null +++ b/api/templates/email_code_login_mail_template_zh-CN.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Dify 的登录验证码

+

复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。

+
+ + diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html new file mode 100644 index 0000000000000000000000000000000000000000..e8bf7f5a52a68994996b64c4bca00c8196a9963c --- /dev/null +++ b/api/templates/invite_member_mail_template_en-US.html @@ -0,0 +1,73 @@ + + + + + + +
+
+ + Dify Logo +
+
+

Dear {{ to }},

+

{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.

+

Click the button below to log in to Dify and join the workspace.

+

Login Here

+
+ +
+ + + diff --git a/api/templates/invite_member_mail_template_zh-CN.html b/api/templates/invite_member_mail_template_zh-CN.html new file mode 100644 index 0000000000000000000000000000000000000000..ccd9cdbaad9e95c6a7498b16699433d8a0f45b22 --- /dev/null +++ b/api/templates/invite_member_mail_template_zh-CN.html @@ -0,0 +1,72 @@ + + + + + + + +
+
+ Dify Logo +
+
+

尊敬的 {{ to }},

+

{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。

+

点击下方按钮即可登录 Dify 并且加入空间。

+

在此登录

+
+ +
+ + diff --git a/api/templates/reset_password_mail_template_en-US.html b/api/templates/reset_password_mail_template_en-US.html new file mode 100644 index 0000000000000000000000000000000000000000..d598fd191c5ff63729245aeee489a80fc89e2342 --- /dev/null +++ b/api/templates/reset_password_mail_template_en-US.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Set your Dify password

+

Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request, don't worry. You can safely ignore this email.

+
+ + diff --git a/api/templates/reset_password_mail_template_zh-CN.html b/api/templates/reset_password_mail_template_zh-CN.html new file mode 100644 index 0000000000000000000000000000000000000000..342c9057a7346afe5c53a5e5046ec0c4fe7ee08d --- /dev/null +++ b/api/templates/reset_password_mail_template_zh-CN.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

设置您的 Dify 账户密码

+

复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。

+
+ + diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/artifact_tests/dependencies/__init__.py b/api/tests/artifact_tests/dependencies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec0783112ac52663511a2e56212f45a340eca10 --- /dev/null +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -0,0 +1,49 @@ +from typing import Any + +import toml # type: ignore + + +def load_api_poetry_configs() -> dict[str, Any]: + pyproject_toml = toml.load("api/pyproject.toml") + return pyproject_toml["tool"]["poetry"] + + +def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]: + configs = load_api_poetry_configs() + configs_by_group = {"main": configs} + for group_name in configs["group"]: + configs_by_group[group_name] = configs["group"][group_name] + dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()} + return dependencies_by_group + + +def test_group_dependencies_sorted(): + for group_name, dependencies in load_all_dependency_groups().items(): + dependency_names = list(dependencies.keys()) + expected_dependency_names = sorted(set(dependency_names)) + section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies" + assert expected_dependency_names == dependency_names, ( + f"Dependencies in group {group_name} are not sorted. " + f"Check and fix [{section}] section in pyproject.toml file" + ) + + +def test_group_dependencies_version_operator(): + for group_name, dependencies in load_all_dependency_groups().items(): + for dependency_name, specification in dependencies.items(): + version_spec = specification if isinstance(specification, str) else specification["version"] + assert not version_spec.startswith("^"), ( + f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' " + f"'^' operator is too wide and not allowed in the version specification." + ) + + +def test_duplicated_dependency_crossing_groups() -> None: + all_dependency_names: list[str] = [] + for dependencies in load_all_dependency_groups().values(): + dependency_names = list(dependencies.keys()) + all_dependency_names.extend(dependency_names) + expected_all_dependency_names = set(all_dependency_names) + assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), ( + "Duplicated dependencies crossing groups are found" + ) diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..6fd144c5c2f349827c5253adf11652068ad66227 --- /dev/null +++ b/api/tests/integration_tests/.env.example @@ -0,0 +1,101 @@ +# OpenAI API Key +OPENAI_API_KEY= + +# Azure OpenAI API Base Endpoint & API Key +AZURE_OPENAI_API_BASE= +AZURE_OPENAI_API_KEY= + +# Anthropic API Key +ANTHROPIC_API_KEY= + +# Replicate API Key +REPLICATE_API_KEY= + +# Hugging Face API Key +HUGGINGFACE_API_KEY= +HUGGINGFACE_TEXT_GEN_ENDPOINT_URL= +HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL= +HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL= + +# Minimax Credentials +MINIMAX_API_KEY= +MINIMAX_GROUP_ID= + +# Spark Credentials +SPARK_APP_ID= +SPARK_API_KEY= +SPARK_API_SECRET= + +# Tongyi Credentials +TONGYI_DASHSCOPE_API_KEY= + +# Wenxin Credentials +WENXIN_API_KEY= +WENXIN_SECRET_KEY= + +# ZhipuAI Credentials +ZHIPUAI_API_KEY= + +# Baichuan Credentials +BAICHUAN_API_KEY= +BAICHUAN_SECRET_KEY= + +# ChatGLM Credentials +CHATGLM_API_BASE= + +# Xinference Credentials +XINFERENCE_SERVER_URL= +XINFERENCE_GENERATION_MODEL_UID= +XINFERENCE_CHAT_MODEL_UID= +XINFERENCE_EMBEDDINGS_MODEL_UID= +XINFERENCE_RERANK_MODEL_UID= + +# OpenLLM Credentials +OPENLLM_SERVER_URL= + +# LocalAI Credentials +LOCALAI_SERVER_URL= + +# Cohere Credentials +COHERE_API_KEY= + +# Jina Credentials +JINA_API_KEY= + +# Ollama Credentials +OLLAMA_BASE_URL= + +# Together API Key +TOGETHER_API_KEY= + +# Mock Switch +MOCK_SWITCH=false + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTION_API_KEY= + +# Volcengine MaaS Credentials +VOLC_API_KEY= +VOLC_SECRET_KEY= +VOLC_MODEL_ENDPOINT_ID= +VOLC_EMBEDDING_ENDPOINT_ID= + +# 360 AI Credentials +ZHINAO_API_KEY= + +# VESSL AI Credentials +VESSL_AI_MODEL_NAME= +VESSL_AI_API_KEY= +VESSL_AI_ENDPOINT_URL= + +# GPUStack Credentials +GPUSTACK_SERVER_URL= +GPUSTACK_API_KEY= + +# Gitee AI Credentials +GITEE_AI_API_KEY= + +# xAI Credentials +XAI_API_KEY= +XAI_API_BASE= diff --git a/api/tests/integration_tests/.gitignore b/api/tests/integration_tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..426667562b31dac736680e7aac2c76c06d98a688 --- /dev/null +++ b/api/tests/integration_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/integration_tests/__init__.py b/api/tests/integration_tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3ab4b74be68f5fd0d7935f196ef35cffb0b905 --- /dev/null +++ b/api/tests/integration_tests/conftest.py @@ -0,0 +1,19 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + + +# Loading the .env file if it exists +def _load_env() -> None: + dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env") + if os.path.exists(dotenv_path): + from dotenv import load_dotenv + + load_dotenv(dotenv_path) + + +_load_env() diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py new file mode 100644 index 0000000000000000000000000000000000000000..32e8c11d19f3992524197ded96bb15669adbbcab --- /dev/null +++ b/api/tests/integration_tests/controllers/app_fixture.py @@ -0,0 +1,25 @@ +import pytest + +from app_factory import create_app +from configs import dify_config + +mock_user = type( + "MockUser", + (object,), + { + "is_authenticated": True, + "id": "123", + "is_editor": True, + "is_dataset_editor": True, + "status": "active", + "get_id": "123", + "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", + }, +) + + +@pytest.fixture +def app(): + app = create_app() + dify_config.LOGIN_DISABLED = True + return app diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py new file mode 100644 index 0000000000000000000000000000000000000000..276ad3a7ed818afb7e9b6f00c286c272593fd114 --- /dev/null +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -0,0 +1,9 @@ +from unittest.mock import patch + +from app_fixture import mock_user # type: ignore + + +def test_post_requires_login(app): + with app.test_client() as client, patch("flask_login.utils._get_user", mock_user): + response = client.get("/console/api/data-source/integrates") + assert response.status_code == 200 diff --git a/api/tests/integration_tests/model_runtime/__init__.py b/api/tests/integration_tests/model_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..5092af4f13b2ffb8f42837fc7e112d1f617e5ade --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -0,0 +1,98 @@ +import os +from collections.abc import Iterable +from typing import Any, Literal, Union + +import anthropic +import pytest +from _pytest.monkeypatch import MonkeyPatch +from anthropic import Stream +from anthropic.resources import Messages +from anthropic.types import ( + ContentBlock, + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageDeltaUsage, + MessageParam, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, + TextDelta, + Usage, +) +from anthropic.types.message_delta_event import Delta + +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + + +class MockAnthropicClass: + @staticmethod + def mocked_anthropic_chat_create_sync(model: str) -> Message: + return Message( + id="msg-123", + type="message", + role="assistant", + content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")], + model=model, + stop_reason="stop_sequence", + usage=Usage(input_tokens=1, output_tokens=1), + ) + + @staticmethod + def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]: + full_response_text = "hello, I'm a chatbot from anthropic" + + yield MessageStartEvent( + type="message_start", + message=Message( + id="msg-123", + content=[], + role="assistant", + model=model, + stop_reason=None, + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ), + ) + + index = 0 + for i in range(0, len(full_response_text)): + yield ContentBlockDeltaEvent( + type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index + ) + + index += 1 + + yield MessageDeltaEvent( + type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1) + ) + + yield MessageStopEvent(type="message_stop") + + def mocked_anthropic( + self: Messages, + *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any, + ) -> Union[Message, Stream[MessageStreamEvent]]: + if len(self._client.api_key) < 18: + raise anthropic.AuthenticationError("Invalid API key") + + if stream: + return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model) + else: + return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model) + + +@pytest.fixture +def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/fishaudio.py b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py new file mode 100644 index 0000000000000000000000000000000000000000..bec3babeafddab18f867e56eb8381bf28d349f93 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py @@ -0,0 +1,82 @@ +import os +from collections.abc import Callable +from typing import Literal + +import httpx +import pytest +from _pytest.monkeypatch import MonkeyPatch + + +def mock_get(*args, **kwargs): + if kwargs.get("headers", {}).get("Authorization") != "Bearer test": + raise httpx.HTTPStatusError( + "Invalid API key", + request=httpx.Request("GET", ""), + response=httpx.Response(401), + ) + + return httpx.Response( + 200, + json={ + "items": [ + {"title": "Model 1", "_id": "model1"}, + {"title": "Model 2", "_id": "model2"}, + ] + }, + request=httpx.Request("GET", ""), + ) + + +def mock_stream(*args, **kwargs): + class MockStreamResponse: + def __init__(self): + self.status_code = 200 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def iter_bytes(self): + yield b"Mocked audio data" + + return MockStreamResponse() + + +def mock_fishaudio( + monkeypatch: MonkeyPatch, + methods: list[Literal["list-models", "tts"]], +) -> Callable[[], None]: + """ + mock fishaudio module + + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function + """ + + def unpatch() -> None: + monkeypatch.undo() + + if "list-models" in methods: + monkeypatch.setattr(httpx, "get", mock_get) + + if "tts" in methods: + monkeypatch.setattr(httpx, "stream", mock_stream) + + return unpatch + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_fishaudio_mock(request, monkeypatch): + methods = request.param if hasattr(request, "param") else [] + if MOCK: + unpatch = mock_fishaudio(monkeypatch, methods=methods) + + yield + + if MOCK: + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py new file mode 100644 index 0000000000000000000000000000000000000000..3a26b99e37507c99d3a60f69de4b1bbe8c45ec74 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -0,0 +1,115 @@ +from unittest.mock import MagicMock + +import google.generativeai.types.generation_types as generation_config_types # type: ignore +import pytest +from _pytest.monkeypatch import MonkeyPatch +from google.ai import generativelanguage as glm +from google.ai.generativelanguage_v1beta.types import content as gag_content +from google.generativeai import GenerativeModel +from google.generativeai.types import GenerateContentResponse, content_types, safety_types +from google.generativeai.types.generation_types import BaseGenerateContentResponse + +from extensions import ext_redis + + +class MockGoogleResponseClass: + _done = False + + def __iter__(self): + full_response_text = "it's google!" + + for i in range(0, len(full_response_text) + 1, 1): + if i == len(full_response_text): + self._done = True + yield GenerateContentResponse( + done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] + ) + else: + yield GenerateContentResponse( + done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] + ) + + +class MockGoogleResponseCandidateClass: + finish_reason = "stop" + + @property + def content(self) -> gag_content.Content: + return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) + + +class MockGoogleClass: + @staticmethod + def generate_content_sync() -> GenerateContentResponse: + return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) + + @staticmethod + def generate_content_stream() -> MockGoogleResponseClass: + return MockGoogleResponseClass() + + def generate_content( + self: GenerativeModel, + contents: content_types.ContentsType, + *, + generation_config: generation_config_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + stream: bool = False, + **kwargs, + ) -> GenerateContentResponse: + if stream: + return MockGoogleClass.generate_content_stream() + + return MockGoogleClass.generate_content_sync() + + @property + def generative_response_text(self) -> str: + return "it's google!" + + @property + def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: + return [MockGoogleResponseCandidateClass()] + + +def mock_configure(api_key: str): + if len(api_key) < 16: + raise Exception("Invalid API key") + + +class MockFileState: + def __init__(self): + self.name = "FINISHED" + + +class MockGoogleFile: + def __init__(self, name: str = "mock_file_name"): + self.name = name + self.state = MockFileState() + + +def mock_get_file(name: str) -> MockGoogleFile: + return MockGoogleFile(name) + + +def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile: + return MockGoogleFile() + + +@pytest.fixture +def setup_google_mock(request, monkeypatch: MonkeyPatch): + monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) + monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) + monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) + monkeypatch.setattr("google.generativeai.configure", mock_configure) + monkeypatch.setattr("google.generativeai.get_file", mock_get_file) + monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file) + + yield + + monkeypatch.undo() + + +@pytest.fixture +def setup_mock_redis() -> None: + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.setex = MagicMock(return_value=None) + ext_redis.redis_client.exists = MagicMock(return_value=True) diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..4de52514408a06b2aaf4d6b3b6a41eb2dd8ca576 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -0,0 +1,20 @@ +import os + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from huggingface_hub import InferenceClient # type: ignore + +from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..77c7e7f5e4089cc104e36d31781c1341dbaf506e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -0,0 +1,56 @@ +import re +from collections.abc import Generator +from typing import Any, Literal, Optional, Union + +from _pytest.monkeypatch import MonkeyPatch +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.inference._text_generation import ( # type: ignore + Details, + StreamDetails, + TextGenerationResponse, + TextGenerationStreamResponse, + Token, +) +from huggingface_hub.utils import BadRequestError # type: ignore + + +class MockHuggingfaceChatClass: + @staticmethod + def generate_create_sync(model: str) -> TextGenerationResponse: + response = TextGenerationResponse( + generated_text="You can call me Miku Miku o~e~o~", + details=Details( + finish_reason="length", + generated_tokens=6, + tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], + ), + ) + + return response + + @staticmethod + def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]: + full_text = "You can call me Miku Miku o~e~o~" + + for i in range(0, len(full_text)): + response = TextGenerationStreamResponse( + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), + ) + response.generated_text = full_text[i] + response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) + + yield response + + def text_generation( + self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any + ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: + # check if key is valid + if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): + raise BadRequestError("Invalid API key") + + if model is None: + raise BadRequestError("Invalid model") + + if stream: + return MockHuggingfaceChatClass.generate_create_stream(model) + return MockHuggingfaceChatClass.generate_create_sync(model) diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a721c803fc5234ed7c668df716f68cf56e59c3 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -0,0 +1,94 @@ +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter + + +class MockTEIClass: + @staticmethod + def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + # During mock, we don't have a real server to query, so we just return a dummy value + if "rerank" in model_name: + model_type = "reranker" + else: + model_type = "embedding" + + return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) + + @staticmethod + def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + # Use space as token separator, and split the text into tokens + tokenized_texts = [] + for text in texts: + tokens = text.split(" ") + current_index = 0 + tokenized_text = [] + for idx, token in enumerate(tokens): + s_token = { + "id": idx, + "text": token, + "special": False, + "start": current_index, + "stop": current_index + len(token), + } + current_index += len(token) + 1 + tokenized_text.append(s_token) + tokenized_texts.append(tokenized_text) + return tokenized_texts + + @staticmethod + def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + # { + # "object": "list", + # "data": [ + # { + # "object": "embedding", + # "embedding": [...], + # "index": 0 + # } + # ], + # "model": "MODEL_NAME", + # "usage": { + # "prompt_tokens": 3, + # "total_tokens": 3 + # } + # } + embeddings = [] + for idx in range(len(texts)): + embedding = [0.1] * 768 + embeddings.append( + { + "object": "embedding", + "embedding": embedding, + "index": idx, + } + ) + return { + "object": "list", + "data": embeddings, + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": sum(len(text.split(" ")) for text in texts), + "total_tokens": sum(len(text.split(" ")) for text in texts), + }, + } + + @staticmethod + def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: + # Example response: + # [ + # { + # "index": 0, + # "text": "Deep Learning is ...", + # "score": 0.9950755 + # } + # ] + reranked_docs = [] + for idx, text in enumerate(texts): + reranked_docs.append( + { + "index": idx, + "text": text, + "score": 0.9, + } + ) + # For mock, only return the first document + break + return reranked_docs diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..4e00660a29162f1046dce3fc8fbce1badfb612c2 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py @@ -0,0 +1,59 @@ +import os +from collections.abc import Callable +from typing import Any, Literal + +import pytest + +# import monkeypatch +from _pytest.monkeypatch import MonkeyPatch +from nomic import embed # type: ignore + + +def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict: + texts_len = len(texts) + + foo_embedding_sample = 0.123456 + + combined = { + "embeddings": [[foo_embedding_sample for _ in range(768)] for _ in range(texts_len)], + "usage": {"prompt_tokens": texts_len, "total_tokens": texts_len}, + "model": model, + "inference_mode": "remote", + } + + return combined + + +def mock_nomic( + monkeypatch: MonkeyPatch, + methods: list[Literal["text_embedding"]], +) -> Callable[[], None]: + """ + mock nomic module + + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function + """ + + def unpatch() -> None: + monkeypatch.undo() + + if "text_embedding" in methods: + monkeypatch.setattr(embed, "text", create_embedding) + + return unpatch + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_nomic_mock(request, monkeypatch): + methods = request.param if hasattr(request, "param") else [] + if MOCK: + unpatch = mock_nomic(monkeypatch, methods=methods) + + yield + + if MOCK: + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6637f4f212a50e7ccd9f5eb126399571f8708b46 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -0,0 +1,71 @@ +import os +from collections.abc import Callable +from typing import Literal + +import pytest + +# import monkeypatch +from _pytest.monkeypatch import MonkeyPatch +from openai.resources.audio.transcriptions import Transcriptions +from openai.resources.chat import Completions as ChatCompletions +from openai.resources.completions import Completions +from openai.resources.embeddings import Embeddings +from openai.resources.models import Models +from openai.resources.moderations import Moderations + +from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass +from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass +from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass +from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass +from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass +from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass + + +def mock_openai( + monkeypatch: MonkeyPatch, + methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]], +) -> Callable[[], None]: + """ + mock openai module + + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function + """ + + def unpatch() -> None: + monkeypatch.undo() + + if "completion" in methods: + monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create) + + if "chat" in methods: + monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create) + + if "remote" in methods: + monkeypatch.setattr(Models, "list", MockModelClass.list) + + if "moderation" in methods: + monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create) + + if "speech2text" in methods: + monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create) + + if "text_embedding" in methods: + monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings) + + return unpatch + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_openai_mock(request, monkeypatch): + methods = request.param if hasattr(request, "param") else [] + if MOCK: + unpatch = mock_openai(monkeypatch, methods=methods) + + yield + + if MOCK: + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc5df766748ec52f9477c86644706feff3667d7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -0,0 +1,267 @@ +import re +from collections.abc import Generator +from json import dumps +from time import time + +# import monkeypatch +from typing import Any, Literal, Optional, Union + +from openai import AzureOpenAI, OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.chat.completions import Completions +from openai.types import Completion as CompletionMessage +from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionToolParam, + completion_create_params, +) +from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion +from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice +from openai.types.chat.chat_completion_chunk import ( + Choice, + ChoiceDelta, + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall +from openai.types.chat.chat_completion_message_tool_call import Function +from openai.types.completion_usage import CompletionUsage + +from core.model_runtime.errors.invoke import InvokeAuthorizationError + + +class MockChatClass: + @staticmethod + def generate_function_call( + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + ) -> Optional[FunctionCall]: + if not functions or len(functions) == 0: + return None + function: completion_create_params.Function = functions[0] + function_name = function["name"] + function_description = function["description"] + function_parameters = function["parameters"] + function_parameters_type = function_parameters["type"] + if function_parameters_type != "object": + return None + function_parameters_properties = function_parameters["properties"] + function_parameters_required = function_parameters["required"] + parameters = {} + for parameter_name, parameter in function_parameters_properties.items(): + if parameter_name not in function_parameters_required: + continue + parameter_type = parameter["type"] + if parameter_type == "string": + if "enum" in parameter: + if len(parameter["enum"]) == 0: + continue + parameters[parameter_name] = parameter["enum"][0] + else: + parameters[parameter_name] = "kawaii" + elif parameter_type == "integer": + parameters[parameter_name] = 114514 + elif parameter_type == "number": + parameters[parameter_name] = 1919810.0 + elif parameter_type == "boolean": + parameters[parameter_name] = True + + return FunctionCall(name=function_name, arguments=dumps(parameters)) + + @staticmethod + def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: + list_tool_calls = [] + if not tools or len(tools) == 0: + return None + tool = tools[0] + + if "type" in tools and tools["type"] != "function": + return None + + function = tool["function"] + + function_call = MockChatClass.generate_function_call(functions=[function]) + if function_call is None: + return None + + list_tool_calls.append( + ChatCompletionMessageToolCall( + id="sakurajima-mai", + function=Function( + name=function_call.name, + arguments=function_call.arguments, + ), + type="function", + ) + ) + + return list_tool_calls + + @staticmethod + def mocked_openai_chat_create_sync( + model: str, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + ) -> CompletionMessage: + tool_calls = [] + function_call = MockChatClass.generate_function_call(functions=functions) + if not function_call: + tool_calls = MockChatClass.generate_tool_calls(tools=tools) + + return _ChatCompletion( + id="cmpl-3QJQa5jXJ5Z5X", + choices=[ + _ChatCompletionChoice( + finish_reason="content_filter", + index=0, + message=ChatCompletionMessage( + content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls + ), + ) + ], + created=int(time()), + model=model, + object="chat.completion", + system_fingerprint="", + usage=CompletionUsage( + prompt_tokens=2, + completion_tokens=1, + total_tokens=3, + ), + ) + + @staticmethod + def mocked_openai_chat_create_stream( + model: str, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + ) -> Generator[ChatCompletionChunk, None, None]: + tool_calls = [] + function_call = MockChatClass.generate_function_call(functions=functions) + if not function_call: + tool_calls = MockChatClass.generate_tool_calls(tools=tools) + + full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" + for i in range(0, len(full_text) + 1): + if i == len(full_text): + yield ChatCompletionChunk( + id="cmpl-3QJQa5jXJ5Z5X", + choices=[ + Choice( + delta=ChoiceDelta( + content="", + function_call=ChoiceDeltaFunctionCall( + name=function_call.name, + arguments=function_call.arguments, + ) + if function_call + else None, + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="misaka-mikoto", + function=ChoiceDeltaToolCallFunction( + name=tool_calls[0].function.name, + arguments=tool_calls[0].function.arguments, + ), + type="function", + ) + ] + if tool_calls and len(tool_calls) > 0 + else None, + ), + finish_reason="function_call", + index=0, + ) + ], + created=int(time()), + model=model, + object="chat.completion.chunk", + system_fingerprint="", + usage=CompletionUsage( + prompt_tokens=2, + completion_tokens=17, + total_tokens=19, + ), + ) + else: + yield ChatCompletionChunk( + id="cmpl-3QJQa5jXJ5Z5X", + choices=[ + Choice( + delta=ChoiceDelta( + content=full_text[i], + role="assistant", + ), + finish_reason="content_filter", + index=0, + ) + ], + created=int(time()), + model=model, + object="chat.completion.chunk", + system_fingerprint="", + ) + + def chat_create( + self: Completions, + *, + messages: list[ChatCompletionMessageParam], + model: Union[ + str, + Literal[ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ], + ], + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ): + openai_models = [ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ] + azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): + raise InvokeAuthorizationError("Invalid base url") + if model in openai_models + azure_openai_models: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: + # sometime, provider use OpenAI compatible API will not have api key or have different api key format + # so we only check if model is in openai_models + raise InvokeAuthorizationError("Invalid api key") + if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: + raise InvokeAuthorizationError("Invalid api key") + if stream: + return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) + + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..14223668e036d922db896f1c3de06c9ec7df8190 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -0,0 +1,130 @@ +import re +from collections.abc import Generator +from time import time + +# import monkeypatch +from typing import Any, Literal, Optional, Union + +from openai import AzureOpenAI, BadRequestError, OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.completions import Completions +from openai.types import Completion as CompletionMessage +from openai.types.completion import CompletionChoice +from openai.types.completion_usage import CompletionUsage + +from core.model_runtime.errors.invoke import InvokeAuthorizationError + + +class MockCompletionsClass: + @staticmethod + def mocked_openai_completion_create_sync(model: str) -> CompletionMessage: + return CompletionMessage( + id="cmpl-3QJQa5jXJ5Z5X", + object="text_completion", + created=int(time()), + model=model, + system_fingerprint="", + choices=[ + CompletionChoice( + text="mock", + index=0, + logprobs=None, + finish_reason="stop", + ) + ], + usage=CompletionUsage( + prompt_tokens=2, + completion_tokens=1, + total_tokens=3, + ), + ) + + @staticmethod + def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]: + full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" + for i in range(0, len(full_text) + 1): + if i == len(full_text): + yield CompletionMessage( + id="cmpl-3QJQa5jXJ5Z5X", + object="text_completion", + created=int(time()), + model=model, + system_fingerprint="", + choices=[ + CompletionChoice( + text="", + index=0, + logprobs=None, + finish_reason="stop", + ) + ], + usage=CompletionUsage( + prompt_tokens=2, + completion_tokens=17, + total_tokens=19, + ), + ) + else: + yield CompletionMessage( + id="cmpl-3QJQa5jXJ5Z5X", + object="text_completion", + created=int(time()), + model=model, + system_fingerprint="", + choices=[ + CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter") + ], + ) + + def completion_create( + self: Completions, + *, + model: Union[ + str, + Literal[ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ], + ], + prompt: Union[str, list[str], list[int], list[list[int]], None], + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ): + openai_models = [ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ] + azure_openai_models = ["gpt-35-turbo-instruct"] + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): + raise InvokeAuthorizationError("Invalid base url") + if model in openai_models + azure_openai_models: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: + # sometime, provider use OpenAI compatible API will not have api key or have different api key format + # so we only check if model is in openai_models + raise InvokeAuthorizationError("Invalid api key") + if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: + raise InvokeAuthorizationError("Invalid api key") + + if not prompt: + raise BadRequestError("Invalid prompt") + if stream: + return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) + + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc1fa9ff1db4906b0f3b5e7eb49e3c287e2be01 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -0,0 +1,57 @@ +import re +from typing import Any, Literal, Union + +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.embeddings import Embeddings +from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage +from openai.types.embedding import Embedding + +from core.model_runtime.errors.invoke import InvokeAuthorizationError + + +class MockEmbeddingsClass: + def create_embeddings( + self: Embeddings, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str, Literal["text-embedding-ada-002"]], + encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> CreateEmbeddingResponse: + if isinstance(input, str): + input = [input] + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): + raise InvokeAuthorizationError("Invalid base url") + + if len(self._client.api_key) < 18: + raise InvokeAuthorizationError("Invalid API key") + + if encoding_format == "float": + return CreateEmbeddingResponse( + data=[ + Embedding(embedding=[0.23333 for _ in range(233)], index=i, object="embedding") + for i in range(len(input)) + ], + model=model, + object="list", + # marked: usage of embeddings should equal the number of testcase + usage=Usage(prompt_tokens=2, total_tokens=2), + ) + + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" # noqa: E501 + + data = [] + for i, text in enumerate(input): + obj = Embedding(embedding=[], index=i, object="embedding") + obj.embedding = embeddings + + data.append(obj) + + return CreateEmbeddingResponse( + data=data, + model=model, + object="list", + # marked: usage of embeddings should equal the number of testcase + usage=Usage(prompt_tokens=2, total_tokens=2), + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..e26855344e63bfa39482054f75a22602b5eda293 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -0,0 +1,140 @@ +import re +from typing import Any, Literal, Union + +from openai._types import NOT_GIVEN, NotGiven +from openai.resources.moderations import Moderations +from openai.types import ModerationCreateResponse +from openai.types.moderation import Categories, CategoryScores, Moderation + +from core.model_runtime.errors.invoke import InvokeAuthorizationError + + +class MockModerationClass: + def moderation_create( + self: Moderations, + *, + input: Union[str, list[str]], + model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> ModerationCreateResponse: + if isinstance(input, str): + input = [input] + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): + raise InvokeAuthorizationError("Invalid base url") + + if len(self._client.api_key) < 18: + raise InvokeAuthorizationError("Invalid API key") + + for text in input: + result = [] + if "kill" in text: + moderation_categories = { + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, + "illicit": False, + "illicit/violent": False, + } + moderation_categories_scores = { + "harassment": 1.0, + "harassment/threatening": 1.0, + "hate": 1.0, + "hate/threatening": 1.0, + "self-harm": 1.0, + "self-harm/instructions": 1.0, + "self-harm/intent": 1.0, + "sexual": 1.0, + "sexual/minors": 1.0, + "violence": 1.0, + "violence/graphic": 1.0, + "illicit": 1.0, + "illicit/violent": 1.0, + } + category_applied_input_types = { + "sexual": ["text", "image"], + "hate": ["text"], + "harassment": ["text"], + "self-harm": ["text", "image"], + "sexual/minors": ["text"], + "hate/threatening": ["text"], + "violence/graphic": ["text", "image"], + "self-harm/intent": ["text", "image"], + "self-harm/instructions": ["text", "image"], + "harassment/threatening": ["text"], + "violence": ["text", "image"], + "illicit": ["text"], + "illicit/violent": ["text"], + } + result.append( + Moderation( + flagged=True, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + category_applied_input_types=category_applied_input_types, + ) + ) + else: + moderation_categories = { + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, + "illicit": False, + "illicit/violent": False, + } + moderation_categories_scores = { + "harassment": 0.0, + "harassment/threatening": 0.0, + "hate": 0.0, + "hate/threatening": 0.0, + "self-harm": 0.0, + "self-harm/instructions": 0.0, + "self-harm/intent": 0.0, + "sexual": 0.0, + "sexual/minors": 0.0, + "violence": 0.0, + "violence/graphic": 0.0, + "illicit": 0.0, + "illicit/violent": 0.0, + } + category_applied_input_types = { + "sexual": ["text", "image"], + "hate": ["text"], + "harassment": ["text"], + "self-harm": ["text", "image"], + "sexual/minors": ["text"], + "hate/threatening": ["text"], + "violence/graphic": ["text", "image"], + "self-harm/intent": ["text", "image"], + "self-harm/instructions": ["text", "image"], + "harassment/threatening": ["text"], + "violence": ["text", "image"], + "illicit": ["text"], + "illicit/violent": ["text"], + } + result.append( + Moderation( + flagged=False, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + category_applied_input_types=category_applied_input_types, + ) + ) + + return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py new file mode 100644 index 0000000000000000000000000000000000000000..704dbad5d288cae5fbebe4dbb417d45ced626213 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -0,0 +1,22 @@ +from time import time + +from openai.types.model import Model + + +class MockModelClass: + """ + mock class for openai.models.Models + """ + + def list( + self, + **kwargs, + ) -> list[Model]: + return [ + Model( + id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ", + created=int(time()), + object="model", + owned_by="organization:org-123", + ) + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..a51dcab4be7467529dc588fde9daca616f3216e6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -0,0 +1,29 @@ +import re +from typing import Any, Literal, Union + +from openai._types import NOT_GIVEN, FileTypes, NotGiven +from openai.resources.audio.transcriptions import Transcriptions +from openai.types.audio.transcription import Transcription + +from core.model_runtime.errors.invoke import InvokeAuthorizationError + + +class MockSpeech2TextClass: + def speech2text_create( + self: Transcriptions, + *, + file: FileTypes, + model: Union[str, Literal["whisper-1"]], + language: str | NotGiven = NOT_GIVEN, + prompt: str | NotGiven = NOT_GIVEN, + response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> Transcription: + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): + raise InvokeAuthorizationError("Invalid base url") + + if len(self._client.api_key) < 18: + raise InvokeAuthorizationError("Invalid API key") + + return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10") diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py new file mode 100644 index 0000000000000000000000000000000000000000..e2abaa52b939a6cadc38f45e0b19fb237a4b2199 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -0,0 +1,169 @@ +import os +import re +from typing import Union + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from requests import Response +from requests.sessions import Session +from xinference_client.client.restful.restful_client import ( # type: ignore + Client, + RESTfulChatModelHandle, + RESTfulEmbeddingModelHandle, + RESTfulGenerateModelHandle, + RESTfulRerankModelHandle, +) +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore + + +class MockXinferenceClass: + def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]: + if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): + raise RuntimeError("404 Not Found") + + if model_uid == "generate": + return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) + if model_uid == "chat": + return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) + if model_uid == "embedding": + return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) + if model_uid == "rerank": + return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) + raise RuntimeError("404 Not Found") + + def get(self: Session, url: str, **kwargs): + response = Response() + if "v1/models/" in url: + # get model uid + model_uid = url.split("/")[-1] or "" + if not re.match( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid + ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: + response.status_code = 404 + response._content = b"{}" + return response + + # check if url is valid + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): + response.status_code = 404 + response._content = b"{}" + return response + + if model_uid in {"generate", "chat"}: + response.status_code = 200 + response._content = b"""{ + "model_type": "LLM", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "chatglm3-6b", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate", + "chat" + ], + "model_description": "latest chatglm3", + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantization": "none", + "model_hub": "huggingface", + "revision": null, + "context_length": 2048, + "replica": 1 + }""" + return response + + elif model_uid == "embedding": + response.status_code = 200 + response._content = b"""{ + "model_type": "embedding", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "bge", + "model_lang": [ + "en" + ], + "revision": null, + "max_tokens": 512 + }""" + return response + + elif "v1/cluster/auth" in url: + response.status_code = 200 + response._content = b"""{ + "auth": true + }""" + return response + + def _check_cluster_authenticated(self): + self._cluster_authed = True + + def rerank( + self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool + ) -> dict: + # check if self._model_uid is a valid uuid + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "rerank" + ): + raise RuntimeError("404 Not Found") + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): + raise RuntimeError("404 Not Found") + + if top_n is None: + top_n = 1 + + return { + "results": [ + {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) + ] + } + + def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: + # check if self._model_uid is a valid uuid + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "embedding" + ): + raise RuntimeError("404 Not Found") + + if isinstance(input, str): + input = [input] + ipt_len = len(input) + + embedding = Embedding( + object="list", + model=self._model_uid, + data=[ + EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) + for i in range(ipt_len) + ], + usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), + ) + + return embedding + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_xinference_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) + monkeypatch.setattr(Session, "get", MockXinferenceClass.get) + monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) + monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/__init__.py b/api/tests/integration_tests/model_runtime/anthropic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7e9ec48743bf9967f74df2886c430cd760e883 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -0,0 +1,92 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_anthropic_mock): + model = AnthropicLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"}) + + model.validate_credentials( + model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_invoke_model(setup_anthropic_mock): + model = AnthropicLargeLanguageModel() + + response = model.invoke( + model="claude-instant-1.2", + credentials={ + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_invoke_stream_model(setup_anthropic_mock): + model = AnthropicLargeLanguageModel() + + response = model.invoke( + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = AnthropicLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1e50f431849fd04414f9b4b539de81d6b341cd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider +from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_validate_provider_credentials(setup_anthropic_mock): + provider = AnthropicProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/assets/audio.mp3 b/api/tests/integration_tests/model_runtime/assets/audio.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..6796e9e373b07d3da686fa8cf514534cb6716ec4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/assets/audio.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29b714073410fefc10ecb80526b5c7c33df73b0830ff0e7778d5065a6cfcae3e +size 218880 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..b995077984e9108d6b5e3d41036f7f92830a83be --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -0,0 +1,109 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_validate_credentials(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="gpt-35-turbo", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + ) + + model.validate_credentials( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + ) + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) +def test_invoke_stream_model(setup_azure_ai_studio_mock): + model = AzureAIStudioLargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = AzureAIStudioLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gpt-35-turbo", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..8afe38b09b9f022c1934395fb850efbdc0b4f6be --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider + + +def test_validate_provider_credentials(): + provider = AzureAIStudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")} + ) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..4d72327c0ec43c39f16f962adafd335fafb65c98 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -0,0 +1,42 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel + + +def test_validate_credentials(): + model = AzureRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="azure-ai-studio-rerank-v1", + credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, + ) + + +def test_invoke_model(): + model = AzureRerankModel() + + result = model.invoke( + model="azure-ai-studio-rerank-v1", + credentials={ + "api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"), + "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/__init__.py b/api/tests/integration_tests/model_runtime/azure_openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..216c50a1823c8db270910fbfa3b0e053de105bc3 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -0,0 +1,292 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAILargeLanguageModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="gpt35", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo", + }, + ) + + model.validate_credentials( + model="gpt35", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_validate_credentials_for_completion_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="gpt-35-turbo-instruct", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo-instruct", + }, + ) + + model.validate_credentials( + model="gpt-35-turbo-instruct", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_invoke_completion_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo-instruct", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_invoke_stream_completion_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo-instruct", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt35", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt35", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_vision(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-4v", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-4-vision-preview", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content=[ + TextPromptMessageContent( + data="Hello World!", + ), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = AzureOpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-35-turbo", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +def test_get_num_tokens(): + model = AzureOpenAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gpt-35-turbo-instruct", + credentials={"base_model_name": "gpt-35-turbo-instruct"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 3 + + num_tokens = model.get_num_tokens( + model="gpt35", + credentials={"base_model_name": "gpt-35-turbo"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ae2b2e5b740cd25da1baf9e5ca7e097bb2a5af --- /dev/null +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -0,0 +1,62 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = AzureOpenAITextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="embedding", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "text-embedding-ada-002", + }, + ) + + model.validate_credentials( + model="embedding", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = AzureOpenAITextEmbeddingModel() + + result = model.invoke( + model="embedding", + credentials={ + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = AzureOpenAITextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"] + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/baichuan/__init__.py b/api/tests/integration_tests/model_runtime/baichuan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7fe96891143909292be0f72497a3bbe2920f6b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -0,0 +1,172 @@ +import os +from collections.abc import Generator +from time import sleep + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLanguageModel + + +def test_predefined_models(): + model = BaichuanLanguageModel() + model_schemas = model.predefined_models() + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + sleep(3) + model = BaichuanLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} + ) + + model.validate_credentials( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + ) + + +def test_invoke_model(): + sleep(3) + model = BaichuanLanguageModel() + + response = model.invoke( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_model_with_system_message(): + sleep(3) + model = BaichuanLanguageModel() + + response = model.invoke( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + prompt_messages=[ + SystemPromptMessage(content="请记住你是Kasumi。"), + UserPromptMessage(content="现在告诉我你是谁?"), + ], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + sleep(3) + model = BaichuanLanguageModel() + + response = model.invoke( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_with_search(): + sleep(3) + model = BaichuanLanguageModel() + + response = model.invoke( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "with_search_enhance": True, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + total_message = "" + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True + total_message += chunk.delta.message.content + + assert "不" not in total_message + + +def test_get_num_tokens(): + sleep(3) + model = BaichuanLanguageModel() + + response = model.get_num_tokens( + model="baichuan2-turbo", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], + ) + + assert isinstance(response, int) + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..4036edfb7a7062114ee9fe15f86ad96463e11690 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider + + +def test_validate_provider_credentials(): + provider = BaichuanProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc63f3978fb99af51b0511d0edee4987e90cc6a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -0,0 +1,87 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel + + +def test_validate_credentials(): + model = BaichuanTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")} + ) + + +def test_invoke_model(): + model = BaichuanTextEmbeddingModel() + + result = model.invoke( + model="baichuan-text-embedding", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = BaichuanTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="baichuan-text-embedding", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 + + +def test_max_chunks(): + model = BaichuanTextEmbeddingModel() + + result = model.invoke( + model="baichuan-text-embedding", + credentials={ + "api_key": os.environ.get("BAICHUAN_API_KEY"), + }, + texts=[ + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + ], + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/__init__.py b/api/tests/integration_tests/model_runtime/bedrock/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..c19ec35a6e45fcc405dced9b31e28647acea0ed7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -0,0 +1,103 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel + + +def test_validate_credentials(): + model = BedrockLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"}) + + model.validate_credentials( + model="meta.llama2-13b-chat-v1", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + ) + + +def test_invoke_model(): + model = BedrockLargeLanguageModel() + + response = model.invoke( + model="meta.llama2-13b-chat-v1", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = BedrockLargeLanguageModel() + + response = model.invoke( + model="meta.llama2-13b-chat-v1", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + print(chunk) + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = BedrockLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="meta.llama2-13b-chat-v1", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..080727829e9e2faf281f8c3418498834cff81193 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider + + +def test_validate_provider_credentials(): + provider = BedrockProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/chatglm/__init__.py b/api/tests/integration_tests/model_runtime/chatglm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c5229e05141821b6ee20ccf7d7d3a6b1420b39 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -0,0 +1,229 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = ChatGLMLargeLanguageModel() + model_schemas = model.predefined_models() + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = ChatGLMLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"}) + + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = ChatGLMLargeLanguageModel() + + response = model.invoke( + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_model(setup_openai_mock): + model = ChatGLMLargeLanguageModel() + + response = model.invoke( + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_model_with_functions(setup_openai_mock): + model = ChatGLMLargeLanguageModel() + + response = model.invoke( + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。" + ), + UserPromptMessage(content="波士顿天气如何?"), + ], + model_parameters={ + "temperature": 0, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=True, + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(response, Generator) + + call: LLMResultChunk = None + chunks = [] + + for chunk in response: + chunks.append(chunk) + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0: + call = chunk + break + + assert call is not None + assert call.delta.message.tool_calls[0].function.name == "get_current_weather" + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_model_with_functions(setup_openai_mock): + model = ChatGLMLargeLanguageModel() + + response = model.invoke( + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + assert response.message.tool_calls[0].function.name == "get_current_weather" + + +def test_get_num_tokens(): + model = ChatGLMLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 77 + + num_tokens = model.get_num_tokens( + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..7907805d0727725fed57c86a5dd2379355ca801f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_provider_credentials(setup_openai_mock): + provider = ChatGLMProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_base": "hahahaha"}) + + provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/__init__.py b/api/tests/integration_tests/model_runtime/cohere/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f707e935dbeafa716954e6997721d497637528 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -0,0 +1,191 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = CohereLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) + + +def test_validate_credentials_for_completion_model(): + model = CohereLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) + + +def test_invoke_completion_model(): + model = CohereLargeLanguageModel() + + credentials = {"api_key": os.environ.get("COHERE_API_KEY")} + + result = model.invoke( + model="command-light", + credentials=credentials, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1 + + +def test_invoke_stream_completion_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_chat_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "p": 0.99, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_stream_chat_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = CohereLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 3 + + num_tokens = model.get_num_tokens( + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 15 + + +def test_fine_tuned_model(): + model = CohereLargeLanguageModel() + + # test invoke + result = model.invoke( + model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + + +def test_fine_tuned_chat_model(): + model = CohereLargeLanguageModel() + + # test invoke + result = model.invoke( + model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7e6d34984a618e3601ffbf4dcf6ab7df59e659 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.cohere import CohereProvider + + +def test_validate_provider_credentials(): + provider = CohereProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b6922128570ec4c213866fb5172d8f091475eb --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -0,0 +1,40 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel + + +def test_validate_credentials(): + model = CohereRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) + + +def test_invoke_model(): + model = CohereRerankModel() + + result = model.invoke( + model="rerank-english-v2.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) " + "is the capital of the United States. It is a federal district. The President of the USA and many major " + "national government offices are in the territory. This makes it the political center of the United " + "States of America.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26d36635d1b599752e25a1eef40c1119ed5e4a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel + + +def test_validate_credentials(): + model = CohereTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")} + ) + + +def test_invoke_model(): + model = CohereTextEmbeddingModel() + + result = model.invoke( + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 811 + + +def test_get_num_tokens(): + model = CohereTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world"], + ) + + assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/fireworks/__init__.py b/api/tests/integration_tests/model_runtime/fireworks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_llm.py b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..699ca293a2fca828a897aac9894beb0093e5249c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py @@ -0,0 +1,186 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = FireworksLargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"}) + + model.validate_credentials( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = FireworksLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_provider.py b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..a68cf1a1a8fbda3b396b3967864bbde3f155a828 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_provider_credentials(setup_openai_mock): + provider = FireworksProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_text_embedding.py b/api/tests/integration_tests/model_runtime/fireworks/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf723b3a93742e5f457a0c13486ae498d7bfb8b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fireworks/test_text_embedding.py @@ -0,0 +1,54 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fireworks.text_embedding.text_embedding import FireworksTextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = FireworksTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": "invalid_key"} + ) + + model.validate_credentials( + model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = FireworksTextEmbeddingModel() + + result = model.invoke( + model="nomic-ai/nomic-embed-text-v1.5", + credentials={ + "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"), + }, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="foo", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = FireworksTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="nomic-ai/nomic-embed-text-v1.5", + credentials={ + "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/fishaudio/__init__.py b/api/tests/integration_tests/model_runtime/fishaudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3526574b61323841d9ffd74beccf75669a64e180 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py @@ -0,0 +1,33 @@ +import os + +import httpx +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True) +def test_validate_provider_credentials(setup_fishaudio_mock): + print("-----", httpx.get) + provider = FishAudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={ + "api_key": "bad_api_key", + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) + + provider.validate_provider_credentials( + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..f61fee28b98e308c9fd0516ae4d391286d9beee6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py @@ -0,0 +1,32 @@ +import os + +import pytest + +from core.model_runtime.model_providers.fishaudio.tts.tts import ( + FishAudioText2SpeechModel, +) +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True) +def test_invoke_model(setup_fishaudio_mock): + model = FishAudioText2SpeechModel() + + result = model.invoke( + model="tts-default", + tenant_id="test", + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + }, + content_text="Hello, world!", + voice="03397b4c4be74759b72533b663fbd001", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py b/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..753c52ce31d4524c55a5c45dfdc4e2774d1de649 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py @@ -0,0 +1,132 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel + + +def test_predefined_models(): + model = GiteeAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + model = GiteeAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + "stream": False, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_stream_chat_model(): + model = GiteeAILargeLanguageModel() + + result = model.invoke( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + + +def test_get_num_tokens(): + model = GiteeAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="Qwen2-7B-Instruct", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..f12ed54a4578969c936ee65b3a8f86e807a33068 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider + + +def test_validate_provider_credentials(): + provider = GiteeAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "invalid_key"}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5914a61f7f43842fe91313944c90fa4c67c4aa --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel + + +def test_validate_credentials(): + model = GiteeAIRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GiteeAIRerankModel() + result = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + top_n=1, + score_threshold=0.01, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].score >= 0.01 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..4a01453fdd1cdaeb553b8265a09c076345c17a8e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel + + +def test_validate_credentials(): + model = GiteeAISpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-base", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="whisper-base", + credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, + ) + + +def test_invoke_model(): + model = GiteeAISpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file + ) + + assert isinstance(result, str) + assert result == "1 2 3 4 5 6 7 8 9 10" diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..34648f0bc8ae78e69b17dd7944cc0f80f8af62f5 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py @@ -0,0 +1,46 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel + + +def test_validate_credentials(): + model = GiteeAIEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) + + +def test_invoke_model(): + model = GiteeAIEmbeddingModel() + + result = model.invoke( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + user="user", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + + +def test_get_num_tokens(): + model = GiteeAIEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..9f18161a7bb74e092077f716eb6191079e1ca9b1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py @@ -0,0 +1,23 @@ +import os + +from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel + + +def test_invoke_model(): + model = GiteeAIText2SpeechModel() + + result = model.invoke( + model="speecht5_tts", + tenant_id="test", + credentials={ + "api_key": os.environ.get("GITEE_AI_API_KEY"), + }, + content_text="Hello, world!", + voice="", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/google/__init__.py b/api/tests/integration_tests/model_runtime/google/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..65357be6586143248ddbaf9829d422ca01aebc3c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -0,0 +1,183 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_google_mock): + model = GoogleLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": "invalid_key"}) + + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_invoke_model(setup_google_mock): + model = GoogleLargeLanguageModel() + + response = model.invoke( + model="gemini-1.5-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), + AssistantPromptMessage( + content="Why did the scarecrow win an award? Because he was outstanding in his field!" + ), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), + ], + model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_invoke_stream_model(setup_google_mock): + model = GoogleLargeLanguageModel() + + response = model.invoke( + model="gemini-1.5-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), + AssistantPromptMessage( + content="Why did the scarecrow win an award? Because he was outstanding in his field!" + ), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), + ], + model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis): + model = GoogleLargeLanguageModel() + + result = model.invoke( + model="gemini-1.5-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="what do you see?"), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ), + ], + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis): + model = GoogleLargeLanguageModel() + + result = model.invoke( + model="gemini-1.5-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="what do you see?"), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ), + AssistantPromptMessage(content="I see a blue letter 'D' with a gradient from light blue to dark blue."), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="what about now?"), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAABAAAAAQBPJcTWAAADl0lEQVR4nC3Uf0zUdRjA8S9W6w//bGs1DUd5RT+gIY0oYeEqY0QCy5EbAnF4IEgyAnGuCBANWOjih6YOlK0BbtLAX+iAENFgUBLMkzs8uDuO+wEcxx3cgdx9v3fvvn/0x+v5PM+z56/n2T6CIAgIQUEECVsICnqOoC0v8PyLW3n5lW28GhLG9hAFwYowdoRsJ+Tzv3hdEcpOxVvsfDscheI1BIXKy5t7OwiPiCI8IZaIL+OISPKxK/IDdiU6ifwqjqj4WKISP5VN8mHSFNHJA7KnfJQYh7A7+g1i9hXw2dcX2JuSxhcJnxCfnEJ8ygESqtfYl3qA5O/1pKaX8E2Rn7R0JWnKXFkRaX0OhIOqUtJVRWQoj5ChyiOjb4XMQ0fIVB0lM6eEzMO5ZN5x8W1xD1nZh1Fm55OtzOdQTgEqZR6CSi5UjSI5hTnk3bWSX/gj+ccaKCgspaDkNIWlpygc3OTYtZc4fqKcE5Vn+eFkDWUp8ZS1ryOUn66lvGmCyt/8nLwxTlXZcapqL1Nd10B1Uy01FbnUnFVS+2sLvzTWUXfRRMOAgcb6KhovdSA0XnHRdL6Zcy1/0lyTS3NfgJbWNq6cu0nrPyu0FSlpu9pF21037ZFhXLtYT+eNIbp61+jq70bofv8drvf0c2vQz+3O3+nRrNI78JD+/psMfLefe0MG7p+a5v6tP3g48ojhC7mMXP2Y0YoZRitnEcbkMPaglzEnPAoNZrw4hXH1LBOtOiYfa3gcugO1+gnqZwGeaHRMTcyhaduKRjOBxiJfQSsnWq0W7YwVrd3PtH6BaeMST40adJ3V6OwBZlR7mNUvMWswYsiKxTA1gWHOgsGiRzCmRGOcW8QoD855JObWJUxmHSb5nfd4Mc+ZMFv1MjtmuWepSMNiMmAxz2LN2o1gbdmDdV6NdVnE1p6EzajHZp7BtjCLbSnAgsMtE1k8H8OiwyuTWPL4sLduwz5vRLA7XCzbLCw7PTiswzgWJnBsijhNwzhtw6xmRLLmdLC27sU9dBC324un/iieSyF4rPIS1/8eZOOego0NL898Epv14Wz2nMHrsOB12/Glh+Mrfg/fqgufKCHmxSC21SE6JxFdKwjihhFxw4O4aUf0bSKVRyN1pyKNXEcaDUbS3EZan5Sp/zeFtLGO5LUiSRKCJAXwZ0bg73oXv+kBfrsOv8uOXxIJ/JRG4N/9sjME1B3QXAjzd8CqhqWfkT8C4T8Z5+ciRtwo8gAAAABJRU5ErkJggg==", + ), + ] + ), + ], + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + print(f"result: {result.message.content}") + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_get_num_tokens(): + model = GoogleLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gemini-1.5-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..c217e4fe058870ccafdec0a966794af24e9ff09c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.google.google import GoogleProvider +from tests.integration_tests.model_runtime.__mock.google import setup_google_mock + + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) +def test_validate_provider_credentials(setup_google_mock): + provider = GoogleProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/gpustack/__init__.py b/api/tests/integration_tests/model_runtime/gpustack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..f56ad0dadcbe2093a2567a30d862555740f966a9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import ( + GPUStackTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = GPUStackTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GPUStackTextEmbeddingModel() + + result = model.invoke( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "context_size": 8192, + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..326b7b16f04ddaf553e00835fe52404ba1f522af --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py @@ -0,0 +1,162 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = GPUStackLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + "mode": "chat", + }, + ) + + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + ) + + +def test_invoke_completion_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "completion", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = GPUStackLanguageModel() + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 80 + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c2d2d21ca8259968617f95b6e50dc071e4ccb4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py @@ -0,0 +1,107 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.rerank.rerank import ( + GPUStackRerankModel, +) + + +def test_validate_credentials_for_rerank_model(): + model = GPUStackRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_rerank_model(): + model = GPUStackRerankModel() + + response = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + + +def test__invoke(): + model = GPUStackRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Expected docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py b/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..c215e9b73988be06d59f19aedaa48e6c1749c5cf --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py @@ -0,0 +1,55 @@ +import os +from pathlib import Path + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel + + +def test_validate_credentials(): + model = GPUStackSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="faster-whisper-medium", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="faster-whisper-medium", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GPUStackSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + file = Path(audio_file_path).read_bytes() + + result = model.invoke( + model="faster-whisper-medium", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + file=file, + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_tts.py b/api/tests/integration_tests/model_runtime/gpustack/test_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..8997ad074cb1bde504f8dde7bc5127100cf7f196 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_tts.py @@ -0,0 +1,24 @@ +import os + +from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel + + +def test_invoke_model(): + model = GPUStackText2SpeechModel() + + result = model.invoke( + model="cosyvoice-300m-sft", + tenant_id="test", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + content_text="Hello world", + voice="Chinese Female", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/__init__.py b/api/tests/integration_tests/model_runtime/huggingface_hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f90c68029572769d83df6fd02cf6ffdcfa504be --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -0,0 +1,278 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel +from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock + + +@pytest.mark.skip +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="HuggingFaceH4/zephyr-7b-beta", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="fake-model", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, + ) + + model.validate_credentials( + model="HuggingFaceH4/zephyr-7b-beta", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, + ) + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_hosted_inference_api_invoke_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="HuggingFaceH4/zephyr-7b-beta", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="HuggingFaceH4/zephyr-7b-beta", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="openchat/openchat_3.5", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, + ) + + model.validate_credentials( + model="openchat/openchat_3.5", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, + ) + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="openchat/openchat_3.5", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="openchat/openchat_3.5", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="google/mt5-base", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, + ) + + model.validate_credentials( + model="google/mt5-base", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, + ) + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="google/mt5-base", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) +def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): + model = HuggingfaceHubLargeLanguageModel() + + response = model.invoke( + model="google/mt5-base", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = HuggingfaceHubLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="google/mt5-base", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee593f38a494a6721aa71ef0a11ba6027374363 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -0,0 +1,112 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import ( + HuggingfaceHubTextEmbeddingModel, +) + + +def test_hosted_inference_api_validate_credentials(): + model = HuggingfaceHubTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="facebook/bart-base", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": "invalid_key", + }, + ) + + model.validate_credentials( + model="facebook/bart-base", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, + ) + + +def test_hosted_inference_api_invoke_model(): + model = HuggingfaceHubTextEmbeddingModel() + + result = model.invoke( + model="facebook/bart-base", + credentials={ + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_inference_endpoints_validate_credentials(): + model = HuggingfaceHubTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="all-MiniLM-L6-v2", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, + ) + + model.validate_credentials( + model="all-MiniLM-L6-v2", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, + ) + + +def test_inference_endpoints_invoke_model(): + model = HuggingfaceHubTextEmbeddingModel() + + result = model.invoke( + model="all-MiniLM-L6-v2", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, + texts=["hello", "world"], + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 0 + + +def test_get_num_tokens(): + model = HuggingfaceHubTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="all-MiniLM-L6-v2", + credentials={ + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py b/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..33160062e5799605b3680f4ac443d5fdec5dcf9d --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -0,0 +1,73 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( + HuggingfaceTeiTextEmbeddingModel, + TeiHelper, +) +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + # model name is only used in mock + model_name = "embedding" + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="reranker", + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + ) + + model.validate_credentials( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + ) + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + model_name = "embedding" + + result = model.invoke( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..9777367063f5fa539436299ccff23a3b1866cbdd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -0,0 +1,80 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( + HuggingfaceTeiRerankModel, +) +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = "reranker" + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="embedding", + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + ) + + model.validate_credentials( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + ) + + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = "reranker" + + result = model.invoke( + model=model_name, + credentials={ + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), + }, + query="Who is Kasumi?", + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/__init__.py b/api/tests/integration_tests/model_runtime/hunyuan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..b3049a06d9b98aaada74ad463ef8e70f7ffd3693 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py @@ -0,0 +1,90 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel + + +def test_validate_credentials(): + model = HunyuanLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} + ) + + model.validate_credentials( + model="hunyuan-standard", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + ) + + +def test_invoke_model(): + model = HunyuanLargeLanguageModel() + + response = model.invoke( + model="hunyuan-standard", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = HunyuanLargeLanguageModel() + + response = model.invoke( + model="hunyuan-standard", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = HunyuanLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="hunyuan-standard", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..e3748c2ce713d4e9e8af85d453aecbbc31559fbd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py @@ -0,0 +1,20 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.hunyuan.hunyuan import HunyuanProvider + + +def test_validate_provider_credentials(): + provider = HunyuanProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}) + + provider.validate_provider_credentials( + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..69d14dffeebf358a2942ae9f5795a0bc6fc7450b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -0,0 +1,96 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel + + +def test_validate_credentials(): + model = HunyuanTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} + ) + + model.validate_credentials( + model="hunyuan-embedding", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + ) + + +def test_invoke_model(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model="hunyuan-embedding", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = HunyuanTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="hunyuan-embedding", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 + + +def test_max_chunks(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model="hunyuan-embedding", + credentials={ + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, + texts=[ + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + ], + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/jina/__init__.py b/api/tests/integration_tests/model_runtime/jina/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b6128c59d8df49b005ec3059cead3181c26d4f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.jina.jina import JinaProvider + + +def test_validate_provider_credentials(): + provider = JinaProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..290735ec49e625c48496308e24e34a4f0bc0128f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel + + +def test_validate_credentials(): + model = JinaTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"}) + + model.validate_credentials( + model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")} + ) + + +def test_invoke_model(): + model = JinaTextEmbeddingModel() + + result = model.invoke( + model="jina-embeddings-v2-base-en", + credentials={ + "api_key": os.environ.get("JINA_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = JinaTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="jina-embeddings-v2-base-en", + credentials={ + "api_key": os.environ.get("JINA_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/localai/__init__.py b/api/tests/integration_tests/model_runtime/localai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd9f2b3000a317ceb2d7c9f0c85d8edc439166b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -0,0 +1,4 @@ +""" +LocalAI Embedding Interface is temporarily unavailable due to +we could not find a way to test it for now. +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..51e899fd5186cfa9fbcf329f6653fca661927fee --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -0,0 +1,172 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.llm.llm import LocalAILanguageModel + + +def test_validate_credentials_for_chat_model(): + model = LocalAILanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="chinese-llama-2-7b", + credentials={ + "server_url": "hahahaha", + "completion_type": "completion", + }, + ) + + model.validate_credentials( + model="chinese-llama-2-7b", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, + ) + + +def test_invoke_completion_model(): + model = LocalAILanguageModel() + + response = model.invoke( + model="chinese-llama-2-7b", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_chat_model(): + model = LocalAILanguageModel() + + response = model.invoke( + model="chinese-llama-2-7b", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_completion_model(): + model = LocalAILanguageModel() + + response = model.invoke( + model="chinese-llama-2-7b", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_stream_chat_model(): + model = LocalAILanguageModel() + + response = model.invoke( + model="chinese-llama-2-7b", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = LocalAILanguageModel() + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 77 + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..13c7df6d1473b0205f7077079f25ffb1621ad844 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -0,0 +1,96 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel + + +def test_validate_credentials_for_chat_model(): + model = LocalaiRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "server_url": "hahahaha", + "completion_type": "completion", + }, + ) + + model.validate_credentials( + model="bge-reranker-base", + credentials={ + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, + ) + + +def test_invoke_rerank_model(): + model = LocalaiRerankModel() + + response = model.invoke( + model="bge-reranker-base", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + + +def test__invoke(): + model = LocalaiRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Valid invocation + result = model._invoke( + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..91b7a5752ce9733be44d1ff3f118ae5f5f33523f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -0,0 +1,42 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text + + +def test_validate_credentials(): + model = LocalAISpeech2text() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"}) + + model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}) + + +def test_invoke_model(): + model = LocalAISpeech2text() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-1", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + file=file, + user="abc-123", + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/minimax/__init__.py b/api/tests/integration_tests/model_runtime/minimax/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2a28eb9eb2fee68486299a8aa58ac164e7e579 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -0,0 +1,58 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel + + +def test_validate_credentials(): + model = MinimaxTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="embo-01", + credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")}, + ) + + model.validate_credentials( + model="embo-01", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + ) + + +def test_invoke_model(): + model = MinimaxTextEmbeddingModel() + + result = model.invoke( + model="embo-01", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 16 + + +def test_get_num_tokens(): + model = MinimaxTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="embo-01", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..aacde04d326cafa542e0c05c5dc6fab045393b28 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -0,0 +1,143 @@ +import os +from collections.abc import Generator +from time import sleep + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel + + +def test_predefined_models(): + model = MinimaxLargeLanguageModel() + model_schemas = model.predefined_models() + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + sleep(3) + model = MinimaxLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"} + ) + + model.validate_credentials( + model="abab5.5-chat", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + ) + + +def test_invoke_model(): + sleep(3) + model = MinimaxLargeLanguageModel() + + response = model.invoke( + model="abab5-chat", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + sleep(3) + model = MinimaxLargeLanguageModel() + + response = model.invoke( + model="abab5.5-chat", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_with_search(): + sleep(3) + model = MinimaxLargeLanguageModel() + + response = model.invoke( + model="abab5.5-chat", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + total_message = "" + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + total_message += chunk.delta.message.content + assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True + + assert "参考资料" in total_message + + +def test_get_num_tokens(): + sleep(3) + model = MinimaxLargeLanguageModel() + + response = model.get_num_tokens( + model="abab5.5-chat", + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], + ) + + assert isinstance(response, int) + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..575ed13eef124a4179d059e275732406032526b9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -0,0 +1,25 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider + + +def test_validate_provider_credentials(): + provider = MinimaxProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={ + "minimax_api_key": "hahahaha", + "minimax_group_id": "123", + } + ) + + provider.validate_provider_credentials( + credentials={ + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/mixedbread/__init__.py b/api/tests/integration_tests/model_runtime/mixedbread/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/mixedbread/test_provider.py b/api/tests/integration_tests/model_runtime/mixedbread/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..25c9f3ce8dffa9400279ca95a3839682472f8628 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/mixedbread/test_provider.py @@ -0,0 +1,28 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.mixedbread.mixedbread import MixedBreadProvider + + +def test_validate_provider_credentials(): + provider = MixedBreadProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "usage": {"prompt_tokens": 3, "total_tokens": 3}, + "model": "mixedbread-ai/mxbai-embed-large-v1", + "data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}], + "object": "list", + "normalized": "true", + "encoding_format": "float", + "dimensions": 1024, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/mixedbread/test_rerank.py b/api/tests/integration_tests/model_runtime/mixedbread/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..b65aab74aa96d3113bb3eaff024b701ab1152438 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/mixedbread/test_rerank.py @@ -0,0 +1,100 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel + + +def test_validate_credentials(): + model = MixedBreadRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mxbai-rerank-large-v1", + credentials={"api_key": "invalid_key"}, + ) + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "usage": {"prompt_tokens": 86, "total_tokens": 86}, + "model": "mixedbread-ai/mxbai-rerank-large-v1", + "data": [ + { + "index": 0, + "score": 0.06762695, + "input": "Carson City is the capital city of the American state of Nevada. At the 2010 United " + "States Census, Carson City had a population of 55,274.", + "object": "text_document", + }, + { + "index": 1, + "score": 0.057403564, + "input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific " + "Ocean that are a political division controlled by the United States. Its capital is " + "Saipan.", + "object": "text_document", + }, + ], + "object": "list", + "top_k": 2, + "return_input": True, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials( + model="mxbai-rerank-large-v1", + credentials={ + "api_key": os.environ.get("MIXEDBREAD_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = MixedBreadRerankModel() + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "usage": {"prompt_tokens": 56, "total_tokens": 56}, + "model": "mixedbread-ai/mxbai-rerank-large-v1", + "data": [ + { + "index": 0, + "score": 0.6044922, + "input": "Kasumi is a girl name of Japanese origin meaning mist.", + "object": "text_document", + }, + { + "index": 1, + "score": 0.0703125, + "input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a " + "team named PopiParty.", + "object": "text_document", + }, + ], + "object": "list", + "top_k": 2, + "return_input": "true", + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="mxbai-rerank-large-v1", + credentials={ + "api_key": os.environ.get("MIXEDBREAD_API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl name of Japanese origin meaning mist.", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named " + "PopiParty.", + ], + score_threshold=0.5, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.5 diff --git a/api/tests/integration_tests/model_runtime/mixedbread/test_text_embedding.py b/api/tests/integration_tests/model_runtime/mixedbread/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ca97a1895113f02910033ff524fa810e66e0e474 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/mixedbread/test_text_embedding.py @@ -0,0 +1,78 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.mixedbread.text_embedding.text_embedding import MixedBreadTextEmbeddingModel + + +def test_validate_credentials(): + model = MixedBreadTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="mxbai-embed-large-v1", credentials={"api_key": "invalid_key"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "usage": {"prompt_tokens": 3, "total_tokens": 3}, + "model": "mixedbread-ai/mxbai-embed-large-v1", + "data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}], + "object": "list", + "normalized": "true", + "encoding_format": "float", + "dimensions": 1024, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials( + model="mxbai-embed-large-v1", credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")} + ) + + +def test_invoke_model(): + model = MixedBreadTextEmbeddingModel() + + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "usage": {"prompt_tokens": 6, "total_tokens": 6}, + "model": "mixedbread-ai/mxbai-embed-large-v1", + "data": [ + {"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}, + {"embedding": [0.23333 for _ in range(1024)], "index": 1, "object": "embedding"}, + ], + "object": "list", + "normalized": "true", + "encoding_format": "float", + "dimensions": 1024, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="mxbai-embed-large-v1", + credentials={ + "api_key": os.environ.get("MIXEDBREAD_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = MixedBreadTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="mxbai-embed-large-v1", + credentials={ + "api_key": os.environ.get("MIXEDBREAD_API_KEY"), + }, + texts=["ping"], + ) + + assert num_tokens == 1 diff --git a/api/tests/integration_tests/model_runtime/nomic/__init__.py b/api/tests/integration_tests/model_runtime/nomic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/nomic/test_embeddings.py b/api/tests/integration_tests/model_runtime/nomic/test_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..52dc96ee95c1bc5911f8c301eadd8263d734b1c7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/nomic/test_embeddings.py @@ -0,0 +1,62 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.nomic.text_embedding.text_embedding import NomicTextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock + + +@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True) +def test_validate_credentials(setup_nomic_mock): + model = NomicTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="nomic-embed-text-v1.5", + credentials={ + "nomic_api_key": "invalid_key", + }, + ) + + model.validate_credentials( + model="nomic-embed-text-v1.5", + credentials={ + "nomic_api_key": os.environ.get("NOMIC_API_KEY"), + }, + ) + + +@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True) +def test_invoke_model(setup_nomic_mock): + model = NomicTextEmbeddingModel() + + result = model.invoke( + model="nomic-embed-text-v1.5", + credentials={ + "nomic_api_key": os.environ.get("NOMIC_API_KEY"), + }, + texts=["hello", "world"], + user="foo", + ) + + assert isinstance(result, TextEmbeddingResult) + assert result.model == "nomic-embed-text-v1.5" + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True) +def test_get_num_tokens(setup_nomic_mock): + model = NomicTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="nomic-embed-text-v1.5", + credentials={ + "nomic_api_key": os.environ.get("NOMIC_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/nomic/test_provider.py b/api/tests/integration_tests/model_runtime/nomic/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ece4bb920080b35bda4c332359c9df65ed84768a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/nomic/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.nomic.nomic import NomicAtlasProvider +from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock + + +@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True) +def test_validate_provider_credentials(setup_nomic_mock): + provider = NomicAtlasProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "nomic_api_key": os.environ.get("NOMIC_API_KEY"), + }, + ) diff --git a/api/tests/integration_tests/model_runtime/novita/__init__.py b/api/tests/integration_tests/model_runtime/novita/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/novita/test_llm.py b/api/tests/integration_tests/model_runtime/novita/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..9f92679cd5ccfab3eb75e421585af09ae72b58cd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/novita/test_llm.py @@ -0,0 +1,98 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.novita.llm.llm import NovitaLargeLanguageModel + + +def test_validate_credentials(): + model = NovitaLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} + ) + + model.validate_credentials( + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, + ) + + +def test_invoke_model(): + model = NovitaLargeLanguageModel() + + response = model.invoke( + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_p": 0.5, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="novita", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = NovitaLargeLanguageModel() + + response = model.invoke( + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100}, + stream=True, + user="novita", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = NovitaLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="meta-llama/llama-3-8b-instruct", + credentials={ + "api_key": os.environ.get("NOVITA_API_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/novita/test_provider.py b/api/tests/integration_tests/model_runtime/novita/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..191af99db20bd974837384158dabc6347ca4c852 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/novita/test_provider.py @@ -0,0 +1,19 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.novita.novita import NovitaProvider + + +def test_validate_provider_credentials(): + provider = NovitaProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "api_key": os.environ.get("NOVITA_API_KEY"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/oci/__init__.py b/api/tests/integration_tests/model_runtime/oci/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/oci/test_llm.py b/api/tests/integration_tests/model_runtime/oci/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5d27eb0f2a02cb631c0f548c0d76a0f64e9e0c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_llm.py @@ -0,0 +1,129 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel + + +def test_validate_credentials(): + model = OCILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="cohere.command-r-plus", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, + ) + + model.validate_credentials( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + ) + + +def test_invoke_model(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="meta.llama-3-70b-instruct", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_model_with_function(): + model = OCILargeLanguageModel() + + response = model.invoke( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=False, + user="abc-123", + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_get_num_tokens(): + model = OCILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="cohere.command-r-plus", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/oci/test_provider.py b/api/tests/integration_tests/model_runtime/oci/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7107c7ccfe45215dcac91b8ff0637707c13f53 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_provider.py @@ -0,0 +1,20 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider + + +def test_validate_provider_credentials(): + provider = OCIGENAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..032c5c681a7aeb399401d7109eaf704cace386f5 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py @@ -0,0 +1,58 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel + + +def test_validate_credentials(): + model = OCITextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="cohere.embed-multilingual-v3.0", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, + ) + + model.validate_credentials( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + ) + + +def test_invoke_model(): + model = OCITextEmbeddingModel() + + result = model.invoke( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + # assert result.usage.total_tokens == 811 + + +def test_get_num_tokens(): + model = OCITextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="cohere.embed-multilingual-v3.0", + credentials={ + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/ollama/__init__.py b/api/tests/integration_tests/model_runtime/ollama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..979751afceaca4188bbfe7bddbc046fce217a79f --- /dev/null +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -0,0 +1,226 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel + + +def test_validate_credentials(): + model = OllamaLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mistral:text", + credentials={ + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + ) + + model.validate_credentials( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + ) + + +def test_invoke_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_completion_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_completion_model(): + model = OllamaLargeLanguageModel() + + response = model.invoke( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_completion_model_with_vision(): + model = OllamaLargeLanguageModel() + + result = model.invoke( + model="llava", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[ + UserPromptMessage( + content=[ + TextPromptMessageContent( + data="What is this in this picture?", + ), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ) + ], + model_parameters={"temperature": 0.1, "num_predict": 100}, + stream=False, + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_invoke_chat_model_with_vision(): + model = OllamaLargeLanguageModel() + + result = model.invoke( + model="llava", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[ + UserPromptMessage( + content=[ + TextPromptMessageContent( + data="What is this in this picture?", + ), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ) + ], + model_parameters={"temperature": 0.1, "num_predict": 100}, + stream=False, + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +def test_get_num_tokens(): + model = OllamaLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4f740a4fd09cad235dbde55b350ab4ac301b27 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -0,0 +1,65 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel + + +def test_validate_credentials(): + model = OllamaEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mistral:text", + credentials={ + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 4096, + }, + ) + + model.validate_credentials( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, + ) + + +def test_invoke_model(): + model = OllamaEmbeddingModel() + + result = model.invoke( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = OllamaEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="mistral:text", + credentials={ + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai/__init__.py b/api/tests/integration_tests/model_runtime/openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..9e83b9d434359d0cf3901c794f84d857b17a3766 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -0,0 +1,314 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = OpenAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_validate_credentials_for_completion_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="text-davinci-003", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials( + model="text-davinci-003", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_invoke_completion_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + assert model._num_tokens_from_string("gpt-3.5-turbo-instruct", result.message.content) == 1 + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_invoke_stream_completion_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo-instruct", + credentials={ + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "openai_organization": os.environ.get("OPENAI_ORGANIZATION"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_vision(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-4-vision-preview", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content=[ + TextPromptMessageContent( + data="Hello World!", + ), + ImagePromptMessageContent( + mime_type="image/png", + format="png", + base64_data="iVBORw0KGgoAAAANSUhEUgAAAE4AAABMCAYAAADDYoEWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBoAQSkhN4EkRpASggt9I4gKiEJEEqMgaBiRxcVXLuIgA1dFVGwAmJBETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMmbllAFA7zhGJclF1APKEBeLYYH/6uOQUOukpIAEdoAy0gA2Hmy9iRkeHA1iG2r+Xd9cBIm2v2Eu1/tn/X4sGj5/PBQCJhjidl8/Ng/gAAHg1VyQuAIAo5c2mFoikGFagJYYBQrxIijPluFqK0+V4j8wmPpYFcTsASiocjjgTANVLkKcXcjOhhmo/xI5CnkAIgBodYp+8vMk8iNMgtoY2Ioil+oz0H3Qy/6aZPqzJ4WQOY/lcZEUpQJAvyuVM/z/T8b9LXq5kyIclrCpZ4pBY6Zxh3m7mTA6TYhWI+4TpkVEQa0L8QcCT2UOMUrIkIQlye9SAm8+COYMrDVBHHicgDGIDiIOEuZHhCj49QxDEhhjuEHSaoIAdD7EuxIv4+YFxCptN4smxCl9oY4aYxVTwZzlimV+pr/uSnASmQv91Fp+t0MdUi7LikyCmQGxeKEiMhFgVYof8nLgwhc3YoixW5JCNWBIrjd8c4li+MNhfro8VZoiDYhX2pXn5Q/PFNmUJ2JEKvK8gKz5Enh+sncuRxQ/ngl3iC5kJQzr8/HHhQ3Ph8QMC5XPHnvGFCXEKnQ+iAv9Y+VicIsqNVtjjpvzcYClvCrFLfmGcYiyeWAA3pFwfzxAVRMfL48SLsjmh0fJ48OUgHLBAAKADCazpYDLIBoLOvqY+eCfvCQIcIAaZgA/sFczQiCRZjxBe40AR+BMiPsgfHucv6+WDQsh/HWblV3uQIestlI3IAU8gzgNhIBfeS2SjhMPeEsFjyAj+4Z0DKxfGmwurtP/f80Psd4YJmXAFIxnySFcbsiQGEgOIIcQgog2uj/vgXng4vPrB6oQzcI+heXy3JzwhdBEeEq4Rugm3JgmKxT9FGQG6oX6QIhfpP+YCt4Sarrg/7g3VoTKug+sDe9wF+mHivtCzK2RZirilWaH/pP23GfywGgo7siMZJY8g+5Gtfx6paqvqOqwizfWP+ZHHmj6cb9Zwz8/+WT9knwfbsJ8tsUXYfuwMdgI7hx3BmgAda8WasQ7sqBQP767Hst015C1WFk8O1BH8w9/Qykozme9Y59jr+EXeV8CfJn1HA9Zk0XSxIDOrgM6EXwQ+nS3kOoyiOzk6OQMg/b7IX19vYmTfDUSn4zs3/w8AvFsHBwcPf+dCWwHY6w4f/0PfOWsG/HQoA3D2EFciLpRzuPRCgG8JNfik6QEjYAas4XycgBvwAn4gEISCKBAPksFEGH0W3OdiMBXMBPNACSgDy8EaUAk2gi1gB9gN9oEmcAScAKfBBXAJXAN34O7pAS9AP3gHPiMIQkKoCA3RQ4wRC8QOcUIYiA8SiIQjsUgykoZkIkJEgsxE5iNlyEqkEtmM1CJ7kUPICeQc0oXcQh4gvchr5BOKoSqoFmqIWqKjUQbKRMPQeHQCmolOQYvQBehStAKtQXehjegJ9AJ6De1GX6ADGMCUMR3MBLPHGBgLi8JSsAxMjM3GSrFyrAarx1rgOl/BurE+7CNOxGk4HbeHOzgET8C5+BR8Nr4Er8R34I14O34Ff4D3498IVIIBwY7gSWATxhEyCVMJJYRywjbCQcIp+Cz1EN4RiUQdohXRHT6LycRs4gziEuJ6YgPxOLGL+Ig4QCKR9Eh2JG9SFIlDKiCVkNaRdpFaSZdJPaQPSspKxkpOSkFKKUpCpWKlcqWdSseULis9VfpMVidbkD3JUWQeeTp5GXkruYV8kdxD/kzRoFhRvCnxlGzKPEoFpZ5yinKX8kZZWdlU2UM5RlmgPFe5QnmP8lnlB8ofVTRVbFVYKqkqEpWlKttVjqvcUnlDpVItqX7UFGoBdSm1lnqSep/6QZWm6qDKVuWpzlGtUm1Uvaz6Uo2sZqHGVJuoVqRWrrZf7aJanzpZ3VKdpc5Rn61epX5I/Yb6gAZNY4xGlEaexhKNnRrnNJ5pkjQtNQM1eZoLNLdontR8RMNoZjQWjUubT9tKO0Xr0SJqWWmxtbK1yrR2a3Vq9WtrartoJ2pP067SPqrdrYPpWOqwdXJ1luns07mu82mE4QjmCP6IxSPqR1we8V53pK6fLl+3VLdB95ruJz26XqBejt4KvSa9e/q4vq1+jP5U/Q36p/T7RmqN9BrJHVk6ct/I2waoga1BrMEMgy0GHQYDhkaGwYYiw3WGJw37jHSM/IyyjVYbHTPqNaYZ+xgLjFcbtxo/p2vTmfRcegW9nd5vYmASYiIx2WzSafLZ1Mo0wbTYtMH0nhnFjGGWYbbarM2s39zYPMJ8pnmd+W0LsgXDIstircUZi/eWVpZJlgstmyyfWelasa2KrOqs7lpTrX2tp1jXWF+1IdowbHJs1ttcskVtXW2zbKtsL9qhdm52Arv1dl2jCKM8RglH1Yy6Ya9iz7QvtK+zf+Cg4xDuUOzQ5PBytPnolNErRp8Z/c3R1THXcavjnTGaY0LHFI9pGfPaydaJ61TldNWZ6hzkPMe52fmVi50L32WDy01XmmuE60LXNtevbu5uYrd6t153c/c092r3GwwtRjRjCeOsB8HD32OOxxGPj55ungWe+zz/8rL3yvHa6fVsrNVY/titYx95m3pzvDd7d/vQfdJ8Nvl0+5r4cnxrfB/6mfnx/Lb5PWXaMLOZu5gv/R39xf4H/d+zPFmzWMcDsIDggNKAzkDNwITAysD7QaZBmUF1Qf3BrsEzgo+HEELCQlaE3GAbsrnsWnZ/qHvorND2MJWwuLDKsIfhtuHi8JYINCI0YlXE3UiLSGFkUxSIYketiroXbRU9JfpwDDEmOqYq5knsmNiZsWfiaHGT4nbGvYv3j18WfyfBOkGS0JaolpiaWJv4PikgaWVS97jR42aNu5CsnyxIbk4hpSSmbEsZGB84fs34nlTX1JLU6xOsJkybcG6i/sTciUcnqU3iTNqfRkhLStuZ9oUTxanhDKSz06vT+7ks7lruC54fbzWvl+/NX8l/muGdsTLjWaZ35qrM3izfrPKsPgFLUCl4lR2SvTH7fU5Uzvacwdyk3IY8pby0vENCTWGOsH2y0eRpk7tEdqISUfcUzylrpvSLw8Tb8pH8CfnNBVrwR75DYi35RfKg0KewqvDD1MSp+6dpTBNO65huO33x9KdFQUW/zcBncGe0zTSZOW/mg1nMWZtnI7PTZ7fNMZuzYE7P3OC5O+ZR5uXM+73YsXhl8dv5SfNbFhgumLvg0S/Bv9SVqJaIS24s9Fq4cRG+SLCoc7Hz4nWLv5XySs+XOZaVl31Zwl1y/tcxv1b8Org0Y2nnMrdlG5YTlwuXX1/hu2LHSo2VRSsfrYpY1biavrp09ds1k9acK3cp37iWslaytrsivKJ5nfm65eu+VGZVXqvyr2qoNqheXP1+PW/95Q1+G+o3Gm4s2/hpk2DTzc3BmxtrLGvKtxC3FG55sjVx65nfGL/VbtPfVrbt63bh9u4dsTvaa91ra3ca7FxWh9ZJ6np3pe66tDtgd3O9ff3mBp2Gsj1gj2TP871pe6/vC9vXtp+xv/6AxYHqg7SDpY1I4/TG/qaspu7m5OauQ6GH2lq8Wg4edji8/YjJkaqj2keXHaMcW3BssLWodeC46HjficwTj9omtd05Oe7k1faY9s5TYafOng46ffIM80zrWe+zR855njt0nnG+6YLbhcYO146Dv7v+frDTrbPxovvF5ksel1q6xnYdu+x7+cSVgCunr7KvXrgWea3resL1mzdSb3Tf5N18div31qvbhbc/35l7l3C39J76vfL7Bvdr/rD5o6Hbrfvog4AHHQ/jHt55xH304nH+4y89C55Qn5Q/NX5a+8zp2ZHeoN5Lz8c/73khevG5r+RPjT+rX1q/PPCX318d/eP6e16JXw2+XvJG7832ty5v2waiB+6/y3v3+X3pB70POz4yPp75lPTp6eepX0hfKr7afG35Fvbt7mDe4KCII+bIfgUwWNGMDABebweAmgwADZ7PKOPl5z9ZQeRnVhkC/wnLz4iy4gZAPfx/j+mDfzc3ANizFR6/oL5aKgDRVADiPQDq7Dxch85qsnOltBDhOWBT5Nf0vHTwb4r8zPlD3D+3QKrqAn5u/wWdZ3xtG7qP3QAAADhlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAATqADAAQAAAABAAAATAAAAADhTXUdAAARnUlEQVR4Ae2c245bR3aGi4fulizFHgUzQAYIggBB5klymfeaZ8hDBYjvAiRxkMAGkowRWx7JktjcZL7vX1Uku62Burkl5YbV5q7Tqqq1/v3XqgMpL95tbvftEh6NwPLRLS4NgsAFuDOJcAHuAtyZCJzZ7MK4C3BnInBmswvjLsCdicCZzS6MOxO49Znt0uz3//CPbbv6srXFrq0W9Q6Wi0VbLPn4R8x/jSLiu3nrl8s9dcartlwtKdmTbm21XranN6v27Mm6XV8t25fP1+3Pn1+1r4if3Czbk+t9u1rR6f9jmAXc1P6sbaevQGbfdgGJeA8ke0AQsCYYgiYgPR1QyVO+3wvcMm2WO0G2PeWkX79btp839AG4//UjYC62gDsB2rI9f7pov3q2bX/9F1ftBWAufTufOcwCrnTtR90dOdHoNgCJeAbUkuM5TsWAW5W9gfkE83ZkUHg0oAyAwbm927a2ebVoP/xx2f7jD1uYuG9/89tF+/VXK1hq+88TZgG32O1g2r7tpRdBM8fUTM7pyR8SYddgxkJErUszHti7U44CpzyEo16syNtx+qgy+1og7RMetpev9+3rb3bt+c2u/ebFsv3uL1ftiqn+qcMs4HY7jNQpEfadNU5VqeHUTJkgUbaPDxRADdZ8jU9LHoJYnwLUtgWN4ObDC7Kdr8Hp7d9qMTW8gt23V1zyvPrD1H56e9t+99vr9uJLprBDfaIw69U4dQRCIw2JdVIjbUzecj+7qYyPpZHiAbDaJwsXyMhQEQ0pq6sAp7hMS2XGqykdA2iy4EUtF6v206ur9k/fbNo//+frtt2OaW/rjxtmAaeNGqihBY5xfVQzQEZfoSH0KHgkrbD/CX6vPIqlSTU61vVCovRSbEwbIS851vj23Q+tff3vu/bzu5I7tvs4qVnADTa5FCbNC86qCLN2E1MxKKroYB2pgSz2RLbbVcVkSJhOKxIDjGxn+nSuqes2JlKuG8fA/IzPXazbj68X7et/27UfX7GifORwOuSju47h/c3beKfRFO74CNA04YP0ZT2/YzERFGojc9pmDG47/wyDZwJjiX4wwJNer1dZPJbs5/xzK5Ppzp7SQZBszNy22U7tX7/dtFdvJrv8aGE2cDJLoPycBgHSgICJUQLo8nmUo6y7oH0S5Lu/FGhDQULCfIooATw3yyOQQ46eYVpYiaBMTFtAFPR307r9y3fbdvsRfd5Rg6HJI2Lt1qaAF6TEqoxWdVdYSHawezCvAHLjW7Jh2QGcUkDDT4Og2OfSFRVkxipcAJUZARC5FVRbeRpB1hVY6r25XQHexIZ96Hfa++PTs4Dbi8rQg7imWQG27/uEgCTCssk/WWg7GwJWwDQ36PceGzQ+x7jOtgNogkIIpsZiFMdXoEfOPUlh3l5ulu2/X6bJ7Mc84Bw+xgOKzJqM0VKm8WYlVMqt61gFKNtQKeZ6o7Ls/aqEeYooJXDIZ9uiT0uZ5UxPUJNlYdoAK62qHfM7unz3/bb9/Ha+v3u/tn3AD0XOrnxAZdpNYZILgoxyGk4BqMCbssq66dXv6RdFkiB6Rj2u3N1npiMw1dQjF4oJW/kzy6VdMRFA9Xd8VvhCLxCyYUYkvhHZb7+fotvdUR6XmwXcYI1DangAA6yspgBj/dRjp6L+RbmSPaaxuuMnGEeVAhBF4pSapAFG5gUo60rAHmpVtcz0sR2aBZW8NAB9+W7dXr9N0dmPmUcu10pWrq7kQQvBQXn1dUsgoM4ej12TtyBknG51PEMGOV2TLLVZ/GLvLMBYHsYJhg7fuMBx6tq3LFu7aBxxD9jKFiO7Thbwcv7n5dS+/ML0eWEWcBqoptk+mEQp2aTG+rbmBYA+D6MyMwMAdepKsX5QpnglFZyZ5k4tDYsI/Y1pF7CRq22HoHXgGEOwgodvgH79INnW3tlFIVVQvkBXg1dvF3z27fkTGzw+zALOPZluVoVkV4yLHoBB3VBJUNyo6uEWXAyIkruC2OQjbVeppxkm8+iti2mySsM1EPYGKBcEyul3LKTW1+pr+wLRstwP0J8a2K95Txf/+6q1ZzeUDEXt/oFhHnA4fJYCBtawYlWmlsrJBEHhP43bi9Rq1Z0ymlK3Z/QCRqA5YfaNLZJWEACn929eluXlUGO8CgMrHWYi441S2tsFebLRL5RWL0e0nL64SEEf2sjMR4ZZwA0Ddfziclz1eN8yDn1qAaHSq3G0FEQXjABDo51sJVNyGnA0QlAPL4LOApzMo0mY1sUFbQBj8xTzYhKrROYF5VGIftR1uW3+3uiWU8XnBw7l3HIYVG/P/djYgMZoyrTJrci0n2qPZVnNFV913viW6btGzsXBT6aW3VKmsauVTFOc2DxpP5YJYLBBeCUixE71IlGBR2EF+6OugHbP12Ddoj29HgIPj+cxDiPDFGINzB8sKhLh0Ui4gOgDI8deb8FiwYxlteWhLHWTlmOzhkxLAObPIkFqS8+bbG5BdgWiAmJTwXdqZ7oysktzdKC/BWMWiAJNpyP0ZPTMItRy7fTi2RB4eDwLuIkpCma1gob/Dsw7zcKAMf3txiCot8c42ZCDPu3WAqRMJAGEk4cACaLzSZsFRhAE9QoAtXcwTX92XDT0sxTQXJYHdDJin0KfVN8PmzNvnOYBx5XNlik4giumihb7tJ60ezgNhgXuXgRNttxunZYAj7uzbL3nUA67rm5KJWrJCyTfIVwBMh3bTkD8TqFYp6uv8RwrgJpAZmHHScqv0qWeKT48NujhAuELekyYBdz9gXJQ53DvDh3tU62xTtN8bQhzzE9OccAK8wA2ez2k3cNtN7wM/RZs9M5NkNZoee0H2rmhLr8miPV9roAZtN1RHV/gDb7EoUtXKeXjYXUBN0oeFs8CbrtlhZRGPZSSZNyI9gA+TBFkelFNWxgEgCtG3wDiFqEr5Jz6y/U1DAM4QLxi2l7DNhl3w/epNTUFWGbXC7HrMQMz7WUbf8AaDQ46DYXuxLoJX6CFRzvuiPyJzCzgZIoKyqgKAx1yAGPQUWfa+GoDsqwDJNnHLF9juSz0i5VrpvqSwmsQul5dtyfrfX1zL3i0WdHHSjaKVjf0T5k7ABtxlEHbwxusgjydAY8N84BjvAx5GLfMqBW0VJEZ+pwKskQnbpnFHPzpwWo/bzkGvX51296+bu1v/+qL9usXT9rTJ07Bzh9k9HEPsxNhwhh6xLXKo3fXWf3iMkrBBz9nAbflbHm6ONxhXp8/NW26lkSleIEV9FBVI+o6ihjmffPDt+3v/+5Z+82vnsZw/fyercweB2d7wzA8mfuPEknpXTnHvQsoPd1v/aD8LODw+AxbAw/QjnEfv69u5kz6dtOiW2R6YmW7vd0C3qK94wcjf/zxZ1bRXfvqGT6U3f2G/Z6AesqotgJX477PNVmTmxfiwTSS5irqz2ybEHD6PzbMAk7lS/0BxgkTqPAUYBiAkQpTLLdKxe1D4Lbsp968uW1vXk+ZrnpsN7yL1TbmbvCl4GcPPPStZWyNcM9s++9y92ruZu2CT21q7lZ9KDcLuC3WbmGG42uA30EISOVkFynt1BBialOliF/wZHqGTa1tOfq8fbMHPL6N2iBPW2d7HfxZdWnreiN49UL0dfhLR6tBSVVwNo+TQ1U5IsHvQU4Dcry7bGNOix+SngVcwAhYpZjTQxaNMABLLLtUFEAMEwi4kk63fGDbLTcVm82ubd7hNylzEXCa6SPdz2Vf5iUobe0jAFIq8+JHT8CjGeUjHFOj5E7MIO4THxvOaHIcwu2IOKiznyg89BTEXi6WssO8B36vkLa33Pv7/QRbEtm21c/BtIm9Yb4ho19PDg4g09aeucySdpzq3BfVx6WQqh7MkLOSkHLf2olEKni4n7xznh0VH4jnAYdy6hfVSZTvUmF54f2cU9d9XmlhvUyTlbkxIT0BWtgH4wRRgPMy7EFbAwi8ojzbNyqtH/7coWxnUHyE+rmYjbs3NCnqdwIbbM/GZ4RZwDleVskO3viSBhWjSu2Pxj7JU4bsqrzTU5YZQ7xKu73Bb8bAbo+s28NStxEyb8e+K1UAKXhOVivK7x0RUANf3zEw/smJpsr37cad9RlhFnCbzQYwfN36I+5qwxgVwRA/vOHxlneeMiaux9lymN5tTTttkZN5mbZwCYsLM550taA+zJM5gsdHsGSdQTbngN7ZlC/JrRhXIcorRJvVcp2pnjzdy+0nnErOCbOAE5x8d4oVCy4xMSFGetjfgWJ3MQFHdomxZbUwwC4B84YlzBNojUEmxmqO1tVC4VcVopUzKuXK+XArUeDVTyq85wv7xKqHsel1dfIUkl8zUXcFm8eUH7IPjWcBp8J5mYxWcWmbclhlyEIAMJm2HbSwDCHZGD9IuR1UH4MhaZ4HOAIQIJOrIxfjxOFRUMNQq8wI9EH5WNVJdcEje22ofxs3K6PlQ+OZwA2ghrFSKhiEVSqh/5JJcfodKBnntLac7wb5CKLpAs+0RguYuAhoNh2CRV1dTVFhqWhRn/u+tOsMtTph6JhOkAWsQDz1K3NHeHyYBZyK70BG5oy3SyqGumoaAhr1Aiggnm8FzXr3cQWSq++p8seM10v6LW9Elgh5kyGINXMdi1xspw2LRHwqMjJTV2KdU9c2eQ1SkXDDHL2aYf2MprVp1dFrtcBlAWB/sNuxMoJIzEfRqhMk04qXfM0n8yVDaa/DRLp1GuGSKhNz65ZEOQUSdyD0Y/adRSojsxjoz2jnNFdN3l/S+sUvnqbDsx+zgCvQMJzhPaCrlouCLBvbA43x68DhsAc7DxpTr0y39VAMBCfpSlpSUMggzRe8X4bIAWRYJqVJj6t7feMV/9Bkfeb+bYw2Czg78S3GwWtEQEPRWFMMEDAZhVTiMaWLnZZRxSexfaStPR9DAXbMj5Qs479Dm8PqqYCNEpUTVAe/GpLC3vH16hI64zkLuB1XQVsdFkED8ps40oLjj2sMAdbFwGlKRjbW6UHAFZaRJVegIpeWVafZhQ4yHahUm+5VyfOwXYFHTX8DKUNSn+fCcsN3qOd8AT3GGPEs4EYnxho9YlOnU1WTUj98GbLKWCawI5wk71DiBMoh+qjYfgXUc+nNlW+rXuqjOrknPAs4sRoHcvvNguDZNEChYOoBUUZ175z9nMBZnQ6cnncgS7uDnt3BJ49Y8axqPYLZ0gVEb2DaICyHtOUM5t2eP7AJexWaGWYBVzcdsqneoAAViyzzo3ZsC1Jeq2qBKVhlkIxDsuSRrSY6/6S6eaaFjD+B4BGmMo9X9M06kcAdMq0qU5eT+lBBc8+GqaVmCc989iHP6yVvOcr4qE8ZLijVZ8VleC/5xWDWFmN6ow6aIKX75EfdL5rfKxBJgAcwwV/zeXrFjyqqo3uy52dnMa5oU4O7svo7YMNgWrFKdsk6WBXmmS82HuKsuADjHZFGi5iBIv+9qnn/qt+qSh3JTFNjPvWDiqpnA0SexYB/ijm6q5qP85wFnIZrXQHgillpVesHh9QVaAWWAJccfo/VNrOcbmrbYn/vCR9gy2m1aUH2WOa/rv4UoKnhPODowC2Gx6jQo4Nox4ZinDL392ssIHFSZWa1rTZJD/wSy0Kn34eDpwZvP1w96+dmH25zrsQs4KSLP4GAawWSjhnFZZQFmUZxOZSTj/ne2yUhIHCjRIlFKcIU0x852RjZTGGlDdaQrkxk7MPrJr/gzg17r4vgJ3rMAk4/wmQDE7wJhg+fFV1xaMGiMqnXaFc5jd4FjCCIRAEmAO5aPE7lzsw0ZelHYJB0PCWscErqOJcsrbllGmhmzE/7mAXcPof544Wlqg6wTuORtvKQzjV2gVC+shaNMhc24v8iIloGmS3ogc7bD9sS884Oi0kEP89jFnDX++/hCtPVtT7kwaxOkZpmxQ/L9vgdj1r+NCtAwQ6/A9DXMXnBqZgoHDdXP7Wna/Id6PRCum7DiREqcg1UPw9Yp6MsLv/HwlM4Hp7WQ1/CGQhcgDsDNJtcgLsAdyYCZza7MO4C3JkInNnswrgLcGcicGazC+POBO7/AH5zPa/ivytzAAAAAElFTkSuQmCC", + ), + ] + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = OpenAILargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = OpenAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 3 + + num_tokens = model.get_num_tokens( + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 72 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat", "remote"]], indirect=True) +def test_fine_tuned_models(setup_openai_mock): + model = OpenAILargeLanguageModel() + + remote_models = model.remote_models(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + + if not remote_models: + assert isinstance(remote_models, list) + else: + assert isinstance(remote_models[0], AIModelEntity) + + for llm_model in remote_models: + if llm_model.model_type == ModelType.LLM: + break + + assert isinstance(llm_model, AIModelEntity) + + # test invoke + result = model.invoke( + model=llm_model.model, + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + + +def test__get_num_tokens_by_gpt2(): + model = OpenAILargeLanguageModel() + num_tokens = model._get_num_tokens_by_gpt2("Hello World!") + + assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..6de262471798ad8cc5939009eddc9e2f42391066 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -0,0 +1,44 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = OpenAIModerationModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials( + model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = OpenAIModerationModel() + + result = model.invoke( + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + text="hello", + user="abc-123", + ) + + assert isinstance(result, bool) + assert result is False + + result = model.invoke( + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + text="i will kill you", + user="abc-123", + ) + + assert isinstance(result, bool) + assert result is True diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..4d56cfcf3c25f0db1429bd392f23e2108c158649 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai.openai import OpenAIProvider +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_provider_credentials(setup_openai_mock): + provider = OpenAIProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..aa92c8b61fb6842d22b7844f39e03bc8e22ba952 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = OpenAISpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + + +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = OpenAISpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-1", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + file=file, + user="abc-123", + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..f5dd73f2d4cd6018597bc65992536e2568da3029 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -0,0 +1,48 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = OpenAITextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials( + model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = OpenAITextEmbeddingModel() + + result = model.invoke( + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = OpenAITextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/__init__.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..f2302ef05e06dee1e2acadfd129bbb2578337d7a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -0,0 +1,197 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + +""" +Using Together.ai's OpenAI-compatible API as testing endpoint +""" + + +def test_validate_credentials(): + model = OAIAPICompatLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"}, + ) + + model.validate_credentials( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, + ) + + +def test_invoke_model(): + model = OAIAPICompatLargeLanguageModel() + + response = model.invoke( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "completion", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OAIAPICompatLargeLanguageModel() + + response = model.invoke( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + "stream_mode_delimiter": "\\n\\n", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_stream_model_without_delimiter(): + model = OAIAPICompatLargeLanguageModel() + + response = model.invoke( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +# using OpenAI's ChatGPT-3.5 as testing endpoint +def test_invoke_chat_model_with_tools(): + model = OAIAPICompatLargeLanguageModel() + + result = model.invoke( + model="gpt-3.5-turbo", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 1024}, + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +def test_get_num_tokens(): + model = OAIAPICompatLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..cf805eafff496888f209cdcbaa81b4772ac07c96 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import ( + OAICompatSpeech2TextModel, +) + + +def test_validate_credentials(): + model = OAICompatSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="whisper-1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"}, + ) + + model.validate_credentials( + model="whisper-1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, + ) + + +def test_invoke_model(): + model = OAICompatSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="whisper-1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, + file=file, + user="abc-123", + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..052b41605f6da258a48411b89ab64d63daa1b764 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -0,0 +1,67 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + +""" +Using OpenAI's API as testing endpoint +""" + + +def test_validate_credentials(): + model = OAICompatEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="text-embedding-ada-002", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184}, + ) + + model.validate_credentials( + model="text-embedding-ada-002", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, + ) + + +def test_invoke_model(): + model = OAICompatEmbeddingModel() + + result = model.invoke( + model="text-embedding-ada-002", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 502 + + +def test_get_num_tokens(): + model = OAICompatEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="text-embedding-ada-002", + credentials={ + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/embeddings", + "context_size": 8184, + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/__init__.py b/api/tests/integration_tests/model_runtime/openllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..14d47217af62c87533c41c866551fec738cb6b8e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -0,0 +1,57 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel + + +def test_validate_credentials(): + model = OpenLLMTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"), + }, + ) + + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + ) + + +def test_invoke_model(): + model = OpenLLMTextEmbeddingModel() + + result = model.invoke( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = OpenLLMTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..35939e3cfe8bfd83997a9f3276f1e87de55bb705 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -0,0 +1,95 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openllm.llm.llm import OpenLLMLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = OpenLLMLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "server_url": "invalid_key", + }, + ) + + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + ) + + +def test_invoke_model(): + model = OpenLLMLargeLanguageModel() + + response = model.invoke( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + model = OpenLLMLargeLanguageModel() + + response = model.invoke( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = OpenLLMLargeLanguageModel() + + response = model.get_num_tokens( + model="NOT IMPORTANT", + credentials={ + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], + ) + + assert isinstance(response, int) + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/openrouter/__init__.py b/api/tests/integration_tests/model_runtime/openrouter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0cc6bf4b8e76934cc18faf9c27a2862d4ac8d6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -0,0 +1,103 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel + + +def test_validate_credentials(): + model = OpenRouterLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} + ) + + model.validate_credentials( + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, + ) + + +def test_invoke_model(): + model = OpenRouterLargeLanguageModel() + + response = model.invoke( + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OpenRouterLargeLanguageModel() + + response = model.invoke( + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = OpenRouterLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="mistralai/mixtral-8x7b-instruct", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/replicate/__init__.py b/api/tests/integration_tests/model_runtime/replicate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..b940005b715760aff48241ac0cc0b37a6e7d5b3a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -0,0 +1,112 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel + + +def test_validate_credentials(): + model = ReplicateLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="meta/llama-2-13b-chat", + credentials={ + "replicate_api_token": "invalid_key", + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, + ) + + model.validate_credentials( + model="meta/llama-2-13b-chat", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, + ) + + +def test_invoke_model(): + model = ReplicateLargeLanguageModel() + + response = model.invoke( + model="meta/llama-2-13b-chat", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = ReplicateLargeLanguageModel() + + response = model.invoke( + model="mistralai/mixtral-8x7b-instruct-v0.1", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = ReplicateLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..397715f225208364513ce35c9e2b67c526384791 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -0,0 +1,136 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel + + +def test_validate_credentials_one(): + model = ReplicateEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="replicate/all-mpnet-base-v2", + credentials={ + "replicate_api_token": "invalid_key", + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, + ) + + model.validate_credentials( + model="replicate/all-mpnet-base-v2", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, + ) + + +def test_validate_credentials_two(): + model = ReplicateEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="nateraw/bge-large-en-v1.5", + credentials={ + "replicate_api_token": "invalid_key", + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, + ) + + model.validate_credentials( + model="nateraw/bge-large-en-v1.5", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, + ) + + +def test_invoke_model_one(): + model = ReplicateEmbeddingModel() + + result = model.invoke( + model="nateraw/bge-large-en-v1.5", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_invoke_model_two(): + model = ReplicateEmbeddingModel() + + result = model.invoke( + model="andreasjansson/clip-features", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_invoke_model_three(): + model = ReplicateEmbeddingModel() + + result = model.invoke( + model="replicate/all-mpnet-base-v2", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_invoke_model_four(): + model = ReplicateEmbeddingModel() + + result = model.invoke( + model="nateraw/jina-embeddings-v2-base-en", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = ReplicateEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="nateraw/jina-embeddings-v2-base-en", + credentials={ + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/__init__.py b/api/tests/integration_tests/model_runtime/sagemaker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..41de2a17fda047faedbcd199ff00bfd904b64796 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -0,0 +1,13 @@ +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider + + +def test_validate_provider_credentials(): + provider = SageMakerProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={}) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a6798a1ef735f98ed787a4cca394f392fbfbf1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel + + +def test_validate_credentials(): + model = SageMakerRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-m3-rerank-v2", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + +def test_invoke_model(): + model = SageMakerRerankModel() + + result = model.invoke( + model="bge-m3-rerank-v2", + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..f77601eea2c2637606fb07ec879b22fbd5ef7f17 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -0,0 +1,31 @@ +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel + + +def test_validate_credentials(): + model = SageMakerEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="bge-m3", credentials={}) + + model.validate_credentials(model="bge-m3-embedding", credentials={}) + + +def test_invoke_model(): + model = SageMakerEmbeddingModel() + + result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123") + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + + +def test_get_num_tokens(): + model = SageMakerEmbeddingModel() + + num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[]) + + assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/__init__.py b/api/tests/integration_tests/model_runtime/siliconflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..f47c9c558808afe0e9bfbd6feb3672ceed51906c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -0,0 +1,73 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.llm.llm import SiliconflowLargeLanguageModel + + +def test_validate_credentials(): + model = SiliconflowLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")}) + + +def test_invoke_model(): + model = SiliconflowLargeLanguageModel() + + response = model.invoke( + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = SiliconflowLargeLanguageModel() + + response = model.invoke( + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = SiliconflowLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..8f70210b7a2acee7d2a0d931d7e143e1b9c7ee14 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.siliconflow import SiliconflowProvider + + +def test_validate_provider_credentials(): + provider = SiliconflowProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..ad794613f910139ad648f0577eece030ea8718e9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -0,0 +1,47 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.rerank.rerank import SiliconflowRerankModel + + +def test_validate_credentials(): + model = SiliconflowRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowRerankModel() + + result = model.invoke( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py new file mode 100644 index 0000000000000000000000000000000000000000..0502ba5ab404bcf4e63996688b0fd11005c13758 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel + + +def test_validate_credentials(): + model = SiliconflowSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="iic/SenseVoiceSmall", + credentials={"api_key": os.environ.get("API_KEY")}, + ) + + +def test_invoke_model(): + model = SiliconflowSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + # Open the file and get the file object + with open(audio_file_path, "rb") as audio_file: + file = audio_file + + result = model.invoke( + model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file + ) + + assert isinstance(result, str) + assert result == "1,2,3,4,5,6,7,8,9,10." diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ab143c10613a88f532acba00270ca3254c6a0717 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py @@ -0,0 +1,60 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.text_embedding.text_embedding import ( + SiliconflowTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = SiliconflowTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-large-zh-v1.5", + credentials={"api_key": "invalid_key"}, + ) + + model.validate_credentials( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowTextEmbeddingModel() + + result = model.invoke( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + texts=[ + "hello", + "world", + ], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + + +def test_get_num_tokens(): + model = SiliconflowTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="BAAI/bge-large-zh-v1.5", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/spark/__init__.py b/api/tests/integration_tests/model_runtime/spark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe2fd8c0a3eac0562da66c12a5b5ba4836e3ea7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -0,0 +1,92 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel + + +def test_validate_credentials(): + model = SparkLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"}) + + model.validate_credentials( + model="spark-1.5", + credentials={ + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SparkLargeLanguageModel() + + response = model.invoke( + model="spark-1.5", + credentials={ + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = SparkLargeLanguageModel() + + response = model.invoke( + model="spark-1.5", + credentials={ + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = SparkLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="spark-1.5", + credentials={ + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..9da0df6bb3d556dadeb19ab794212d85c0676468 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.spark.spark import SparkProvider + + +def test_validate_provider_credentials(): + provider = SparkProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={ + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/stepfun/__init__.py b/api/tests/integration_tests/model_runtime/stepfun/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..f9afca6f5945b589a77a19e852f5da4878a1d682 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -0,0 +1,123 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel + + +def test_validate_credentials(): + model = StepfunLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}) + + +def test_invoke_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["Hi"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_customizable_model_schema(): + model = StepfunLargeLanguageModel() + + schema = model.get_customizable_model_schema( + model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")} + ) + assert isinstance(schema, AIModelEntity) + + +def test_invoke_chat_model_with_tools(): + model = StepfunLargeLanguageModel() + + result = model.invoke( + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in Shanghai?", + ), + ], + model_parameters={"temperature": 0.9, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec4b0b7243176092eafe16891cbbf1a64c9384b --- /dev/null +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -0,0 +1,69 @@ +import logging +import os + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderExtension, ModelProviderFactory + +logger = logging.getLogger(__name__) + + +def test_get_providers(): + factory = ModelProviderFactory() + providers = factory.get_providers() + + for provider in providers: + logger.debug(provider) + + assert len(providers) >= 1 + assert isinstance(providers[0], ProviderEntity) + + +def test_get_models(): + factory = ModelProviderFactory() + providers = factory.get_models( + model_type=ModelType.LLM, + provider_configs=[ + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], + ) + + logger.debug(providers) + + assert len(providers) >= 1 + assert isinstance(providers[0], SimpleProviderEntity) + + # all provider models type equals to ModelType.LLM + for provider in providers: + for provider_model in provider.models: + assert provider_model.model_type == ModelType.LLM + + providers = factory.get_models( + provider="openai", + provider_configs=[ + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], + ) + + assert len(providers) == 1 + assert isinstance(providers[0], SimpleProviderEntity) + assert providers[0].provider == "openai" + + +def test_provider_credentials_validate(): + factory = ModelProviderFactory() + factory.provider_credentials_validate( + provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} + ) + + +def test__get_model_provider_map(): + factory = ModelProviderFactory() + model_providers = factory._get_model_provider_map() + + for name, model_provider in model_providers.items(): + logger.debug(name) + logger.debug(model_provider.provider_instance) + + assert len(model_providers) >= 1 + assert isinstance(model_providers["openai"], ModelProviderExtension) diff --git a/api/tests/integration_tests/model_runtime/togetherai/__init__.py b/api/tests/integration_tests/model_runtime/togetherai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..5787e1bf6a8d99ed3a5a3f42313de133805e0730 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -0,0 +1,103 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel + + +def test_validate_credentials(): + model = TogetherAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"} + ) + + model.validate_credentials( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, + ) + + +def test_invoke_model(): + model = TogetherAILargeLanguageModel() + + response = model.invoke( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = TogetherAILargeLanguageModel() + + response = model.invoke( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = TogetherAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={ + "api_key": os.environ.get("TOGETHER_API_KEY"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/tongyi/__init__.py b/api/tests/integration_tests/model_runtime/tongyi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..61650735f2ad3fc91c23bbcf1274abcf0924411e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -0,0 +1,75 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel + + +def test_validate_credentials(): + model = TongyiLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"}) + + model.validate_credentials( + model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} + ) + + +def test_invoke_model(): + model = TongyiLargeLanguageModel() + + response = model.invoke( + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = TongyiLargeLanguageModel() + + response = model.invoke( + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = TongyiLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc96c84e73195becb13e6baeef09bdccc28fdbf --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider + + +def test_validate_provider_credentials(): + provider = TongyiProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials( + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} + ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..d37fcf897fc3a8ee82d464ec117c11ec51aa883d --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -0,0 +1,40 @@ +import os + +import dashscope # type: ignore +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.tongyi.rerank.rerank import GTERerankModel + + +def test_validate_credentials(): + model = GTERerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="get-rank", credentials={"dashscope_api_key": "invalid_key"}) + + model.validate_credentials( + model="get-rank", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} + ) + + +def test_invoke_model(): + model = GTERerankModel() + + result = model.invoke( + model=dashscope.TextReRank.Models.gte_rerank, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + query="什么是文本排序模型", + docs=[ + "文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序", + "量子计算是计算科学的一个前沿领域", + "预训练语言模型的发展给文本排序模型带来了新的进展", + ], + score_threshold=0.7, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.7 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py new file mode 100644 index 0000000000000000000000000000000000000000..905e7907fde5a8f1ae3746a2d68d634b8ae6c901 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py @@ -0,0 +1,80 @@ +import json +import os +from collections.abc import Generator + +from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel + + +def test_invoke_model_with_json_response(): + """ + Test the invocation of a model with JSON response. + """ + model_list = [ + "qwen-max-0403", + "qwen-max-1201", + "qwen-max-longcontext", + "qwen-max", + "qwen-plus-chat", + "qwen-plus", + "qwen-turbo-chat", + "qwen-turbo", + ] + for model_name in model_list: + print("testing model: ", model_name) + invoke_model_with_json_response(model_name) + + +def invoke_model_with_json_response(model_name="qwen-max-0403"): + """ + Method to invoke the model with JSON response format. + Args: + model_name (str): The name of the model to invoke. Defaults to "qwen-max-0403". + + Returns: + None + """ + model = TongyiLargeLanguageModel() + + response = model.invoke( + model=model_name, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[ + UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}') + ], + model_parameters={ + "temperature": 0.5, + "max_tokens": 50, + "response_format": "JSON", + }, + stream=True, + user="abc-123", + ) + print("=====================================") + print(response) + assert isinstance(response, Generator) + output = "" + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + output += chunk.delta.message.content + assert is_json(output) + + +def is_json(s): + """ + Check if a string is a valid JSON. + + Args: + s (str): The string to check. + + Returns: + bool: True if the string is a valid JSON, False otherwise. + """ + try: + json.loads(s) + except ValueError: + return False + return True diff --git a/api/tests/integration_tests/model_runtime/upstage/__init__.py b/api/tests/integration_tests/model_runtime/upstage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_llm.py b/api/tests/integration_tests/model_runtime/upstage/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..0f39e902f338035b2939c33bf8c9e63c5732f2ee --- /dev/null +++ b/api/tests/integration_tests/model_runtime/upstage/test_llm.py @@ -0,0 +1,185 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.upstage.llm.llm import UpstageLargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = UpstageLargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = UpstageLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"}) + + model.validate_credentials( + model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = UpstageLargeLanguageModel() + + result = model.invoke( + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = UpstageLargeLanguageModel() + + result = model.invoke( + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="abc-123", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = UpstageLargeLanguageModel() + + result = model.invoke( + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="abc-123", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = UpstageLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 13 + + num_tokens = model.get_num_tokens( + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 106 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_provider.py b/api/tests/integration_tests/model_runtime/upstage/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..9d83779aa00a495bf4d66593bf1fb73b47d7bed7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/upstage/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.upstage.upstage import UpstageProvider +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_provider_credentials(setup_openai_mock): + provider = UpstageProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..8c83172fa3ff7ef9aae499622791f94f97e06d7a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py @@ -0,0 +1,54 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.upstage.text_embedding.text_embedding import UpstageTextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_validate_credentials(setup_openai_mock): + model = UpstageTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"} + ) + + model.validate_credentials( + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) +def test_invoke_model(setup_openai_mock): + model = UpstageTextEmbeddingModel() + + result = model.invoke( + model="solar-embedding-1-large-passage", + credentials={ + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), + }, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = UpstageTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="solar-embedding-1-large-passage", + credentials={ + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 5 diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py b/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..7797d0f8e46a87c894dbad0ee10aa0ce40c61889 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel + + +def test_validate_credentials(): + model = VesslAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": "invalid_key", + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": "http://invalid_url", + "mode": "chat", + }, + ) + + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + +def test_invoke_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = VesslAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py b/api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..f831c063a4263007fff444e90eea710e66862d5e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -0,0 +1,79 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import ( + VolcengineMaaSTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = VolcengineMaaSTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + "base_model_name": "Doubao-embedding", + }, + ) + + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", + }, + ) + + +def test_invoke_model(): + model = VolcengineMaaSTextEmbeddingModel() + + result = model.invoke( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = VolcengineMaaSTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff9c414046e7da158d410c4be069e3cac3cbb27 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -0,0 +1,118 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = VolcengineMaaSLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + }, + ) + + model.validate_credentials( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + }, + ) + + +def test_invoke_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.invoke( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = VolcengineMaaSLargeLanguageModel() + + response = model.get_num_tokens( + model="NOT IMPORTANT", + credentials={ + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], + ) + + assert isinstance(response, int) + assert response == 6 diff --git a/api/tests/integration_tests/model_runtime/voyage/__init__.py b/api/tests/integration_tests/model_runtime/voyage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_provider.py b/api/tests/integration_tests/model_runtime/voyage/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..08978c88a961e748d440f20a35d5ca30332b4375 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_provider.py @@ -0,0 +1,25 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.voyage import VoyageProvider + + +def test_validate_provider_credentials(): + provider = VoyageProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/voyage/test_rerank.py b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..e97a9e4c811c827d74f1c3d63fa6e959813adb32 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py @@ -0,0 +1,92 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel + + +def test_validate_credentials(): + model = VoyageRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="rerank-lite-1", + credentials={"api_key": "invalid_key"}, + ) + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.546875, + "index": 0, + "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United " + "States Census, Carson City had a population of 55,274.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the " + "Pacific Ocean that are a political division controlled by the United States. Its " + "capital is Saipan.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 96}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = VoyageRerankModel() + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.84375, + "index": 0, + "document": "Kasumi is a girl name of Japanese origin meaning mist.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she " + "leads a team named PopiParty.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 59}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl name of Japanese origin meaning mist.", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named " + "PopiParty.", + ], + score_threshold=0.5, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.5 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..75719672a9ecc931b8400cabbbb759850f6affb2 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py @@ -0,0 +1,70 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel + + +def test_validate_credentials(): + model = VoyageTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) + + +def test_invoke_model(): + model = VoyageTextEmbeddingModel() + + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}, + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1}, + ], + "model": "voyage-3", + "usage": {"total_tokens": 2}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = VoyageTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["ping"], + ) + + assert num_tokens == 1 diff --git a/api/tests/integration_tests/model_runtime/wenxin/__init__.py b/api/tests/integration_tests/model_runtime/wenxin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ac38340aecf7d288881c8f74e2daef0d3fd9b459 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -0,0 +1,69 @@ +import os +from time import sleep + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel + + +def test_invoke_embedding_v1(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="embedding-v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_en(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="bge-large-en", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_zh(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="bge-large-zh", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_tao_8k(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model="tao-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e58f15e025d855f92f8134059fc5daf858d863 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -0,0 +1,214 @@ +import os +from collections.abc import Generator +from time import sleep + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel + + +def test_predefined_models(): + model = ErnieBotLargeLanguageModel() + model_schemas = model.predefined_models() + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +def test_validate_credentials_for_chat_model(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} + ) + + model.validate_credentials( + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + ) + + +def test_invoke_model_ernie_bot(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_model_ernie_bot_turbo(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot-turbo", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_model_ernie_8k(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_model_ernie_bot_4(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot-4", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_model(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-3.5-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_model_with_system(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert "kasumi" in response.message.content.lower() + + +def test_invoke_with_search(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.invoke( + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + "disable_search": True, + }, + stop=[], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + total_message = "" + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + total_message += chunk.delta.message.content + print(chunk.delta.message.content) + assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True + + # there should be 对不起、我不能、不支持…… + assert "不" in total_message or "抱歉" in total_message or "无法" in total_message + + +def test_get_num_tokens(): + sleep(3) + model = ErnieBotLargeLanguageModel() + + response = model.get_num_tokens( + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], + ) + + assert isinstance(response, int) + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..337c3d2a8010dd474ba5781d57882a5ed3c5f951 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider + + +def test_validate_provider_credentials(): + provider = WenxinProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"}) + + provider.validate_provider_credentials( + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")} + ) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..33c803e8e1964eea764d3e268b172b91f1032b6e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py @@ -0,0 +1,21 @@ +import os +from time import sleep + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel + + +def test_invoke_bce_reranker_base_v1(): + sleep(3) + model = WenxinRerankModel() + + response = model.invoke( + model="bce-reranker-base_v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + query="What is Deep Learning?", + docs=["Deep Learning is ...", "My Book is ..."], + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 2 diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/api/tests/integration_tests/model_runtime/x/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..647a2f648075e550118eba42fc28d511a06fbfa5 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/x/test_llm.py @@ -0,0 +1,204 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = XAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials( + model="gpt-3.5-turbo", + credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, + ) + + model.validate_credentials( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = XAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/xinference/__init__.py b/api/tests/integration_tests/model_runtime/xinference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..8e778d005a4bc3c7876690f49e7c00949a5e0a4e --- /dev/null +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -0,0 +1,64 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel +from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock + + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_xinference_mock): + model = XinferenceTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-base-en", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, + ) + + model.validate_credentials( + model="bge-base-en", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, + ) + + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) +def test_invoke_model(setup_xinference_mock): + model = XinferenceTextEmbeddingModel() + + result = model.invoke( + model="bge-base-en", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = XinferenceTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="bge-base-en", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, + texts=["hello", "world"], + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4cde36389b5b7e06eb9d590370e53b182d0452 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -0,0 +1,364 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock +from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""}) + + model.validate_credentials( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, + ) + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) +def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + response = model.invoke( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("chat", "none")], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + response = model.invoke( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +""" + Function calling of xinference does not support stream mode currently +""" +# def test_invoke_stream_chat_model_with_functions(): +# model = XinferenceAILargeLanguageModel() + +# response = model.invoke( +# model='ChatGLM3-6b', +# credentials={ +# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), +# 'model_type': 'text-generation', +# 'model_name': 'ChatGLM3', +# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') +# }, +# prompt_messages=[ +# SystemPromptMessage( +# content='你是一个天气机器人,可以通过调用函数来获取天气信息', +# ), +# UserPromptMessage( +# content='波士顿天气如何?' +# ) +# ], +# model_parameters={ +# 'temperature': 0, +# 'top_p': 1.0, +# }, +# stop=['you'], +# user='abc-123', +# stream=True, +# tools=[ +# PromptMessageTool( +# name='get_current_weather', +# description='Get the current weather in a given location', +# parameters={ +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": ["celsius", "fahrenheit"] +# } +# }, +# "required": [ +# "location" +# ] +# } +# ) +# ] +# ) + +# assert isinstance(response, Generator) + +# call: LLMResultChunk = None +# chunks = [] + +# for chunk in response: +# chunks.append(chunk) +# assert isinstance(chunk, LLMResultChunk) +# assert isinstance(chunk.delta, LLMResultChunkDelta) +# assert isinstance(chunk.delta.message, AssistantPromptMessage) +# assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + +# if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0: +# call = chunk +# break + +# assert call is not None +# assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + +# def test_invoke_chat_model_with_functions(): +# model = XinferenceAILargeLanguageModel() + +# response = model.invoke( +# model='ChatGLM3-6b', +# credentials={ +# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), +# 'model_type': 'text-generation', +# 'model_name': 'ChatGLM3', +# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') +# }, +# prompt_messages=[ +# UserPromptMessage( +# content='What is the weather like in San Francisco?' +# ) +# ], +# model_parameters={ +# 'temperature': 0.7, +# 'top_p': 1.0, +# }, +# stop=['you'], +# user='abc-123', +# stream=False, +# tools=[ +# PromptMessageTool( +# name='get_current_weather', +# description='Get the current weather in a given location', +# parameters={ +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": [ +# "c", +# "f" +# ] +# } +# }, +# "required": [ +# "location" +# ] +# } +# ) +# ] +# ) + +# assert isinstance(response, LLMResult) +# assert len(response.message.content) > 0 +# assert response.usage.total_tokens > 0 +# assert response.message.tool_calls[0].function.name == 'get_current_weather' + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) +def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="alapaca", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""}) + + model.validate_credentials( + model="alapaca", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + ) + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) +def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + response = model.invoke( + model="alapaca", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + prompt_messages=[UserPromptMessage(content="the United States is")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +@pytest.mark.parametrize(("setup_openai_mock", "setup_xinference_mock"), [("completion", "none")], indirect=True) +def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): + model = XinferenceAILargeLanguageModel() + + response = model.invoke( + model="alapaca", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + prompt_messages=[UserPromptMessage(content="the United States is")], + model_parameters={ + "temperature": 0.7, + "top_p": 1.0, + }, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = XinferenceAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 77 + + num_tokens = model.get_num_tokens( + model="ChatGLM3", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..71ac4eef7c22bededbfa71637bda452303cbeca6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -0,0 +1,52 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel +from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock + + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) +def test_validate_credentials(setup_xinference_mock): + model = XinferenceRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-base", + credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")}, + ) + + model.validate_credentials( + model="bge-reranker-base", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, + ) + + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) +def test_invoke_model(setup_xinference_mock): + model = XinferenceRerankModel() + + result = model.invoke( + model="bge-reranker-base", + credentials={ + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, + query="Who is Kasumi?", + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", + ], + score_threshold=0.8, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/zhinao/__init__.py b/api/tests/integration_tests/model_runtime/zhinao/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca1b864764818f9901cba06963da47521c09a2c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py @@ -0,0 +1,73 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhinao.llm.llm import ZhinaoLargeLanguageModel + + +def test_validate_credentials(): + model = ZhinaoLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) + + +def test_invoke_model(): + model = ZhinaoLargeLanguageModel() + + response = model.invoke( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = ZhinaoLargeLanguageModel() + + response = model.invoke( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = ZhinaoLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..c22f797919597c17be842ea71ba4631ff2c31532 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhinao.zhinao import ZhinaoProvider + + +def test_validate_provider_credentials(): + provider = ZhinaoProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/__init__.py b/api/tests/integration_tests/model_runtime/zhipuai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..20380513eaa789ac8d3fe621e9850f0adbc117e3 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -0,0 +1,109 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel + + +def test_validate_credentials(): + model = ZhipuAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) + + +def test_invoke_model(): + model = ZhipuAILargeLanguageModel() + + response = model.invoke( + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = ZhipuAILargeLanguageModel() + + response = model.invoke( + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = ZhipuAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 14 + + +def test_get_tools_num_tokens(): + model = ZhipuAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="tools", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert num_tokens == 88 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5bc0b20aafc19dcdebd02254fcaee8522e8e8d --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider + + +def test_validate_provider_credentials(): + provider = ZhipuaiProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..9c97c91ecbdd94a63c0d98b6f7574629d7c26517 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -0,0 +1,41 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel + + +def test_validate_credentials(): + model = ZhipuAITextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) + + +def test_invoke_model(): + model = ZhipuAITextEmbeddingModel() + + result = model.invoke( + model="text_embedding", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 + + +def test_get_num_tokens(): + model = ZhipuAITextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"] + ) + + assert num_tokens == 2 diff --git a/api/tests/integration_tests/tools/__init__.py b/api/tests/integration_tests/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py new file mode 100644 index 0000000000000000000000000000000000000000..42cf87e317ab6a04595da2ca94fc707b3dd00a74 --- /dev/null +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -0,0 +1,34 @@ +import json +from typing import Literal + +import httpx +import pytest +from _pytest.monkeypatch import MonkeyPatch + + +class MockedHttp: + @staticmethod + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: + """ + Mocked httpx.request + """ + request = httpx.Request( + method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") + ) + data = kwargs.get("data") + resp = json.dumps(data).encode("utf-8") if data else b"OK" + response = httpx.Response( + status_code=200, + request=request, + content=resp, + ) + return response + + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py new file mode 100644 index 0000000000000000000000000000000000000000..2860739f0e30b36aebc43f7c6e8f4dd4d6bd391c --- /dev/null +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -0,0 +1,40 @@ +from flask import Flask, request +from flask_restful import Api, Resource # type: ignore + +app = Flask(__name__) +api = Api(app) + +# Mock data +todos_data = { + "global": ["Buy groceries", "Finish project"], + "user1": ["Go for a run", "Read a book"], +} + + +class TodosResource(Resource): + def get(self, username): + todos = todos_data.get(username, []) + return {"todos": todos} + + def post(self, username): + data = request.get_json() + new_todo = data.get("todo") + todos_data.setdefault(username, []).append(new_todo) + return {"message": "Todo added successfully"} + + def delete(self, username): + data = request.get_json() + todo_idx = data.get("todo_idx") + todos = todos_data.get(username, []) + + if 0 <= todo_idx < len(todos): + del todos[todo_idx] + return {"message": "Todo deleted successfully"} + + return {"error": "Invalid todo index"}, 400 + + +api.add_resource(TodosResource, "/todos/") + +if __name__ == "__main__": + app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/api_tool/__init__.py b/api/tests/integration_tests/tools/api_tool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd75b91f745d6af5434361577fc03ddc4b6ea5c --- /dev/null +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -0,0 +1,42 @@ +from core.tools.tool.api_tool import ApiTool +from core.tools.tool.tool import Tool +from tests.integration_tests.tools.__mock.http import setup_http_mock + +tool_bundle = { + "server_url": "http://www.example.com/{path_param}", + "method": "post", + "author": "", + "openapi": { + "parameters": [ + {"in": "path", "name": "path_param"}, + {"in": "query", "name": "query_param"}, + {"in": "cookie", "name": "cookie_param"}, + {"in": "header", "name": "header_param"}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}} + }, + }, + "parameters": [], +} +parameters = { + "path_param": "p_param", + "query_param": "q_param", + "cookie_param": "c_param", + "header_param": "h_param", + "body_param": "b_param", +} + + +def test_api_tool(setup_http_mock): + tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) + headers = tool.assembling_request(parameters) + response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) + + assert response.status_code == 200 + assert response.request.url.path == "/p_param" + assert response.request.url.query == b"query_param=q_param" + assert response.request.headers.get("header_param") == "h_param" + assert response.request.headers.get("content-type") == "application/json" + assert response.request.headers.get("cookie") == "cookie_param=c_param" + assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/tools/code/__init__.py b/api/tests/integration_tests/tools/code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfce749b3e16f399457aa61e4c1ebc8ee85296a --- /dev/null +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -0,0 +1,23 @@ +import pytest + +from core.tools.tool_manager import ToolManager + +provider_generator = ToolManager.list_builtin_providers() +provider_names = [provider.identity.name for provider in provider_generator] +ToolManager.clear_builtin_providers_cache() +provider_generator = ToolManager.list_builtin_providers() + + +@pytest.mark.parametrize("name", provider_names) +def test_tool_providers(benchmark, name): + """ + Test that all tool providers can be loaded + """ + + def test(generator): + try: + return next(generator) + except StopIteration: + return None + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) diff --git a/api/tests/integration_tests/utils/child_class.py b/api/tests/integration_tests/utils/child_class.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e5f341ff76c0c036fd65f2ef244ea8da0c83cd --- /dev/null +++ b/api/tests/integration_tests/utils/child_class.py @@ -0,0 +1,7 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class ChildClass(ParentClass): + def __init__(self, name: str): + super().__init__(name) + self.name = name diff --git a/api/tests/integration_tests/utils/lazy_load_class.py b/api/tests/integration_tests/utils/lazy_load_class.py new file mode 100644 index 0000000000000000000000000000000000000000..ec881a470a3a864441c86e45f5fb56a44c64bfec --- /dev/null +++ b/api/tests/integration_tests/utils/lazy_load_class.py @@ -0,0 +1,7 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class LazyLoadChildClass(ParentClass): + def __init__(self, name: str): + super().__init__(name) + self.name = name diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6de1cc41aaf2ee5fbce92d54eb7fdc3ac0b27e --- /dev/null +++ b/api/tests/integration_tests/utils/parent_class.py @@ -0,0 +1,6 @@ +class ParentClass: + def __init__(self, name): + self.name = name + + def get_name(self): + return self.name diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..50725415e4174e99b7caf4dd94c9b8c49da3ba23 --- /dev/null +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -0,0 +1,34 @@ +import os + +from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source +from tests.integration_tests.utils.parent_class import ParentClass + + +def test_loading_subclass_from_source(): + current_path = os.getcwd() + module = load_single_subclass_from_source( + module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass + ) + assert module + assert module.__name__ == "ChildClass" + + +def test_load_import_module_from_source(): + current_path = os.getcwd() + module = import_module_from_source( + module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") + ) + assert module + assert module.__name__ == "ChildClass" + + +def test_lazy_loading_subclass_from_source(): + current_path = os.getcwd() + clz = load_single_subclass_from_source( + module_name="LazyLoadChildClass", + script_path=os.path.join(current_path, "lazy_load_class.py"), + parent_type=ParentClass, + use_lazy_loader=True, + ) + instance = clz("dify") + assert instance.get_name() == "dify" diff --git a/api/tests/integration_tests/vdb/__init__.py b/api/tests/integration_tests/vdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/__mock/__init__.py b/api/tests/integration_tests/vdb/__mock/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py new file mode 100644 index 0000000000000000000000000000000000000000..4af35a8befcaf887e9826b7fc935f80c2aa624ae --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -0,0 +1,166 @@ +import os +from collections import UserDict +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore +from requests.adapters import HTTPAdapter + + +class AttrDict(UserDict): + def __getattr__(self, item): + return self.get(item) + + +class MockBaiduVectorDBClass: + def mock_vector_db_client( + self, + config=None, + adapter: HTTPAdapter = None, + ): + self.conn = MagicMock() + self._config = MagicMock() + + def list_databases(self, config=None) -> list[Database]: + return [ + Database( + conn=self.conn, + database_name="dify", + config=self._config, + ) + ] + + def create_database(self, database_name: str, config=None) -> Database: + return Database(conn=self.conn, database_name=database_name, config=config) + + def list_table(self, config=None) -> list[Table]: + return [] + + def drop_table(self, table_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def create_table( + self, + table_name: str, + replication: int, + partition: int, + schema, + enable_dynamic_field=False, + description: str = "", + config=None, + ) -> Table: + return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config) + + def describe_table(self, table_name: str, config=None) -> Table: + return Table( + self, + table_name, + 3, + 1, + None, + enable_dynamic_field=False, + description="table for dify", + config=config, + state=TableState.NORMAL, + ) + + def upsert(self, rows, config=None): + return {"code": 0, "msg": "operation success", "affectedCount": 1} + + def rebuild_index(self, index_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def describe_index(self, index_name: str, config=None): + return VectorIndex( + index_name=index_name, + index_type=IndexType.HNSW, + field="vector", + metric_type=MetricType.L2, + params=HNSWParams(m=16, efconstruction=200), + auto_build=False, + state=IndexState.NORMAL, + ) + + def query( + self, + primary_key, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return AttrDict( + { + "row": { + "id": primary_key.get("id"), + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": '{"doc_id": "doc_id_001"}', + }, + "code": 0, + "msg": "Success", + } + ) + + def delete(self, primary_key=None, partition_key=None, filter=None, config=None): + return {"code": 0, "msg": "Success"} + + def search( + self, + anns, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return AttrDict( + { + "rows": [ + { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": '{"doc_id": "doc_id_001"}', + }, + "distance": 0.1, + "score": 0.5, + } + ], + "code": 0, + "msg": "Success", + } + ) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client) + monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases) + monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database) + monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table) + monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table) + monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table) + monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) + monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) + monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) + monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query) + monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py new file mode 100644 index 0000000000000000000000000000000000000000..68a1e290adc1206fb7eb7470e17629bdff33d28b --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -0,0 +1,127 @@ +import os +from typing import Optional + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from requests.adapters import HTTPAdapter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore + + +class MockTcvectordbClass: + def mock_vector_db_client( + self, + url=None, + username="", + key="", + read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, + timeout=5, + adapter: HTTPAdapter = None, + ): + self._conn = None + self._read_consistency = read_consistency + + def list_databases(self) -> list[Database]: + return [ + Database( + conn=self._conn, + read_consistency=self._read_consistency, + name="dify", + ) + ] + + def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: + return [] + + def drop_collection(self, name: str, timeout: Optional[float] = None): + return {"code": 0, "msg": "operation success"} + + def create_collection( + self, + name: str, + shard: int, + replicas: int, + description: str, + index: Index, + embedding: Embedding = None, + timeout: Optional[float] = None, + ) -> Collection: + return Collection( + self, + name, + shard, + replicas, + description, + index, + embedding=embedding, + read_consistency=self._read_consistency, + timeout=timeout, + ) + + def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: + collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) + return collection + + def collection_upsert( + self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs + ): + return {"code": 0, "msg": "operation success"} + + def collection_search( + self, + vectors: list[list[float]], + filter: Filter = None, + params=None, + retrieve_vector: bool = False, + limit: int = 10, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, + ) -> list[list[dict]]: + return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] + + def collection_query( + self, + document_ids: Optional[list] = None, + retrieve_vector: bool = False, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[Filter] = None, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, + ) -> list[dict]: + return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] + + def collection_delete( + self, + document_ids: Optional[list[str]] = None, + filter: Filter = None, + timeout: Optional[float] = None, + ): + return {"code": 0, "msg": "operation success"} + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) + monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) + monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) + monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) + monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) + monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py new file mode 100644 index 0000000000000000000000000000000000000000..4b251ba836e40d3015d06d7d1cc6ede6f096c2fc --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/upstashvectordb.py @@ -0,0 +1,76 @@ +import os +from collections import UserDict +from typing import Optional + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from upstash_vector import Index + + +# Mocking the Index class from upstash_vector +class MockIndex: + def __init__(self, url="", token=""): + self.url = url + self.token = token + self.vectors = [] + + def upsert(self, vectors): + for vector in vectors: + vector.score = 0.5 + self.vectors.append(vector) + return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)} + + def fetch(self, ids): + return [vector for vector in self.vectors if vector.id in ids] + + def delete(self, ids): + self.vectors = [vector for vector in self.vectors if vector.id not in ids] + return {"code": 0, "msg": "Success"} + + def query( + self, + vector: None, + top_k: int = 10, + include_vectors: bool = False, + include_metadata: bool = False, + filter: str = "", + data: Optional[str] = None, + namespace: str = "", + include_data: bool = False, + ): + # Simple mock query, in real scenario you would calculate similarity + mock_result = [] + for vector_data in self.vectors: + mock_result.append(vector_data) + return mock_result[:top_k] + + def reset(self): + self.vectors = [] + + def info(self): + return AttrDict({"dimension": 1024}) + + +class AttrDict(UserDict): + def __getattr__(self, item): + return self.get(item) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Index, "__init__", MockIndex.__init__) + monkeypatch.setattr(Index, "upsert", MockIndex.upsert) + monkeypatch.setattr(Index, "fetch", MockIndex.fetch) + monkeypatch.setattr(Index, "delete", MockIndex.delete) + monkeypatch.setattr(Index, "query", MockIndex.query) + monkeypatch.setattr(Index, "reset", MockIndex.reset) + monkeypatch.setattr(Index, "info", MockIndex.info) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad72e55501f58569e1dbe12a1ffa87dd19db71f --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -0,0 +1,215 @@ +import os +from typing import Union +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from volcengine.viking_db import ( # type: ignore + Collection, + Data, + DistanceType, + Field, + FieldType, + Index, + IndexType, + QuantType, + VectorIndexParams, + VikingDBService, +) + +from core.rag.datasource.vdb.field import Field as vdb_Field + + +class MockVikingDBClass: + def __init__( + self, + host="api-vikingdb.volces.com", + region="cn-north-1", + ak="", + sk="", + scheme="http", + connection_timeout=30, + socket_timeout=30, + proxy=None, + ): + self._viking_db_service = MagicMock() + self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}') + + def get_collection(self, collection_name) -> Collection: + return Collection( + collection_name=collection_name, + description="Collection For Dify", + viking_db_service=self._viking_db_service, + primary_key=vdb_Field.PRIMARY_KEY.value, + fields=[ + Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), + ], + indexes=[ + Index( + collection_name=collection_name, + index_name=f"{collection_name}_idx", + vector_index=VectorIndexParams( + distance=DistanceType.L2, + index_type=IndexType.HNSW, + quant=QuantType.Float, + ), + scalar_index=None, + stat=None, + viking_db_service=self._viking_db_service, + ) + ], + ) + + def drop_collection(self, collection_name): + assert collection_name != "" + + def create_collection(self, collection_name, fields, description="") -> Collection: + return Collection( + collection_name=collection_name, + description=description, + primary_key=vdb_Field.PRIMARY_KEY.value, + viking_db_service=self._viking_db_service, + fields=fields, + ) + + def get_index(self, collection_name, index_name) -> Index: + return Index( + collection_name=collection_name, + index_name=index_name, + viking_db_service=self._viking_db_service, + stat=None, + scalar_index=None, + vector_index=VectorIndexParams( + distance=DistanceType.L2, + index_type=IndexType.HNSW, + quant=QuantType.Float, + ), + ) + + def create_index( + self, + collection_name, + index_name, + vector_index=None, + cpu_quota=2, + description="", + partition_by="", + scalar_index=None, + shard_count=None, + shard_policy=None, + ): + return Index( + collection_name=collection_name, + index_name=index_name, + vector_index=vector_index, + cpu_quota=cpu_quota, + description=description, + partition_by=partition_by, + scalar_index=scalar_index, + shard_count=shard_count, + shard_policy=shard_policy, + viking_db_service=self._viking_db_service, + stat=None, + ) + + def drop_index(self, collection_name, index_name): + assert collection_name != "" + assert index_name != "" + + def upsert_data(self, data: Union[Data, list[Data]]): + assert data is not None + + def fetch_data(self, id: Union[str, list[str], int, list[int]]): + return Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: "{}", + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: id, + vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + }, + id=id, + ) + + def delete_data(self, id: Union[str, list[str], int, list[int]]): + assert id is not None + + def search_by_vector( + self, + vector, + sparse_vectors=None, + filter=None, + limit=10, + output_fields=None, + partition="default", + dense_weight=None, + ) -> list[Data]: + return [ + Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: '\ + {"source": "/var/folders/ml/xxx/xxx.txt", \ + "document_id": "test_document_id", \ + "dataset_id": "test_dataset_id", \ + "doc_id": "test_id", \ + "doc_hash": "test_hash"}', + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: "test_id", + vdb_Field.VECTOR.value: vector, + }, + id="test_id", + score=0.10, + ) + ] + + def search( + self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None + ) -> list[Data]: + return [ + Data( + fields={ + vdb_Field.GROUP_KEY.value: "test_group", + vdb_Field.METADATA_KEY.value: '\ + {"source": "/var/folders/ml/xxx/xxx.txt", \ + "document_id": "test_document_id", \ + "dataset_id": "test_dataset_id", \ + "doc_id": "test_id", \ + "doc_hash": "test_hash"}', + vdb_Field.CONTENT_KEY.value: "content", + vdb_Field.PRIMARY_KEY.value: "test_id", + vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + }, + id="test_id", + score=0.10, + ) + ] + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_vikingdb_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__) + monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection) + monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection) + monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection) + monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index) + monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index) + monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index) + monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data) + monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data) + monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data) + monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector) + monkeypatch.setattr(Index, "search", MockVikingDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/analyticdb/__init__.py b/api/tests/integration_tests/vdb/analyticdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd4754e8ef19ce04f411aa16356f575cc227f38 --- /dev/null +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -0,0 +1,49 @@ +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis + + +class AnalyticdbVectorTest(AbstractVectorTest): + def __init__(self, config_type: str): + super().__init__() + # Analyticdb requires collection_name length less than 60. + # it's ok for normal usage. + self.collection_name = self.collection_name.replace("_test", "") + if config_type == "sql": + self.vector = AnalyticdbVector( + collection_name=self.collection_name, + sql_config=AnalyticdbVectorBySqlConfig( + host="test_host", + port=5432, + account="test_account", + account_password="test_passwd", + namespace="difytest_namespace", + ), + api_config=None, + ) + else: + self.vector = AnalyticdbVector( + collection_name=self.collection_name, + sql_config=None, + api_config=AnalyticdbVectorOpenAPIConfig( + access_key_id="test_key_id", + access_key_secret="test_key_secret", + region_id="test_region", + instance_id="test_id", + account="test_account", + account_password="test_passwd", + namespace="difytest_namespace", + collection="difytest_collection", + namespace_password="test_passwd", + ), + ) + + def run_all_tests(self): + self.vector.delete() + return super().run_all_tests() + + +def test_chroma_vector(setup_mock_redis): + AnalyticdbVectorTest("api").run_all_tests() + AnalyticdbVectorTest("sql").run_all_tests() diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/tests/integration_tests/vdb/baidu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py new file mode 100644 index 0000000000000000000000000000000000000000..25989958d9918c1c768b6599d9b4359490bcaec3 --- /dev/null +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -0,0 +1,31 @@ +from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector +from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class BaiduVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = BaiduVector( + "dify", + BaiduConfig( + endpoint="http://127.0.0.1:5287", + account="root", + api_key="dify", + database="dify", + shard=1, + replicas=3, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock): + BaiduVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/chroma/__init__.py b/api/tests/integration_tests/vdb/chroma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/tests/integration_tests/vdb/chroma/test_chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7b5cbda45b23526c1f4f71ccfc0248c267a676 --- /dev/null +++ b/api/tests/integration_tests/vdb/chroma/test_chroma.py @@ -0,0 +1,33 @@ +import chromadb + +from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class ChromaVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = ChromaVector( + collection_name=self.collection_name, + config=ChromaConfig( + host="localhost", + port=8000, + tenant=chromadb.DEFAULT_TENANT, + database=chromadb.DEFAULT_DATABASE, + auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", + auth_credentials="difyai123456", + ), + ) + + def search_by_full_text(self): + # chroma dos not support full text searching + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_chroma_vector(setup_mock_redis): + ChromaVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/couchbase/__init__.py b/api/tests/integration_tests/vdb/couchbase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py new file mode 100644 index 0000000000000000000000000000000000000000..d76c34ba0e07c35889845d5241fe5f0ba3c3ab5a --- /dev/null +++ b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py @@ -0,0 +1,50 @@ +import subprocess +import time + +from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +def wait_for_healthy_container(service_name="couchbase-server", timeout=300): + start_time = time.time() + while time.time() - start_time < timeout: + result = subprocess.run( + ["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True + ) + if result.stdout.strip() == "healthy": + print(f"{service_name} is healthy!") + return True + else: + print(f"Waiting for {service_name} to be healthy...") + time.sleep(10) + raise TimeoutError(f"{service_name} did not become healthy in time") + + +class CouchbaseTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = CouchbaseVector( + collection_name=self.collection_name, + config=CouchbaseConfig( + connection_string="couchbase://127.0.0.1", + user="Administrator", + password="password", + bucket_name="Embeddings", + scope_name="_default", + ), + ) + + def search_by_vector(self): + # brief sleep to ensure document is indexed + time.sleep(5) + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + +def test_couchbase(setup_mock_redis): + wait_for_healthy_container("couchbase-server", timeout=60) + CouchbaseTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/tests/integration_tests/vdb/elasticsearch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0c1bb03891875ec99cbeb3625a12d83f0fe313 --- /dev/null +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -0,0 +1,20 @@ +from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class ElasticSearchVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] + self.vector = ElasticSearchVector( + index_name=self.collection_name.lower(), + config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + attributes=self.attributes, + ) + + +def test_elasticsearch_vector(setup_mock_redis): + ElasticSearchVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/lindorm/__init__.py b/api/tests/integration_tests/vdb/lindorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py new file mode 100644 index 0000000000000000000000000000000000000000..0a26d3ea1c998774d9f354304c8037b35fd60cbe --- /dev/null +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -0,0 +1,58 @@ +import environs + +from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis + +env = environs.Env() + + +class Config: + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") + SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN") + USING_UGC = env.bool("USING_UGC", True) + + +class TestLindormVectorStore(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name=self.collection_name, + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +class TestLindormVectorStoreUGC(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name="ugc_index_test", + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + using_ugc=Config.USING_UGC, + ), + routing_value=self.collection_name, + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector_ugc(setup_mock_redis): + TestLindormVectorStore().run_all_tests() + TestLindormVectorStoreUGC().run_all_tests() diff --git a/api/tests/integration_tests/vdb/milvus/__init__.py b/api/tests/integration_tests/vdb/milvus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py new file mode 100644 index 0000000000000000000000000000000000000000..0e13f9369e8c94d77c3d62f4e37f4f2509f8066d --- /dev/null +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -0,0 +1,32 @@ +from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class MilvusVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = MilvusVector( + collection_name=self.collection_name, + config=MilvusConfig( + uri="http://localhost:19530", + user="root", + password="Milvus", + ), + ) + + def search_by_full_text(self): + # milvus support BM25 full text search after version 2.5.0-beta + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) >= 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 1 + + +def test_milvus_vector(setup_mock_redis): + MilvusVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/myscale/__init__.py b/api/tests/integration_tests/vdb/myscale/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/tests/integration_tests/vdb/myscale/test_myscale.py new file mode 100644 index 0000000000000000000000000000000000000000..55b2fde42761052b38ae6dfb6a3d65cc22e88e2d --- /dev/null +++ b/api/tests/integration_tests/vdb/myscale/test_myscale.py @@ -0,0 +1,29 @@ +from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class MyScaleVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = MyScaleVector( + collection_name=self.collection_name, + config=MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 1 + + +def test_myscale_vector(setup_mock_redis): + MyScaleVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/oceanbase/__init__.py b/api/tests/integration_tests/vdb/oceanbase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py b/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcb1341683c3c1ea93231c6dc01bc748a02d881 --- /dev/null +++ b/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py @@ -0,0 +1,71 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.oceanbase.oceanbase_vector import ( + OceanBaseVector, + OceanBaseVectorConfig, +) +from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +@pytest.fixture +def oceanbase_vector(): + return OceanBaseVector( + "dify_test_collection", + config=OceanBaseVectorConfig( + host="127.0.0.1", + port="2881", + user="root@test", + database="test", + password="test", + ), + ) + + +class OceanBaseVectorTest(AbstractVectorTest): + def __init__(self, vector: OceanBaseVector): + super().__init__() + self.vector = vector + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 0 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def text_exists(self): + exist = self.vector.text_exists(self.example_doc_id) + assert exist == True + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 0 + + +@pytest.fixture +def setup_mock_oceanbase_client(): + with patch("core.rag.datasource.vdb.oceanbase.oceanbase_vector.ObVecClient", new_callable=MagicMock) as mock_client: + yield mock_client + + +@pytest.fixture +def setup_mock_oceanbase_vector(oceanbase_vector): + with patch.object(oceanbase_vector, "_client"): + yield oceanbase_vector + + +def test_oceanbase_vector( + setup_mock_redis, + setup_mock_oceanbase_client, + setup_mock_oceanbase_vector, + oceanbase_vector, +): + OceanBaseVectorTest(oceanbase_vector).run_all_tests() diff --git a/api/tests/integration_tests/vdb/opensearch/__init__.py b/api/tests/integration_tests/vdb/opensearch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py new file mode 100644 index 0000000000000000000000000000000000000000..35eed75c2f7e69cc8e560516439513092c615f7a --- /dev/null +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -0,0 +1,158 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector +from core.rag.models.document import Document +from extensions import ext_redis + + +def get_example_text() -> str: + return "This is a sample text for testing purposes." + + +@pytest.fixture(scope="module") +def setup_mock_redis(): + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.set = MagicMock(return_value=None) + + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) + + +class TestOpenSearchVector: + def setup_method(self): + self.collection_name = "test_collection" + self.example_doc_id = "example_doc_id" + self.vector = OpenSearchVector( + collection_name=self.collection_name, + config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), + ) + self.vector._client = MagicMock() + + @pytest.mark.parametrize( + ("search_response", "expected_length", "expected_doc_id"), + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) + def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): + self.vector._client.search.return_value = search_response + + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == expected_length + if expected_length > 0: + assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id + + def test_search_by_vector(self): + vector = [0.1] * 128 + mock_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + Field.CONTENT_KEY.value: get_example_text(), + Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, + }, + "_score": 1.0, + } + ], + } + } + self.vector._client.search.return_value = mock_response + + hits_by_vector = self.vector.search_by_vector(query_vector=vector) + + print("Hits by vector:", hits_by_vector) + print("Expected document ID:", self.example_doc_id) + print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") + + assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" + assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, ( + f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + ) + + def test_get_ids_by_metadata_field(self): + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + self.vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) + embedding = [0.1] * 128 + + with patch("opensearchpy.helpers.bulk") as mock_bulk: + mock_bulk.return_value = ([], []) + self.vector.add_texts([doc], [embedding]) + + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_add_texts(self): + self.vector._client.index.return_value = {"result": "created"} + + doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) + embedding = [0.1] * 128 + + with patch("opensearchpy.helpers.bulk") as mock_bulk: + mock_bulk.return_value = ([], []) + self.vector.add_texts([doc], [embedding]) + + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + self.vector._client.search.return_value = mock_response + + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + +@pytest.mark.usefixtures("setup_mock_redis") +class TestOpenSearchVectorWithRedis: + def setup_method(self): + self.tester = TestOpenSearchVector() + + def test_search_by_full_text(self): + self.tester.setup_method() + search_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], + } + } + expected_length = 1 + expected_doc_id = "example_doc_id" + self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) + + def test_get_ids_by_metadata_field(self): + self.tester.setup_method() + self.tester.test_get_ids_by_metadata_field() + + def test_add_texts(self): + self.tester.setup_method() + self.tester.test_add_texts() + + def test_search_by_vector(self): + self.tester.setup_method() + self.tester.test_search_by_vector() diff --git a/api/tests/integration_tests/vdb/oracle/__init__.py b/api/tests/integration_tests/vdb/oracle/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/oracle/test_oraclevector.py b/api/tests/integration_tests/vdb/oracle/test_oraclevector.py new file mode 100644 index 0000000000000000000000000000000000000000..3252b0427609c67d52b01187af5afaa49226008e --- /dev/null +++ b/api/tests/integration_tests/vdb/oracle/test_oraclevector.py @@ -0,0 +1,30 @@ +from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class OracleVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = OracleVector( + collection_name=self.collection_name, + config=OracleVectorConfig( + host="localhost", + port=1521, + user="dify", + password="dify", + database="FREEPDB1", + ), + ) + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_oraclevector(setup_mock_redis): + OracleVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py b/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000000000000000000000000000000000..6497f47deb99fe5b8c60c73ca9463941d8d525ec --- /dev/null +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,35 @@ +from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class PGVectoRSVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = PGVectoRS( + collection_name=self.collection_name.lower(), + config=PgvectoRSConfig( + host="localhost", + port=5431, + user="postgres", + password="difyai123456", + database="dify", + ), + dim=128, + ) + + def search_by_full_text(self): + # pgvecto rs only support english text search, So it’s not open for now + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) == 1 + + +def test_pgvecto_rs(setup_mock_redis): + PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2cfde5d1dc45c58d76c3280dcb775df2d8c4ee --- /dev/null +++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py @@ -0,0 +1,27 @@ +from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class PGVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = PGVector( + collection_name=self.collection_name, + config=PGVectorConfig( + host="localhost", + port=5433, + user="postgres", + password="difyai123456", + database="dify", + min_connection=1, + max_connection=5, + ), + ) + + +def test_pgvector(setup_mock_redis): + PGVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/__init__.py b/api/tests/integration_tests/vdb/qdrant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py new file mode 100644 index 0000000000000000000000000000000000000000..61d9a9e712aade1f63ae2851bba3405e44468455 --- /dev/null +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -0,0 +1,23 @@ +from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class QdrantVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] + self.vector = QdrantVector( + collection_name=self.collection_name, + group_id=self.dataset_id, + config=QdrantConfig( + endpoint="http://localhost:6333", + api_key="difyai123456", + ), + ) + + +def test_qdrant_vector(setup_mock_redis): + QdrantVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/tcvectordb/__init__.py b/api/tests/integration_tests/vdb/tcvectordb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9466e27f4c8ff5eb19de551510baec7b58c261 --- /dev/null +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -0,0 +1,37 @@ +from unittest.mock import MagicMock + +from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector +from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + +mock_client = MagicMock() +mock_client.list_databases.return_value = [{"name": "test"}] + + +class TencentVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = TencentVector( + "dify", + TencentConfig( + url="http://127.0.0.1", + api_key="dify", + timeout=30, + username="dify", + database="dify", + shard=1, + replicas=2, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): + TencentVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py new file mode 100644 index 0000000000000000000000000000000000000000..50519e2052cd889774a84b1a9303cf8169dfaaca --- /dev/null +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -0,0 +1,95 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.rag.models.document import Document +from extensions import ext_redis +from models.dataset import Dataset + + +def get_example_text() -> str: + return "test_text" + + +def get_example_document(doc_id: str) -> Document: + doc = Document( + page_content=get_example_text(), + metadata={ + "doc_id": doc_id, + "doc_hash": doc_id, + "document_id": doc_id, + "dataset_id": doc_id, + }, + ) + return doc + + +@pytest.fixture +def setup_mock_redis() -> None: + # get + ext_redis.redis_client.get = MagicMock(return_value=None) + + # set + ext_redis.redis_client.set = MagicMock(return_value=None) + + # lock + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + ext_redis.redis_client.lock = mock_redis_lock + + +class AbstractVectorTest: + def __init__(self): + self.vector = None + self.dataset_id = str(uuid.uuid4()) + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" + self.example_doc_id = str(uuid.uuid4()) + self.example_embedding = [1.001 * i for i in range(128)] + + def create_vector(self) -> None: + self.vector.create( + texts=[get_example_document(doc_id=self.example_doc_id)], + embeddings=[self.example_embedding], + ) + + def search_by_vector(self): + hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 1 + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id + + def delete_vector(self): + self.vector.delete() + + def delete_by_ids(self, ids: list[str]): + self.vector.delete_by_ids(ids=ids) + + def add_texts(self) -> list[str]: + batch_size = 100 + documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] + embeddings = [self.example_embedding] * batch_size + self.vector.add_texts(documents=documents, embeddings=embeddings) + return [doc.metadata["doc_id"] for doc in documents] + + def text_exists(self): + assert self.vector.text_exists(self.example_doc_id) + + def get_ids_by_metadata_field(self): + with pytest.raises(NotImplementedError): + self.vector.get_ids_by_metadata_field(key="key", value="value") + + def run_all_tests(self): + self.create_vector() + self.search_by_vector() + self.search_by_full_text() + self.text_exists() + self.get_ids_by_metadata_field() + added_doc_ids = self.add_texts() + self.delete_by_ids(added_doc_ids) + self.delete_vector() diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..df0bb3f81aea56ddee590f47df8038fbbb43f22d --- /dev/null +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,38 @@ +import pytest + +from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig +from models.dataset import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +@pytest.fixture +def tidb_vector(): + return TiDBVector( + collection_name="test_collection", + config=TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="", + database="test", + program_name="langgenius/dify", + ), + ) + + +class TiDBVectorTest(AbstractVectorTest): + def __init__(self, vector): + super().__init__() + self.vector = vector + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert len(ids) == 1 + + +def test_tidb_vector(setup_mock_redis, tidb_vector): + TiDBVectorTest(vector=tidb_vector).run_all_tests() diff --git a/api/tests/integration_tests/vdb/upstash/__init__.py b/api/tests/integration_tests/vdb/upstash/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py b/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..23470474ff647f26209388edfc2b7d466b6ec571 --- /dev/null +++ b/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,28 @@ +from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text + + +class UpstashVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = UpstashVector( + collection_name="test_collection", + config=UpstashVectorConfig( + url="your-server-url", + token="your-access-token", + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert len(ids) != 0 + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_upstash_vector(setup_upstashvector_mock): + UpstashVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/vikingdb/__init__.py b/api/tests/integration_tests/vdb/vikingdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py b/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py new file mode 100644 index 0000000000000000000000000000000000000000..2572012ea03aa1e4593495cf0c6190d0fc6c00fd --- /dev/null +++ b/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py @@ -0,0 +1,37 @@ +from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector +from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class VikingDBVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = VikingDBVector( + "test_collection", + "test_group", + config=VikingDBConfig( + access_key="test_access_key", + host="test_host", + region="test_region", + scheme="test_scheme", + secret_key="test_secret_key", + connection_timeout=30, + socket_timeout=30, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id") + assert len(ids) > 0 + + +def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock): + VikingDBVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/weaviate/__init__.py b/api/tests/integration_tests/vdb/weaviate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f55420d312eed2304b49b96ed592786a23a70c --- /dev/null +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -0,0 +1,23 @@ +from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class WeaviateVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] + self.vector = WeaviateVector( + collection_name=self.collection_name, + config=WeaviateConfig( + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", + ), + attributes=self.attributes, + ) + + +def test_weaviate_vector(setup_mock_redis): + WeaviateVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/__init__.py b/api/tests/integration_tests/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/workflow/nodes/__init__.py b/api/tests/integration_tests/workflow/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..30414811ea7986a6bf5a91fbb14a4448b2854608 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -0,0 +1,34 @@ +import os +from typing import Literal + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from jinja2 import Template + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + + +class MockedCodeExecutor: + @classmethod + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: + # invoke directly + match language: + case CodeLanguage.PYTHON3: + return {"result": 3} + case CodeLanguage.JINJA2: + return {"result": Template(code).render(inputs)} + case _: + raise Exception("Language not supported") + + +@pytest.fixture +def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(CodeExecutor, "execute_workflow_code_template", MockedCodeExecutor.invoke) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py new file mode 100644 index 0000000000000000000000000000000000000000..f08d270b4bfce381853ecdcec958ec9457db226e --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -0,0 +1,56 @@ +import os +from json import dumps +from typing import Literal + +import httpx +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from core.helper import ssrf_proxy + +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + + +class MockedHttp: + @staticmethod + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: + """ + Mocked httpx.request + """ + if url == "http://404.com": + response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found") + return response + + # get data, files + data = kwargs.get("data") + files = kwargs.get("files") + json = kwargs.get("json") + content = kwargs.get("content") + if data is not None: + resp = dumps(data).encode("utf-8") + elif files is not None: + resp = dumps(files).encode("utf-8") + elif json is not None: + resp = dumps(json).encode("utf-8") + elif content is not None: + resp = content + else: + resp = b"OK" + + response = httpx.Response( + status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp + ) + return response + + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/__init__.py b/api/tests/integration_tests/workflow/nodes/code_executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..487178ff58066e79f43f2956050956cc357dbc4b --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -0,0 +1,11 @@ +import pytest + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + +CODE_LANGUAGE = "unsupported_language" + + +def test_unsupported_with_code_template(): + with pytest.raises(CodeExecutionError) as e: + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py new file mode 100644 index 0000000000000000000000000000000000000000..09fcb68cf032d4a3dac57399c715a74b6d04863e --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -0,0 +1,38 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer + +CODE_LANGUAGE = CodeLanguage.JAVASCRIPT + + +def test_javascript_plain(): + code = 'console.log("Hello World")' + result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result_message == "Hello World\n" + + +def test_javascript_json(): + code = dedent(""" + obj = {'Hello': 'World'} + console.log(JSON.stringify(obj)) + """) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == '{"Hello":"World"}\n' + + +def test_javascript_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + + +def test_javascript_get_runner_script(): + runner_script = NodeJsTemplateTransformer.get_runner_script() + assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py new file mode 100644 index 0000000000000000000000000000000000000000..94903cf79688e54bdb2bbd4eb2943a039dfa3174 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -0,0 +1,34 @@ +import base64 + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer + +CODE_LANGUAGE = CodeLanguage.JINJA2 + + +def test_jinja2(): + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" + + +def test_jinja2_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} + + +def test_jinja2_get_runner_script(): + runner_script = Jinja2TemplateTransformer.get_runner_script() + assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py new file mode 100644 index 0000000000000000000000000000000000000000..25af312afa4ea43a86b40ddb02a4b49e5ea5eec1 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -0,0 +1,36 @@ +from textwrap import dedent + +from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer + +CODE_LANGUAGE = CodeLanguage.PYTHON3 + + +def test_python3_plain(): + code = 'print("Hello World")' + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == "Hello World\n" + + +def test_python3_json(): + code = dedent(""" + import json + print(json.dumps({'Hello': 'World'})) + """) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == '{"Hello": "World"}\n' + + +def test_python3_with_code_template(): + result = CodeExecutor.execute_workflow_code_template( + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} + ) + assert result == {"result": "HelloWorld"} + + +def test_python3_get_runner_script(): + runner_script = Python3TemplateTransformer.get_runner_script() + assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py new file mode 100644 index 0000000000000000000000000000000000000000..4de985ae7c9dea98167551ca7afb92586e9fcda9 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -0,0 +1,355 @@ +import time +import uuid +from os import getenv +from typing import cast + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.entities import CodeNodeData +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) + + +def init_code_node(code_config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["code", "123", "args1"], 1) + variable_pool.add(["code", "123", "args2"], 2) + + node = CodeNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=code_config, + ) + + return node + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] == 3 + assert result.error is None + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code_output_validator(setup_code_executor_mock): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "string", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "Output variable `result` must be a string" + + +def test_execute_code_output_validator_depth(): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + }, + } + }, + }, + }, + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + node.node_data = cast(CodeNodeData, node.node_data) + + # validate + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": "1", + "string_validator": 1, + "number_array_validator": ["1", "2", "3", "3.333"], + "string_array_validator": [1, 2, 3], + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333] * 2000, + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + +def test_execute_code_output_object_list(): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "object_list": { + "type": "array[object]", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + + # construct result + result = { + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] + } + + node.node_data = cast(CodeNodeData, node.node_data) + + # validate + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py new file mode 100644 index 0000000000000000000000000000000000000000..0507fc707564dd510d504489c0ff274f7db16060 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -0,0 +1,432 @@ +import time +import uuid +from urllib.parse import urlencode + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.http_request.node import HttpRequestNode +from models.enums import UserFrom +from models.workflow import WorkflowType +from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock + + +def init_http_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return HttpRequestNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_get(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_no_auth(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "Auth", + "header": "X-Auth", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_template(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.b123.args2#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", + "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "Template=2" in data + assert "X-Header: 123" in data + assert "X-Header2: 2" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_json(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": { + "type": "json", + "data": [ + { + "key": "", + "type": "text", + "value": '{"a": "{{#a.b123.args1#}}"}', + }, + ], + }, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert '{"a": "1"}' in data + assert "X-Header: 123" in data + + +def test_x_www_form_urlencoded(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": { + "type": "x-www-form-urlencoded", + "data": [ + { + "key": "a", + "type": "text", + "value": "{{#a.b123.args1#}}", + }, + { + "key": "b", + "type": "text", + "value": "{{#a.b123.args2#}}", + }, + ], + }, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "a=1&b=2" in data + assert "X-Header: 123" in data + + +def test_form_data(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": { + "type": "form-data", + "data": [ + { + "key": "a", + "type": "text", + "value": "{{#a.b123.args1#}}", + }, + { + "key": "b", + "type": "text", + "value": "{{#a.b123.args2#}}", + }, + ], + }, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert 'form-data; name="a"' in data + assert "1" in data + assert 'form-data; name="b"' in data + assert "2" in data + assert "X-Header: 123" in data + + +def test_none_data(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "none", "data": []}, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + assert "X-Header: 123" in data + assert "123123123" not in data + + +def test_mock_404(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://404.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "body": None, + "params": "", + "headers": "X-Header:123", + }, + } + ) + + result = node._run() + assert result.outputs is not None + resp = result.outputs + + assert resp.get("status_code") == 404 + assert "Not Found" in resp.get("body", "") + + +def test_multi_colons_parse(setup_http_mock): + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "params": "Referer:http://example1.com\nRedirect:http://example2.com", + "headers": "Referer:http://example3.com\nRedirect:http://example4.com", + "body": { + "type": "form-data", + "data": [ + { + "key": "Referer", + "type": "text", + "value": "http://example5.com", + }, + { + "key": "Redirect", + "type": "text", + "value": "http://example6.com", + }, + ], + }, + }, + } + ) + + result = node._run() + assert result.process_data is not None + assert result.outputs is not None + resp = result.outputs + + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") + assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") + # assert "http://example3.com" == resp.get("headers", {}).get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..9a23949b38939ee45971470f88e36a25b2422520 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -0,0 +1,236 @@ +import json +import os +import time +import uuid +from collections.abc import Generator +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.llm.node import LLMNode +from extensions.ext_database import db +from models.enums import UserFrom +from models.provider import ProviderType +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + + +def init_llm_node(config: dict) -> LLMNode: + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["abc", "output"], "sunny") + + node = LLMNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + return node + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_execute_llm(setup_openai_mock): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, + ) + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None + model_config = ModelConfigWithCredentialsEntity( + model="gpt-3.5-turbo", + provider="openai", + mode="chat", + credentials=credentials, + parameters={}, + model_schema=model_schema, + provider_model_bundle=provider_model_bundle, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) + + # execute node + result = node._run() + assert isinstance(result, Generator) + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): + """ + Test execute LLM node with jinja2 + """ + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, + ) + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") + model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") + assert model_schema is not None + model_config = ModelConfigWithCredentialsEntity( + model="gpt-3.5-turbo", + provider="openai", + mode="chat", + credentials=credentials, + parameters={}, + model_schema=model_schema, + provider_model_bundle=provider_model_bundle, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) + + # execute node + result = node._run() + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..42a058d29bae8d83d35bde00c16b93bb001a3827 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -0,0 +1,414 @@ +import os +import time +import uuid +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from extensions.ext_database import db +from models.enums import UserFrom +from models.provider import ProviderType + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType +from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def get_mocked_fetch_model_config( + provider: str, + model: str, + mode: str, + credentials: dict, +): + provider_instance = ModelProviderFactory().get_provider_instance(provider) + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) + model_schema = model_type_instance.get_model_schema(model) + assert model_schema is not None + model_config = ModelConfigWithCredentialsEntity( + model=model, + provider=provider, + mode=mode, + credentials=credentials, + parameters={}, + model_schema=model_schema, + provider_model_bundle=provider_model_bundle, + ) + + return MagicMock(return_value=(model_instance, model_config)) + + +def get_mocked_fetch_memory(memory_text: str): + class MemoryMock: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ): + return memory_text + + return MagicMock(return_value=MemoryMock()) + + +def init_parameter_extractor_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "b123", "args1"], 1) + variable_pool.add(["a", "b123", "args2"], 2) + + return ParameterExtractorNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_function_calling_parameter_extractor(setup_openai_mock): + """ + Test function calling for parameter extractor. + """ + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "instruction": "", + "reasoning_mode": "function_call", + "memory": None, + }, + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + ) + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_instructions(setup_openai_mock): + """ + Test chat parameter extractor. + """ + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "function_call", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + ) + db.session.close = MagicMock() + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None + + process_data = result.process_data + + assert process_data is not None + process_data.get("prompts") + + for prompt in process_data.get("prompts", []): + if prompt.get("role") == "system": + assert "what's the weather in SF" in prompt.get("text") + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_chat_parameter_extractor(setup_anthropic_mock): + """ + Test chat parameter extractor. + """ + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": None, + }, + }, + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, + ) + db.session.close = MagicMock() + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) + + for prompt in prompts: + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) +def test_completion_parameter_extractor(setup_openai_mock): + """ + Test completion parameter extractor. + """ + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, + }, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider="openai", + model="gpt-3.5-turbo-instruct", + mode="completion", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + ) + db.session.close = MagicMock() + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert result.process_data is not None + assert len(result.process_data.get("prompts", [])) == 1 + assert "SF" in result.process_data.get("prompts", [])[0].get("text") + + +def test_extract_json_response(): + """ + Test extract json response. + """ + + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, + }, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, + ) + + result = node._extract_complete_json_response(""" + uwu{ovo} + { + "location": "kawaii" + } + hello world. + """) + + assert result is not None + assert result["location"] == "kawaii" + + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) +def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): + """ + Test chat parameter extractor with memory. + """ + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": {"window": {"enabled": True, "size": 50}}, + }, + }, + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, + ) + node._fetch_memory = get_mocked_fetch_memory("customized memory") + db.session.close = MagicMock() + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert result.process_data is not None + prompts = result.process_data.get("prompts", []) + + latest_role = None + for prompt in prompts: + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + elif prompt.get("role") == "system": + assert "customized memory" in prompt.get("text") + + if latest_role is not None: + assert latest_role != prompt.get("role") + + if prompt.get("role") in {"user", "assistant"}: + latest_role = prompt.get("role") diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..51d61a95ea4698c4d48279452849dd54dc23dfad --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -0,0 +1,84 @@ +import time +import uuid + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = """{{args2}}""" + config = { + "id": "1", + "data": { + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "123", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + ], + "template": code, + }, + } + + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["1", "123", "args1"], 1) + variable_pool.add(["1", "123", "args2"], 3) + + node = TemplateTransformNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..4068e796b787ef5adb990c3d576d508b228dc0e6 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -0,0 +1,124 @@ +import time +import uuid + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.tool.tool_node import ToolNode +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def init_tool_node(config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [{"data": {"type": "start"}, "id": "start"}, config], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + return ToolNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=config, + ) + + +def test_tool_variable_invoke(): + node = init_tool_node( + config={ + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], + } + }, + }, + } + ) + + node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] + + +def test_tool_mixed_invoke(): + node = init_tool_node( + config={ + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "mixed", + "value": "{{#1.args1#}}", + } + }, + }, + } + ) + + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] diff --git a/api/tests/integration_tests/workflow/test_sync_workflow.py b/api/tests/integration_tests/workflow/test_sync_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..be270cdc49c3450442e73231e8045e021be6775b --- /dev/null +++ b/api/tests/integration_tests/workflow/test_sync_workflow.py @@ -0,0 +1,57 @@ +""" +This test file is used to verify the compatibility of Workflow before and after supporting multiple file types. +""" + +import json + +from models import Workflow + +OLD_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NEW_VERSION_WORKFLOW_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def test_workflow_features(): + workflow = Workflow( + tenant_id="", + app_id="", + type="", + version="", + graph="", + features=json.dumps(OLD_VERSION_WORKFLOW_FEATURES), + created_by="", + environment_variables=[], + conversation_variables=[], + ) + + assert workflow.features_dict == NEW_VERSION_WORKFLOW_FEATURES diff --git a/api/tests/unit_tests/.gitignore b/api/tests/unit_tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..426667562b31dac736680e7aac2c76c06d98a688 --- /dev/null +++ b/api/tests/unit_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py new file mode 100644 index 0000000000000000000000000000000000000000..efa9ea89794b92e8f2c30a1befef6140b4e04fc3 --- /dev/null +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -0,0 +1,102 @@ +import os +from textwrap import dedent + +import pytest +from flask import Flask +from yarl import URL + +from configs.app_config import DifyConfig + +EXAMPLE_ENV_FILENAME = ".env" + + +@pytest.fixture +def example_env_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME) + file_path.write_text( + dedent( + """ + CONSOLE_API_URL=https://example.com + CONSOLE_WEB_URL=https://example.com + HTTP_REQUEST_MAX_WRITE_TIMEOUT=30 + """ + ) + ) + return str(file_path) + + +def test_dify_config_undefined_entry(example_env_file): + # NOTE: See https://github.com/microsoft/pylance-release/issues/6099 for more details about this type error. + # load dotenv file with pydantic-settings + config = DifyConfig(_env_file=example_env_file) + + # entries not defined in app settings + with pytest.raises(TypeError): + # TypeError: 'AppSettings' object is not subscriptable + assert config["LOG_LEVEL"] == "INFO" + + +# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. +# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. +def test_dify_config(example_env_file): + # clear system environment variables + os.environ.clear() + # load dotenv file with pydantic-settings + config = DifyConfig(_env_file=example_env_file) + + # constant values + assert config.COMMIT_SHA == "" + + # default values + assert config.EDITION == "SELF_HOSTED" + assert config.API_COMPRESSION_ENABLED is False + assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 + + # annotated field with default value + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60 + + # annotated field with configured value + assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + + assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3 + + +# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. +# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. +def test_flask_configs(example_env_file): + flask_app = Flask("app") + # clear system environment variables + os.environ.clear() + flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore + config = flask_app.config + + # configs read from pydantic-settings + assert config["LOG_LEVEL"] == "INFO" + assert config["COMMIT_SHA"] == "" + assert config["EDITION"] == "SELF_HOSTED" + assert config["API_COMPRESSION_ENABLED"] is False + assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 + + # value from env file + assert config["CONSOLE_API_URL"] == "https://example.com" + # fallback to alias choices value as CONSOLE_API_URL + assert config["FILES_URL"] == "https://example.com" + + assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" + assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { + "connect_args": { + "options": "-c timezone=UTC", + }, + "max_overflow": 10, + "pool_pre_ping": False, + "pool_recycle": 3600, + "pool_size": 30, + } + + assert config["CONSOLE_WEB_URL"] == "https://example.com" + assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] + assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] + + assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/" + assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1" diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e09acc4c39d1c80fadcffef8faa5a9e71e3ffe23 --- /dev/null +++ b/api/tests/unit_tests/conftest.py @@ -0,0 +1,23 @@ +import os + +import pytest +from flask import Flask + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + +CACHED_APP = Flask(__name__) + + +@pytest.fixture +def app() -> Flask: + return CACHED_APP + + +@pytest.fixture(autouse=True) +def _provide_app_context(app: Flask): + with app.app_context(): + yield diff --git a/api/tests/unit_tests/controllers/test_compare_versions.py b/api/tests/unit_tests/controllers/test_compare_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..9db57a84460c8bd95ed31003cfb083990d624f43 --- /dev/null +++ b/api/tests/unit_tests/controllers/test_compare_versions.py @@ -0,0 +1,24 @@ +import pytest + +from controllers.console.version import _has_new_version + + +@pytest.mark.parametrize( + ("latest_version", "current_version", "expected"), + [ + ("1.0.1", "1.0.0", True), + ("1.1.0", "1.0.0", True), + ("2.0.0", "1.9.9", True), + ("1.0.0", "1.0.0", False), + ("1.0.0", "1.0.1", False), + ("1.0.0", "2.0.0", False), + ("1.0.1", "1.0.0-beta", True), + ("1.0.0", "1.0.0-alpha", True), + ("1.0.0-beta", "1.0.0-alpha", True), + ("1.0.0", "1.0.0-rc1", True), + ("1.0.0", "0.9.9", True), + ("1.0.0", "1.0.0-dev", True), + ], +) +def test_has_new_version(latest_version, current_version, expected): + assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected diff --git a/api/tests/unit_tests/core/__init__.py b/api/tests/unit_tests/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..50a612ec5fe4e47dd7512f997ba93a7d1504cf08 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -0,0 +1,61 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + + +def test_convert_with_vision(): + config = { + "file_upload": { + "enabled": True, + "number_limits": 5, + "allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL], + "image": {"detail": "high"}, + } + } + result = FileUploadConfigManager.convert(config, is_vision=True) + expected = FileUploadConfig( + image_config=ImageConfig( + number_limits=5, + transfer_methods=[FileTransferMethod.REMOTE_URL], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ) + assert result == expected + + +def test_convert_without_vision(): + config = { + "file_upload": { + "enabled": True, + "number_limits": 5, + "allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL], + } + } + result = FileUploadConfigManager.convert(config, is_vision=False) + expected = FileUploadConfig( + image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL]) + ) + assert result == expected + + +def test_validate_and_set_defaults(): + config = {} + result, keys = FileUploadConfigManager.validate_and_set_defaults(config) + assert "file_upload" in result + assert keys == ["file_upload"] + + +def test_validate_and_set_defaults_with_existing_config(): + config = { + "file_upload": { + "enabled": True, + "number_limits": 5, + "allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL], + } + } + result, keys = FileUploadConfigManager.validate_and_set_defaults(config) + assert "file_upload" in result + assert keys == ["file_upload"] + assert result["file_upload"]["enabled"] is True + assert result["file_upload"]["number_limits"] == 5 + assert result["file_upload"]["allowed_file_upload_methods"] == [FileTransferMethod.REMOTE_URL] diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bf43ab0cf1e59f79c5bbcd6ec8978bead296b3 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -0,0 +1,52 @@ +import pytest + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.apps.base_app_generator import BaseAppGenerator + + +def test_validate_inputs_with_zero(): + base_app_generator = BaseAppGenerator() + + var = VariableEntity( + variable="test_var", + label="test_var", + type=VariableEntityType.NUMBER, + required=True, + ) + + # Test with input 0 + result = base_app_generator._validate_inputs( + variable_entity=var, + value=0, + ) + + assert result == 0 + + # Test with input "0" (string) + result = base_app_generator._validate_inputs( + variable_entity=var, + value="0", + ) + + assert result == 0 + + +def test_validate_input_with_none_for_required_variable(): + base_app_generator = BaseAppGenerator() + + for var_type in VariableEntityType: + var = VariableEntity( + variable="test_var", + label="test_var", + type=var_type, + required=True, + ) + + # Test with input None + with pytest.raises(ValueError) as exc_info: + base_app_generator._validate_inputs( + variable_entity=var, + value=None, + ) + + assert str(exc_info.value) == "test_var is required in input form" diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e289c12acfe8c8ec1e5744693f2f0a7c8fed5d --- /dev/null +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -0,0 +1,165 @@ +from uuid import uuid4 + +import pytest + +from core.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectSegment, + SecretVariable, + StringVariable, +) +from core.variables.exc import VariableError +from core.variables.segments import ArrayAnySegment +from factories import variable_factory + + +def test_string_variable(): + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {"value_type": "number", "name": "test_int", "value": 42} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} + result = variable_factory.build_conversation_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping(test_data) + + +def test_build_a_blank_string(): + result = variable_factory.build_conversation_variable_from_mapping( + { + "value_type": "string", + "name": "blank", + "value": "", + } + ) + assert isinstance(result, StringVariable) + assert result.value == "" + + +def test_build_a_object_variable_with_none_value(): + var = variable_factory.build_segment( + { + "key1": None, + } + ) + assert isinstance(var, ObjectSegment) + assert var.value["key1"] is None + + +def test_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, + }, + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ObjectSegment) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) + + +def test_array_string_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayStringVariable) + assert isinstance(variable.value[0], str) + assert isinstance(variable.value[1], str) + + +def test_array_number_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + 1, + 2.0, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayNumberVariable) + assert isinstance(variable.value[0], int) + assert isinstance(variable.value[1], float) + + +def test_array_object_variable(): + mapping = { + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + { + "key1": "text", + "key2": 1, + }, + { + "key1": "text", + "key2": 1, + }, + ], + } + variable = variable_factory.build_conversation_variable_from_mapping(mapping) + assert isinstance(variable, ArrayObjectVariable) + assert isinstance(variable.value[0], dict) + assert isinstance(variable.value[1], dict) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) + + +def test_variable_cannot_large_than_200_kb(): + with pytest.raises(VariableError): + variable_factory.build_conversation_variable_from_mapping( + { + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 201, + } + ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py new file mode 100644 index 0000000000000000000000000000000000000000..1b035d01a7ad55c36626a5b5bb6e38707aac5c9f --- /dev/null +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -0,0 +1,58 @@ +from core.helper import encrypter +from core.variables import SecretVariable, StringVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey + + +def test_segment_group_to_text(): + variable_pool = VariablePool( + system_variables={ + SystemVariableKey("user_id"): "fake-user-id", + }, + user_inputs={}, + environment_variables=[ + SecretVariable(name="secret_key", value="fake-secret-key"), + ], + conversation_variables=[], + ) + variable_pool.add(("node_id", "custom_query"), "fake-user-query") + template = ( + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." + ) + segments_group = variable_pool.convert_template(template) + + assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." + assert segments_group.log == ( + f"Hello, fake-user-id! Your query is fake-user-query." + f" And your key is {encrypter.obfuscated_token('fake-secret-key')}." + ) + + +def test_convert_constant_to_segment_group(): + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + template = "Hello, world!" + segments_group = variable_pool.convert_template(template) + assert segments_group.text == "Hello, world!" + assert segments_group.log == "Hello, world!" + + +def test_convert_variable_to_segment_group(): + variable_pool = VariablePool( + system_variables={ + SystemVariableKey("user_id"): "fake-user-id", + }, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + template = "{{#sys.user_id#}}" + segments_group = variable_pool.convert_template(template) + assert segments_group.text == "fake-user-id" + assert segments_group.log == "fake-user-id" + assert isinstance(segments_group.value[0], StringVariable) + assert segments_group.value[0].value == "fake-user-id" diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..426557c71614190158b32ac2e557bd3828e9f2d9 --- /dev/null +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -0,0 +1,90 @@ +import pytest +from pydantic import ValidationError + +from core.variables import ( + ArrayFileVariable, + ArrayVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + SecretVariable, + SegmentType, + StringVariable, +) + + +def test_frozen_variables(): + var = StringVariable(name="text", value="text") + with pytest.raises(ValidationError): + var.value = "new value" + + int_var = IntegerVariable(name="integer", value=42) + with pytest.raises(ValidationError): + int_var.value = 100 + + float_var = FloatVariable(name="float", value=3.14) + with pytest.raises(ValidationError): + float_var.value = 2.718 + + secret_var = SecretVariable(name="secret", value="secret_value") + with pytest.raises(ValidationError): + secret_var.value = "new_secret_value" + + +def test_variable_value_type_immutable(): + with pytest.raises(ValidationError): + StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text") + + with pytest.raises(ValidationError): + StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"}) + + var = IntegerVariable(name="integer", value=42) + with pytest.raises(ValidationError): + IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) + + var = FloatVariable(name="float", value=3.14) + with pytest.raises(ValidationError): + FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) + + var = SecretVariable(name="secret", value="secret_value") + with pytest.raises(ValidationError): + SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) + + +def test_object_variable_to_object(): + var = ObjectVariable( + name="object", + value={ + "key1": { + "key2": "value2", + }, + "key2": ["value5_1", 42, {}], + }, + ) + + assert var.to_object() == { + "key1": { + "key2": "value2", + }, + "key2": [ + "value5_1", + 42, + {}, + ], + } + + +def test_variable_to_object(): + var = StringVariable(name="text", value="text") + assert var.to_object() == "text" + var = IntegerVariable(name="integer", value=42) + assert var.to_object() == 42 + var = FloatVariable(name="float", value=3.14) + assert var.to_object() == 3.14 + var = SecretVariable(name="secret", value="secret_value") + assert var.to_object() == "secret_value" + + +def test_array_file_variable_is_array_variable(): + var = ArrayFileVariable(name="files", value=[]) + assert isinstance(var, ArrayVariable) diff --git a/api/tests/unit_tests/core/helper/__init__.py b/api/tests/unit_tests/core/helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..c688d3952b698bde37a32e654a49138dabcf4039 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -0,0 +1,52 @@ +import random +from unittest.mock import MagicMock, patch + +import pytest + +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request + + +@patch("httpx.Client.request") +def test_successful_request(mock_request): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + response = make_request("GET", "http://example.com") + assert response.status_code == 200 + + +@patch("httpx.Client.request") +def test_retry_exceed_max_retries(mock_request): + mock_response = MagicMock() + mock_response.status_code = 500 + + side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES + mock_request.side_effect = side_effects + + with pytest.raises(Exception) as e: + make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" + + +@patch("httpx.Client.request") +def test_retry_logic_success(mock_request): + side_effects = [] + + for _ in range(SSRF_DEFAULT_MAX_RETRIES): + status_code = random.choice(STATUS_FORCELIST) + mock_response = MagicMock() + mock_response.status_code = status_code + side_effects.append(mock_response) + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + side_effects.append(mock_response_200) + + mock_request.side_effect = side_effects + + response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) + + assert response.status_code == 200 + assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + assert mock_request.call_args_list[0][1].get("method") == "GET" diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/core/model_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a301f56d75c506d2299ca5f1c9a273083ae56f3a --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -0,0 +1,77 @@ +import string + +import numpy as np + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( + TextEmbedding, + WenxinTextEmbeddingModel, +) + + +def test_max_chunks(): + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += len(text) + + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + model = "embedding-v1" + credentials = { + "api_key": "xxxx", + "secret_key": "yyyy", + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + max_chunks = embedding_model._get_max_chunks(model, credentials) + embedding_model._create_text_embedding = _create_text_embedding + + texts = [string.digits for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") + assert len(result.embeddings) == max_chunks * 2 + + +def test_context_size(): + def get_num_tokens_by_gpt2(text: str) -> int: + return GPT2Tokenizer.get_num_tokens(text) + + def mock_text(token_size: int) -> str: + _text = "".join(["0" for i in range(token_size)]) + num_tokens = get_num_tokens_by_gpt2(_text) + ratio = int(np.floor(len(_text) / num_tokens)) + m_text = "".join([_text for i in range(ratio)]) + return m_text + + model = "embedding-v1" + credentials = { + "api_key": "xxxx", + "secret_key": "yyyy", + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += get_num_tokens_by_gpt2(text) + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + embedding_model._create_text_embedding = _create_text_embedding + text = mock_text(context_size * 2) + assert get_num_tokens_by_gpt2(text) == context_size * 2 + + texts = [text] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") + assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/prompt/__init__.py b/api/tests/unit_tests/core/prompt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d22690d18539d8210c7461f775801387857d4e --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -0,0 +1,190 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.file import File, FileTransferMethod, FileType +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + UserPromptMessage, +) +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from models.model import Conversation + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" + + prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." + prompt_template_config = CompletionModelPromptTemplate(text=prompt_template) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ) + + inputs = {"name": "John"} + files = [] + context = "I am superman." + + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) + + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_completion_model_prompt_messages( + prompt_template=prompt_template_config, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock, + ) + + assert len(prompt_messages) == 1 + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format( + { + "#context#": context, + "#histories#": "\n".join( + [ + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: {prompt.content}" + for prompt in history_prompt_messages + ] + ), + **inputs, + } + ) + + +def test__get_chat_model_prompt_messages(get_chat_model_args): + model_config_mock, memory_config, messages, inputs, context = get_chat_model_args + + files = [] + query = "Hi2." + + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) + + history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock, + ) + + assert len(prompt_messages) == 6 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) + assert prompt_messages[5].content == query + + +def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): + model_config_mock, _, messages, inputs, context = get_chat_model_args + + files = [] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(prompt_messages) == 3 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) + + +def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): + model_config_mock, _, messages, inputs, context = get_chat_model_args + dify_config.MULTIMODAL_SEND_FORMAT = "url" + + files = [ + File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + storage_key="", + ) + ] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: + mock_get_encoded_string.return_value = ImagePromptMessageContent( + url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" + ) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) + assert isinstance(prompt_messages[3].content, list) + assert len(prompt_messages[3].content) == 2 + assert prompt_messages[3].content[1].data == files[0].remote_url + + +@pytest.fixture +def get_chat_model_args(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" + + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) + + prompt_messages = [ + ChatModelMessage( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM + ), + ChatModelMessage(text="Hi.", role=PromptMessageRole.USER), + ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + + inputs = {"name": "John"} + + context = "I am superman." + + return model_config_mock, memory_config, prompt_messages, inputs, context diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd176e65d02f88103c9df4f03a4df9dfdf46928 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import ( + ModelConfigWithCredentialsEntity, +) +from core.entities.provider_configuration import ProviderModelBundle +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from models.model import Conversation + + +def test_get_prompt(): + prompt_messages = [ + SystemPromptMessage(content="System Template"), + UserPromptMessage(content="User Query"), + ] + history_messages = [ + SystemPromptMessage(content="System Prompt 1"), + UserPromptMessage(content="User Prompt 1"), + AssistantPromptMessage(content="Assistant Thought 1"), + ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"), + ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"), + SystemPromptMessage(content="System Prompt 2"), + UserPromptMessage(content="User Prompt 2"), + AssistantPromptMessage(content="Assistant Thought 2"), + ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"), + ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"), + UserPromptMessage(content="User Prompt 3"), + AssistantPromptMessage(content="Assistant Thought 3"), + ] + + # use message number instead of token for testing + def side_effect_get_num_tokens(*args): + return len(args[2]) + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.model = "openai" + model_config_mock.credentials = {} + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) + + transform = AgentHistoryPromptTransform( + model_config=model_config_mock, + prompt_messages=prompt_messages, + history_messages=history_messages, + memory=memory, + ) + + max_token_limit = 5 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 4 + + max_token_limit = 20 + transform._calculate_rest_token = MagicMock(return_value=max_token_limit) + result = transform.get_prompt() + + assert len(result) <= max_token_limit + assert len(result) == 12 diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3c1eb5e032a002e285a20c3b719585a828f15c --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -0,0 +1,91 @@ +from uuid import uuid4 + +from constants import UUID_NIL +from core.prompt.utils.extract_thread_messages import extract_thread_messages + + +class TestMessage: + def __init__(self, id, parent_message_id): + self.id = id + self.parent_message_id = parent_message_id + + def __getitem__(self, item): + return getattr(self, item) + + +def test_extract_thread_messages_single_message(): + messages = [TestMessage(str(uuid4()), UUID_NIL)] + result = extract_thread_messages(messages) + assert len(result) == 1 + assert result[0] == messages[0] + + +def test_extract_thread_messages_linear_thread(): + id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id5, id4), + TestMessage(id4, id3), + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 5 + assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1] + + +def test_extract_thread_messages_branched_thread(): + id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id4, id2), + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id4, id2, id1] + + +def test_extract_thread_messages_empty_list(): + messages = [] + result = extract_thread_messages(messages) + assert len(result) == 0 + + +def test_extract_thread_messages_partially_loaded(): + id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, id0), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id3, id2, id1] + + +def test_extract_thread_messages_legacy_messages(): + id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id3, UUID_NIL), + TestMessage(id2, UUID_NIL), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id3, id2, id1] + + +def test_extract_thread_messages_mixed_with_legacy_messages(): + id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id5, id4), + TestMessage(id4, id2), + TestMessage(id3, id2), + TestMessage(id2, UUID_NIL), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 4 + assert [msg["id"] for msg in result] == [id5, id4, id2, id1] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..89c14463bbfb948b7b3da4a62b856690ef3e4f06 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -0,0 +1,52 @@ +from unittest.mock import MagicMock + +from core.app.app_config.entities import ModelConfigEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.model_runtime.entities.message_entities import UserPromptMessage +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +from core.model_runtime.entities.provider_entities import ProviderEntity +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +def test__calculate_rest_token(): + model_schema_mock = MagicMock(spec=AIModelEntity) + parameter_rule_mock = MagicMock(spec=ParameterRule) + parameter_rule_mock.name = "max_tokens" + model_schema_mock.parameter_rules = [parameter_rule_mock] + model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens.return_value = 6 + + provider_mock = MagicMock(spec=ProviderEntity) + provider_mock.provider = "openai" + + provider_configuration_mock = MagicMock(spec=ProviderConfiguration) + provider_configuration_mock.provider = provider_mock + provider_configuration_mock.model_settings = None + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + provider_model_bundle_mock.configuration = provider_configuration_mock + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.model = "gpt-4" + model_config_mock.credentials = {} + model_config_mock.parameters = {"max_tokens": 50} + model_config_mock.model_schema = model_schema_mock + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + prompt_transform = PromptTransform() + + prompt_messages = [UserPromptMessage(content="Hello, how are you?")] + rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + + # Validate based on the mock configuration and expected logic + expected_rest_tokens = ( + model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters["max_tokens"] + - large_language_model_mock.get_num_tokens.return_value + ) + assert rest_tokens == expected_rest_tokens + assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..c32fc2bc34813d468fd5a95adb09604a222c8d93 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -0,0 +1,247 @@ +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.prompt.simple_prompt_transform import SimplePromptTransform +from models.model import AppMode, Conversation + + +def test_get_common_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] + + +def test_get_baichuan_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] + + +def test_get_common_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] + + +def test_get_baichuan_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + print(prompt_template["prompt_template"].template) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] + + +def test_get_common_chat_app_prompt_template_with_q(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"] + assert prompt_template["special_variable_keys"] == ["#query#"] + + +def test_get_common_chat_app_prompt_template_with_cq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] + + +def test_get_common_chat_app_prompt_template_with_p(): + prompt_transform = SimplePromptTransform() + pre_prompt = "you are {{name}}" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=False, + with_memory_prompt=False, + ) + assert prompt_template["prompt_template"].template == pre_prompt + "\n" + assert prompt_template["custom_variable_keys"] == ["name"] + assert prompt_template["special_variable_keys"] == [] + + +def test__get_chat_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" + + memory_mock = MagicMock(spec=TokenBufferMemory) + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] + memory_mock.get_history_prompt_messages.return_value = history_prompt_messages + + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + + pre_prompt = "You are a helpful assistant {{name}}." + inputs = {"name": "John"} + context = "yes or no." + query = "How are you?" + prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=memory_mock, + model_config=model_config_mock, + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=False, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, "#context#": context} + real_system_prompt = prompt_template["prompt_template"].format(full_inputs) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].content == real_system_prompt + assert prompt_messages[1].content == history_prompt_messages[0].content + assert prompt_messages[2].content == history_prompt_messages[1].content + assert prompt_messages[3].content == query + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" + + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) + + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + pre_prompt = "You are a helpful assistant {{name}}." + inputs = {"name": "John"} + context = "yes or no." + query = "How are you?" + prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=memory, + model_config=model_config_mock, + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + + prompt_rules = prompt_template["prompt_rules"] + full_inputs = { + **inputs, + "#context#": context, + "#query#": query, + "#histories#": memory.get_history_prompt_text( + max_token_limit=2000, + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), + ), + } + real_prompt = prompt_template["prompt_template"].format(full_inputs) + + assert len(prompt_messages) == 1 + assert stops == prompt_rules.get("stops") + assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/rag/__init__.py b/api/tests/unit_tests/core/rag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/rag/datasource/__init__.py b/api/tests/unit_tests/core/rag/datasource/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py new file mode 100644 index 0000000000000000000000000000000000000000..bd414c88f4452cfd4b489056979fc2e31182d13d --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -0,0 +1,18 @@ +import pytest +from pydantic.error_wrappers import ValidationError + +from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig + + +def test_default_value(): + valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} + + for key in valid_config: + config = valid_config.copy() + del config[key] + with pytest.raises(ValidationError) as e: + MilvusConfig(**config) + assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" + + config = MilvusConfig(**valid_config) + assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/__init__.py b/api/tests/unit_tests/core/rag/extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py new file mode 100644 index 0000000000000000000000000000000000000000..607728efd8e28ae112816f8f138387e71af64682 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -0,0 +1,26 @@ +import os + +from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response + + +def test_firecrawl_web_extractor_crawl_mode(mocker): + url = "https://firecrawl.dev" + api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" + base_url = "https://api.firecrawl.dev" + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) + params = { + "includePaths": [], + "excludePaths": [], + "maxDepth": 1, + "limit": 1, + } + mocked_firecrawl = { + "id": "test", + } + mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) + job_id = firecrawl_app.crawl_url(url, params) + print(f"job_id: {job_id}") + + assert job_id is not None + assert isinstance(job_id, str) diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..eea584a2f8edc63d617ccb26d22430d40a7117d9 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -0,0 +1,91 @@ +from unittest import mock + +from core.rag.extractor import notion_extractor + +user_id = "user1" +database_id = "database1" +page_id = "page1" + + +extractor = notion_extractor.NotionExtractor( + notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" +) + + +def _generate_page(page_title: str): + return { + "object": "page", + "id": page_id, + "properties": { + "Page": { + "type": "title", + "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], + } + }, + } + + +def _generate_block(block_id: str, block_type: str, block_text: str): + return { + "object": "block", + "id": block_id, + "parent": {"type": "page_id", "page_id": page_id}, + "type": block_type, + "has_children": False, + block_type: { + "rich_text": [ + { + "type": "text", + "text": {"content": block_text}, + "plain_text": block_text, + } + ] + }, + } + + +def _mock_response(data): + response = mock.Mock() + response.status_code = 200 + response.json.return_value = data + return response + + +def _remove_multiple_new_lines(text): + while "\n\n" in text: + text = text.replace("\n\n", "\n") + return text.strip() + + +def test_notion_page(mocker): + texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] + mocked_notion_page = { + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]), + ], + "next_cursor": None, + } + mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) + + page_docs = extractor._load_data_as_documents(page_id, "page") + assert len(page_docs) == 1 + content = _remove_multiple_new_lines(page_docs[0].page_content) + assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" + + +def test_notion_database(mocker): + page_title_list = ["page1", "page2", "page3"] + mocked_notion_database = { + "object": "list", + "results": [_generate_page(i) for i in page_title_list], + "next_cursor": None, + } + mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) + database_docs = extractor._load_data_as_documents(database_id, "database") + assert len(database_docs) == 1 + content = _remove_multiple_new_lines(database_docs[0].page_content) + assert content == "\n".join([f"Page:{i}" for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py new file mode 100644 index 0000000000000000000000000000000000000000..e02d882780900f373a553a3017d7551717bc5a90 --- /dev/null +++ b/api/tests/unit_tests/core/test_file.py @@ -0,0 +1,56 @@ +import json + +from core.file import File, FileTransferMethod, FileType, FileUploadConfig +from models.workflow import Workflow + + +def test_file_to_dict(): + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image1.jpg", + storage_key="storage_key", + ) + + file_dict = file.to_dict() + assert "_storage_key" not in file_dict + assert "url" in file_dict + + +def test_workflow_features_with_image(): + # Create a feature dict that mimics the old structure with image config + features = { + "file_upload": { + "image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]} + } + } + + # Create a workflow instance with the features + workflow = Workflow( + tenant_id="tenant-1", + app_id="app-1", + type="chat", + version="1.0", + graph="{}", + features=json.dumps(features), + created_by="user-1", + environment_variables=[], + conversation_variables=[], + ) + + # Get the converted features through the property + converted_features = json.loads(workflow.features) + + # Create FileUploadConfig from the converted features + file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"]) + + # Validate the config + assert file_upload_config.number_limits == 5 + assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE] + assert list(file_upload_config.allowed_file_upload_methods) == [ + FileTransferMethod.REMOTE_URL, + FileTransferMethod.LOCAL_FILE, + ] + assert list(file_upload_config.allowed_file_extensions) == [] diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d98e9f6badf57c6c6d8d90d40554b24093874000 --- /dev/null +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -0,0 +1,72 @@ +from unittest.mock import MagicMock, patch + +import pytest +import redis + +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_redis import redis_client + + +@pytest.fixture +def lb_model_manager(): + load_balancing_configs = [ + ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}), + ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}), + ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}), + ] + + lb_model_manager = LBModelManager( + tenant_id="tenant_id", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4", + load_balancing_configs=load_balancing_configs, + managed_credentials={"openai_api_key": "fake_key"}, + ) + + lb_model_manager.cooldown = MagicMock(return_value=None) + + def is_cooldown(config: ModelLoadBalancingConfiguration): + if config.id == "id1": + return True + + return False + + lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown) + + return lb_model_manager + + +def test_lb_model_manager_fetch_next(mocker, lb_model_manager): + # initialize redis client + redis_client.initialize(redis.Redis()) + + assert len(lb_model_manager._load_balancing_configs) == 3 + + config1 = lb_model_manager._load_balancing_configs[0] + config2 = lb_model_manager._load_balancing_configs[1] + config3 = lb_model_manager._load_balancing_configs[2] + + assert lb_model_manager.in_cooldown(config1) is True + assert lb_model_manager.in_cooldown(config2) is False + assert lb_model_manager.in_cooldown(config3) is False + + start_index = 0 + + def incr(key): + nonlocal start_index + start_index += 1 + return start_index + + with ( + patch.object(redis_client, "incr", side_effect=incr), + patch.object(redis_client, "set", return_value=None), + patch.object(redis_client, "expire", return_value=None), + ): + config = lb_model_manager.fetch_next() + assert config == config2 + + config = lb_model_manager.fetch_next() + assert config == config3 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4214a5801de1e9326519d110b14108f652c79c --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -0,0 +1,183 @@ +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting + + +def test__to_model_settings(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == "openai": + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" + + +def test__to_model_settings_only_one_lb(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == "openai": + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ) + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 + + +def test__to_model_settings_lb_disabled(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == "openai": + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_type.py b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py new file mode 100644 index 0000000000000000000000000000000000000000..8a41678267a209f5484a571ec46904b90b5545bf --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_type.py @@ -0,0 +1,49 @@ +from core.tools.entities.tool_entities import ToolParameter + + +def test_get_parameter_type(): + assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string" + assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean" + assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number" + assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file" + assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files" + + +def test_cast_parameter_by_type(): + # string + assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.STRING.cast_value(None) == "" + + # secret input + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == "" + + # select + assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1" + assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0" + assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == "" + + # boolean + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] + for value in true_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True + + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] + for value in false_values: + assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False + + # number + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0 + assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None diff --git a/api/tests/unit_tests/core/workflow/__init__.py b/api/tests/unit_tests/core/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..13ba11016a098fd60d6d54eda78848b9b1742d7a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -0,0 +1,791 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "question-classifier"}, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_iteration_graph(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": {"type": "iteration"}, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + }, + ], + } + + graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration") + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")], + ), + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" + + +def test_parallels_graph(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + start_edges = graph.edge_mapping.get("start") + assert start_edges is not None + assert start_edges[i].target_node_id == f"llm{i + 1}" + + llm_edges = graph.edge_mapping.get(f"llm{i + 1}") + assert llm_edges is not None + assert llm_edges[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph2(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + if i < 2: + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph3(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph4(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "code2", + }, + { + "id": "llm3-source-code3-target", + "source": "llm3", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" + assert graph.edge_mapping.get(f"code{i + 1}") is not None + assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph5(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm4", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm5", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-code1-target", + "source": "llm2", + "target": "code1", + }, + { + "id": "llm3-source-code2-target", + "source": "llm3", + "target": "code2", + }, + { + "id": "llm4-source-code2-target", + "source": "llm4", + "target": "code2", + }, + { + "id": "llm5-source-code3-target", + "source": "llm5", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(5): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm3") is not None + assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm4") is not None + assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm5") is not None + assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 8 + + for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph6(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm1-source-code2-target", + "source": "llm1", + "target": "code2", + }, + { + "id": "llm2-source-code3-target", + "source": "llm2", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "code", + }, + "id": "code1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "code", + }, + "id": "code2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "code", + }, + "id": "code3", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1"}, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code3") is not None + assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 2 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + parent_parallel = None + child_parallel = None + for p_id, parallel in graph.parallel_mapping.items(): + if parallel.parent_parallel_id is None: + parent_parallel = parallel + else: + child_parallel = parallel + + for node_id in ["llm1", "llm2", "llm3", "code3"]: + assert graph.node_parallel_mapping[node_id] == parent_parallel.id + + for node_id in ["code1", "code2"]: + assert graph.node_parallel_mapping[node_id] == child_parallel.id diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d8f69e8c52ee656b17142332c5851905a74142 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -0,0 +1,504 @@ +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + BaseNodeEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_workflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "llm1", + }, + { + "id": "2", + "source": "llm1", + "target": "llm2", + }, + { + "id": "3", + "source": "llm1", + "target": "llm3", + }, + { + "id": "4", + "source": "llm2", + "target": "end1", + }, + { + "id": "5", + "source": "llm3", + "target": "end2", + }, + ], + "nodes": [ + { + "data": { + "type": "start", + "title": "start", + "variables": [ + { + "label": "query", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "query", + } + ], + }, + "id": "start", + }, + { + "data": { + "type": "llm", + "title": "llm1", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say hi"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + "title": "llm2", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say bye"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + "title": "llm3", + "context": {"enabled": False, "variable_selector": []}, + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "prompt_template": [ + {"role": "system", "text": "say good morning"}, + {"role": "user", "text": "{{#start.query#}}"}, + ], + "vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False}, + }, + "id": "llm3", + }, + { + "data": { + "type": "end", + "title": "end1", + "outputs": [ + {"value_selector": ["llm2", "text"], "variable": "result2"}, + {"value_selector": ["start", "query"], "variable": "query"}, + ], + }, + "id": "end1", + }, + { + "data": { + "type": "end", + "title": "end2", + "outputs": [ + {"value_selector": ["llm1", "text"], "variable": "result1"}, + {"value_selector": ["llm3", "text"], "variable": "result3"}, + ], + }, + "id": "end2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent( + chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"] + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + # print("") + + with patch.object(LLMNode, "_run", new=llm_generator): + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}: + assert item.parallel_id is not None + + assert len(items) == 18 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_parallel_in_chatflow(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "answer1", + }, + { + "id": "2", + "source": "answer1", + "target": "answer2", + }, + { + "id": "3", + "source": "answer1", + "target": "answer3", + }, + { + "id": "4", + "source": "answer2", + "target": "answer4", + }, + { + "id": "5", + "source": "answer3", + "target": "answer5", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "start"}, "id": "start"}, + {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"}, + { + "data": {"type": "answer", "title": "answer2", "answer": "2"}, + "id": "answer2", + }, + { + "data": {"type": "answer", "title": "answer3", "answer": "3"}, + "id": "answer3", + }, + { + "data": {"type": "answer", "title": "answer4", "answer": "4"}, + "id": "answer4", + }, + { + "data": {"type": "answer", "title": "answer5", "answer": "5"}, + "id": "answer5", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in { + "answer2", + "answer3", + "answer4", + "answer5", + }: + assert item.parallel_id is not None + + assert len(items) == 23 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == "start" + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == "start" + + +@patch("extensions.ext_database.db.session.remove") +@patch("extensions.ext_database.db.session.close") +def test_run_branch(mock_close, mock_remove): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "if-else-1", + }, + { + "id": "2", + "source": "if-else-1", + "sourceHandle": "true", + "target": "answer-1", + }, + { + "id": "3", + "source": "if-else-1", + "sourceHandle": "false", + "target": "if-else-2", + }, + { + "id": "4", + "source": "if-else-2", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "5", + "source": "if-else-2", + "sourceHandle": "false", + "target": "answer-3", + }, + ], + "nodes": [ + { + "data": { + "title": "Start", + "type": "start", + "variables": [ + { + "label": "uid", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "uid", + } + ], + }, + "id": "start", + }, + { + "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []}, + "id": "answer-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", + "value": "hi", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "title": "IF/ELSE", + "type": "if-else", + }, + "id": "if-else-1", + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "ae895199-5608-433b-b5f0-0997ae1431e4", + "value": "takatost", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "title": "IF/ELSE 2", + "type": "if-else", + }, + "id": "if-else-2", + }, + { + "data": { + "answer": "2", + "title": "Answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "answer": "3", + "title": "Answer 3", + "type": "answer", + }, + "id": "answer-3", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "hi", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + # print("") + + items = [] + generator = graph_engine.run() + for item in generator: + items.append(item) + + assert len(items) == 10 + assert items[3].route_node_state.node_id == "if-else-1" + assert items[4].route_node_state.node_id == "if-else-1" + assert isinstance(items[5], NodeRunStreamChunkEvent) + assert isinstance(items[6], NodeRunStreamChunkEvent) + assert items[6].chunk_content == "takato" + assert items[7].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_id == "answer-1" + assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato" + assert isinstance(items[9], GraphRunSucceededEvent) + + # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) diff --git a/api/tests/unit_tests/core/workflow/nodes/__init__.py b/api/tests/unit_tests/core/workflow/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py b/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..0369f3fa4447fe222052eb941769eb59519c0411 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -0,0 +1,82 @@ +import time +import uuid +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.answer.answer_node import AnswerNode +from extensions.ext_database import db +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_answer(): + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py new file mode 100644 index 0000000000000000000000000000000000000000..bce87536d8e995fd36dfb4b0d235607d58a77689 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py @@ -0,0 +1,109 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + answer_stream_generate_route = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping + ) + + assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"] + assert answer_stream_generate_route.answer_dependencies["answer2"] == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a3818655520ac41245a0ef99dbeda43585fb55 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -0,0 +1,216 @@ +import uuid +from collections.abc import Generator +from datetime import UTC, datetime + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.start.entities import StartNodeData + + +def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + if next_node_id == "start": + yield from _publish_events(graph, next_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _publish_events(graph, edge.target_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _recursive_process(graph, edge.target_node_id) + + +def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) + + parallel_id = graph.node_parallel_mapping.get(next_node_id) + parallel_start_node_id = None + if parallel_id: + parallel = graph.parallel_mapping.get(parallel_id) + parallel_start_node_id = parallel.start_from_node_id if parallel else None + + node_execution_id = str(uuid.uuid4()) + node_config = graph.node_id_config_mapping[next_node_id] + node_type = NodeType(node_config.get("data", {}).get("type")) + mock_node_data = StartNodeData(**{"title": "demo", "variables": []}) + + yield NodeRunStartedEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_start_node_id=parallel_start_node_id, + ) + + if "llm" in next_node_id: + length = int(next_node_id[-1]) + for i in range(0, length): + yield NodeRunStreamChunkEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + chunk_content=str(i), + route_node_state=route_node_state, + from_variable_selector=[next_node_id, "text"], + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + route_node_state.status = RouteNodeState.Status.SUCCESS + route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) + yield NodeRunSucceededEvent( + id=node_execution_id, + node_id=next_node_id, + node_type=node_type, + node_data=mock_node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + +def test_process(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm1", + }, + { + "data": { + "type": "llm", + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + }, + "id": "llm3", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4", + }, + { + "data": { + "type": "llm", + }, + "id": "llm5", + }, + { + "data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"}, + "id": "answer", + }, + { + "data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"}, + "id": "answer2", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + ) + + answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool) + + def graph_generator() -> Generator[GraphEngineEvent, None, None]: + # print("") + for event in _recursive_process(graph, "start"): + # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunSucceededEvent): + if "llm" in event.route_node_state.node_id: + variable_pool.add( + [event.route_node_state.node_id, "text"], + "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))), + ) + yield event + + result_generator = answer_stream_processor.process(graph_generator()) + stream_contents = "" + for event in result_generator: + # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunStreamChunkEvent): + stream_contents += event.chunk_content + pass + + assert stream_contents == "c012da01b" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6b7e4ab6c1fd73c3e03aa39ef1b3a6a3f62aee --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -0,0 +1,140 @@ +from unittest.mock import Mock, PropertyMock, patch + +import httpx +import pytest + +from core.workflow.nodes.http_request.entities import Response + + +@pytest.fixture +def mock_response(): + response = Mock(spec=httpx.Response) + response.headers = {} + return response + + +def test_is_file_with_attachment_disposition(mock_response): + """Test is_file when content-disposition header contains 'attachment'""" + mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} + response = Response(mock_response) + assert response.is_file + + +def test_is_file_with_filename_disposition(mock_response): + """Test is_file when content-disposition header contains filename parameter""" + mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} + response = Response(mock_response) + assert response.is_file + + +@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) +def test_is_file_with_file_content_types(mock_response, content_type): + """Test is_file with various file content types""" + mock_response.headers = {"content-type": content_type} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file, f"Content type {content_type} should be identified as a file" + + +@pytest.mark.parametrize( + "content_type", + [ + "application/json", + "application/xml", + "application/javascript", + "application/x-www-form-urlencoded", + "application/yaml", + "application/graphql", + ], +) +def test_text_based_application_types(mock_response, content_type): + """Test common text-based application types are not identified as files""" + mock_response.headers = {"content-type": content_type} + response = Response(mock_response) + assert not response.is_file, f"Content type {content_type} should not be identified as a file" + + +@pytest.mark.parametrize( + ("content", "content_type"), + [ + (b'{"key": "value"}', "application/octet-stream"), + (b"[1, 2, 3]", "application/unknown"), + (b"function test() {}", "application/x-unknown"), + (b"test", "application/binary"), + (b"var x = 1;", "application/data"), + ], +) +def test_content_based_detection(mock_response, content, content_type): + """Test content-based detection for text-like content""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=content) + response = Response(mock_response) + assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" + + +@pytest.mark.parametrize( + ("content", "content_type"), + [ + (bytes([0x00, 0xFF] * 512), "application/octet-stream"), + (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers + (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers + ], +) +def test_binary_content_detection(mock_response, content, content_type): + """Test content-based detection for binary content""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=content) + response = Response(mock_response) + assert response.is_file, f"Binary content with type {content_type} should be identified as a file" + + +@pytest.mark.parametrize( + ("content_type", "expected_main_type"), + [ + ("x-world/x-vrml", "model"), # VRML 3D model + ("font/ttf", "application"), # TrueType font + ("text/csv", "text"), # CSV text file + ("unknown/xyz", None), # Unknown type + ], +) +def test_mimetype_based_detection(mock_response, content_type, expected_main_type): + """Test detection using mimetypes.guess_type for non-application content types""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content + + with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + # Mock the return value based on expected_main_type + if expected_main_type: + mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) + else: + mock_guess_type.return_value = (None, None) + + response = Response(mock_response) + + # Check if the result matches our expectation + if expected_main_type in ("application", "image", "audio", "video"): + assert response.is_file, f"Content type {content_type} should be identified as a file" + else: + assert not response.is_file, f"Content type {content_type} should not be identified as a file" + + # Verify that guess_type was called + mock_guess_type.assert_called_once() + + +def test_is_file_with_inline_disposition(mock_response): + """Test is_file when content-disposition is 'inline'""" + mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file + + +def test_is_file_with_no_content_disposition(mock_response): + """Test is_file when no content-disposition header is present""" + mock_response.headers = {"content-type": "application/pdf"} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..58b910e17bae4a10f622dab870e8ef787bcbf54e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -0,0 +1,336 @@ +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.executor import Executor + + +def test_executor_with_json_body_and_number_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "number"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Number Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"number": {{#pre_node_id.number#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == [] + assert executor.json == {"number": 42} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '{"number": 42}' in raw_request + + +def test_executor_with_json_body_and_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value="{{#pre_node_id.object#}}", + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == [] + assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_executor_with_json_body_and_nested_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"object": {{#pre_node_id.object#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == [] + assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"object": {' in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_extract_selectors_from_template_with_newline(): + variable_pool = VariablePool() + variable_pool.add(("node_id", "custom_query"), "line1\nline2") + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="test: {{#node_id.custom_query#}}", + body=HttpRequestNodeBody( + type="none", + data=[], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + assert executor.params == [("test", "line1\nline2")] + + +def test_executor_with_form_data(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") + variable_pool.add(["pre_node_id", "number_field"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test Form Data", + method="post", + url="https://api.example.com/upload", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: multipart/form-data", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="text_field", + type="text", + value="{{#pre_node_id.text_field#}}", + ), + BodyData( + key="number_field", + type="text", + value="{{#pre_node_id.number_field#}}", + ), + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/upload" + assert "Content-Type" in executor.headers + assert "multipart/form-data" in executor.headers["Content-Type"] + assert executor.params == [] + assert executor.json is None + assert executor.files is None + assert executor.content is None + + # Check that the form data is correctly loaded in executor.data + assert isinstance(executor.data, dict) + assert "text_field" in executor.data + assert executor.data["text_field"] == "Hello, World!" + assert "number_field" in executor.data + assert executor.data["number_field"] == "42" + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /upload HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: multipart/form-data" in raw_request + assert "text_field" in raw_request + assert "Hello, World!" in raw_request + assert "number_field" in raw_request + assert "42" in raw_request + + +def test_init_headers(): + def create_executor(headers: str) -> Executor: + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers=headers, + params="", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + + executor = create_executor("aa\n cc:") + executor._init_headers() + assert executor.headers == {"aa": "", "cc": ""} + + executor = create_executor("aa:bb\n cc:dd") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + executor = create_executor("aa:bb\n cc:dd\n") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + executor = create_executor("aa:bb\n\n cc : dd\n\n") + executor._init_headers() + assert executor.headers == {"aa": "bb", "cc": "dd"} + + +def test_init_params(): + def create_executor(params: str) -> Executor: + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params=params, + authorization=HttpRequestNodeAuthorization(type="no-auth"), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + + # Test basic key-value pairs + executor = create_executor("key1:value1\nkey2:value2") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] + + # Test empty values + executor = create_executor("key1:\nkey2:") + executor._init_params() + assert executor.params == [("key1", ""), ("key2", "")] + + # Test duplicate keys (which is allowed for params) + executor = create_executor("key1:value1\nkey1:value2") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key1", "value2")] + + # Test whitespace handling + executor = create_executor(" key1 : value1 \n key2 : value2 ") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] + + # Test empty lines and extra whitespace + executor = create_executor("key1:value1\n\nkey2:value2\n\n") + executor._init_params() + assert executor.params == [("key1", "value1"), ("key2", "value2")] diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py new file mode 100644 index 0000000000000000000000000000000000000000..97bacada74572d180b1dca0d61f630db90e93178 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -0,0 +1,196 @@ +import httpx + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.variables import FileVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNode, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_http_request_node_binary_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="binary", + data=[ + BodyData( + key="file", + type="file", + value="", + file=["1111", "file"], + ) + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + storage_key="", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]), + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "test" + + +def test_http_request_node_form_with_file(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/post", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="file", + type="file", + file=["1111", "file"], + ), + BodyData( + key="name", + type="text", + value="test", + ), + ], + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add( + ["1111", "file"], + FileVariable( + name="file", + value=File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1111", + storage_key="", + ), + ), + ) + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda *args, **kwargs: b"test", + ) + + def attr_checker(*args, **kwargs): + assert kwargs["data"] == {"name": "test"} + assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")} + return httpx.Response(200, content=b"") + + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + attr_checker, + ) + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == "" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py new file mode 100644 index 0000000000000000000000000000000000000000..29bd4d6c6ccab1a966a4b50b74f9b6c7504d2c00 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -0,0 +1,860 @@ +import time +import uuid +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_run(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 20 + + +def test_run_parallel(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 32 + + +def test_iteration_run_in_parallel_mode(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + parallel_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + sequential_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + parallel_result = parallel_iteration_node._run() + sequential_result = sequential_iteration_node._run() + assert parallel_iteration_node.node_data.parallel_nums == 10 + assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + count = 0 + parallel_arr = [] + sequential_arr = [] + for item in parallel_result: + count += 1 + parallel_arr.append(item) + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 32 + + for item in sequential_result: + sequential_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 64 + + +def test_iteration_run_error_handle(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "iteration-start", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "tt", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "tt2", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt2", "output"], + "output_type": "array[string]", + "start_node_id": "if-else", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1.split(arg2) }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, + ], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + ], + }, + "id": "tt2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "1", + "variable_selector": ["iteration-1", "item"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["1", "1"]) + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + }, + ) + # execute continue on error node + result = iteration_node._run() + result_arr = [] + count = 0 + for item in result: + result_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": [None, None]} + + assert count == 14 + # execute remove abnormal output + iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + result = iteration_node._run() + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": []} + assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py new file mode 100644 index 0000000000000000000000000000000000000000..74af5eb56b5038ae81b9c8a02c75ab09d4843e0b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -0,0 +1,469 @@ +from collections.abc import Sequence +from typing import Optional + +import pytest + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.file import File, FileTransferMethod, FileType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom +from models.provider import ProviderType +from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario + + +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + storage_key="", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + storage_key="", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + storage_key="", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + storage_key="", + ) + ] + + fake_query = faker.sentence() + + prompt_messages, _ = llm_node._fetch_prompt_messages( + sys_query=fake_query, + sys_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content=fake_query)] + + +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_FORMAT = "url" + + # Generate fake values for prompt template + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + fake_context = faker.sentence() + fake_window_size = faker.random_int(min=1, max=3) + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] + + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, + ) + + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Test scenarios covering different file input combinations + test_scenarios = [ + LLMNodeTestScenario( + description="No files", + sys_query=fake_query, + sys_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + ), + LLMNodeTestScenario( + description="User files", + sys_query=fake_query, + sys_files=[ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", + ) + ], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), + ] + ), + ], + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + sys_query=fake_query, + sys_files=[], + vision_enabled=False, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", + ) + }, + ), + ] + + for scenario in test_scenarios: + model_config.model_schema.features = scenario.features + + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + sys_query=scenario.sys_query, + sys_files=scenario.sys_files, + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario.prompt_template, + memory_config=memory_config, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + # Verify the result + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" + assert prompt_messages == scenario.expected_messages, ( + f"Message content mismatch in scenario: {scenario.description}" + ) + + +def test_handle_list_messages_basic(llm_node): + messages = [ + LLMNodeChatModelMessage( + text="Hello, {#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ] + context = "world" + jinja2_variables = [] + variable_pool = llm_node.graph_runtime_state.variable_pool + vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH + + result = llm_node._handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + ) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == [TextPromptMessageContent(data="Hello, world")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 0000000000000000000000000000000000000000..21bb857353262ca0f5fbef15ed1774dbb9a0c66e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,25 @@ +from collections.abc import Mapping, Sequence + +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + sys_query: str = Field(..., description="User query input") + sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0aa28b4847107edcc1ea4a22ebc8fa1997b487 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -0,0 +1,85 @@ +import time +import uuid +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.answer.answer_node import AnswerNode +from extensions.ext_database import db +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_answer(): + graph_config = { + "edges": [ + { + "id": "start-source-answer-target", + "source": "start", + "target": "answer", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + + node = AnswerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py new file mode 100644 index 0000000000000000000000000000000000000000..2d74be9da9a96c0fcef858c2a725f4580baeb8a4 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -0,0 +1,508 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphRunPartialSucceededEvent, + NodeRunExceptionEvent, + NodeRunStreamChunkEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.graph_engine import GraphEngine +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class ContinueOnErrorTestHelper: + @staticmethod + def get_code_node( + code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} + ): + """Helper method to create a code node configuration""" + node = { + "id": "node", + "data": { + "outputs": {"result": {"type": "number"}}, + "error_strategy": error_strategy, + "title": "code", + "variables": [], + "code_language": "python3", + "code": "\n".join([line[4:] for line in code.split("\n")]), + "type": "code", + **retry_config, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_http_node( + error_strategy: str = "fail-branch", + default_value: dict | None = None, + authorization_success: bool = False, + retry_config: dict = {}, + ): + """Helper method to create a http node configuration""" + authorization = ( + { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + } + if authorization_success + else { + "type": "api-key", + # missing config field + } + ) + node = { + "id": "node", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": authorization, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + "type": "http-request", + "error_strategy": error_strategy, + **retry_config, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a http node configuration""" + node = { + "id": "node", + "data": { + "type": "http-request", + "title": "HTTP Request", + "desc": "", + "variables": [], + "method": "get", + "url": "https://api.github.com/issues", + "authorization": {"type": "no-auth", "config": None}, + "headers": "", + "params": "", + "body": {"type": "none", "data": []}, + "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a tool node configuration""" + node = { + "id": "node", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], + } + }, + "type": "tool", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a llm node configuration""" + node = { + "id": "node", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): + """Helper method to create a graph engine instance for testing""" + graph = Graph.init(graph_config=graph_config) + variable_pool = { + "system_variables": { + SystemVariableKey.QUERY: "clear", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + "user_inputs": user_inputs or {"uid": "takato"}, + } + + return GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + +FAIL_BRANCH_EDGES = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-true-success-target", + "source": "node", + "target": "success", + "sourceHandle": "source", + }, + { + "id": "node-false-error-target", + "source": "node", + "target": "error", + "sourceHandle": "fail-branch", + }, +] + + +def test_code_default_value_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_code_node( + error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_code_fail_branch_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_code_node(error_code), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events + ) + + +def test_http_node_default_value_continue_on_error(): + """Test HTTP node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} + for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_default_value_continue_on_error(): + """Test tool node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_tool_node( + "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_tool_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_default_value_continue_on_error(): + """Test LLM node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_llm_node( + "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_fail_branch_continue_on_error(): + """Test LLM node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_status_code_error_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_variable_pool_error_type_variable(): + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + list(graph_engine.run()) + error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) + error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) + assert error_message != None + assert error_type.value == "HTTPResponseCodeError" + + +def test_no_node_in_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES[:-1], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py new file mode 100644 index 0000000000000000000000000000000000000000..1a550ec5309aa376fd031ccfa4049a807eb8985d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -0,0 +1,178 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.file import File, FileTransferMethod +from core.variables import ArrayFileSegment +from core.variables.variables import StringVariable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from core.workflow.nodes.document_extractor.node import ( + _extract_text_from_doc, + _extract_text_from_pdf, + _extract_text_from_plain_text, +) +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def document_extractor_node(): + node_data = DocumentExtractorNodeData( + title="Test Document Extractor", + variable_selector=["node_id", "variable_name"], + ) + return DocumentExtractorNode( + id="test_node_id", + config={"id": "test_node_id", "data": node_data.model_dump()}, + graph_init_params=Mock(), + graph=Mock(), + graph_runtime_state=Mock(), + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + return Mock() + + +def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = None + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "File variable not found" in result.error + + +def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + mock_graph_runtime_state.variable_pool.get.return_value = StringVariable( + value="Not an ArrayFileSegment", name="test" + ) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error is not None + assert "is not an ArrayFileSegment" in result.error + + +@pytest.mark.parametrize( + ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), + [ + ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "application/pdf", + b"%PDF-1.5\n%Test PDF content", + ["Mocked PDF content"], + FileTransferMethod.LOCAL_FILE, + ".pdf", + ), + ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + b"PK\x03\x04", + ["Mocked DOCX content"], + FileTransferMethod.REMOTE_URL, + "", + ), + ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), + ], +) +def test_run_extract_text( + document_extractor_node, + mock_graph_runtime_state, + mime_type, + file_content, + expected_text, + transfer_method, + extension, + monkeypatch, +): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + mock_file = Mock(spec=File) + mock_file.mime_type = mime_type + mock_file.transfer_method = transfer_method + mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None + mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None + mock_file.extension = extension + + mock_array_file_segment = Mock(spec=ArrayFileSegment) + mock_array_file_segment.value = [mock_file] + + mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment + + mock_download = Mock(return_value=file_content) + mock_ssrf_proxy_get = Mock() + mock_ssrf_proxy_get.return_value.content = file_content + mock_ssrf_proxy_get.return_value.raise_for_status = Mock() + + monkeypatch.setattr("core.file.file_manager.download", mock_download) + monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) + + if mime_type == "application/pdf": + mock_pdf_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + elif mime_type.startswith("application/vnd.openxmlformats"): + mock_docx_extract = Mock(return_value=expected_text[0]) + monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.outputs is not None + assert result.outputs["text"] == expected_text + + if transfer_method == FileTransferMethod.REMOTE_URL: + mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + elif transfer_method == FileTransferMethod.LOCAL_FILE: + mock_download.assert_called_once_with(mock_file) + + +def test_extract_text_from_plain_text(): + text = _extract_text_from_plain_text(b"Hello, world!") + assert text == "Hello, world!" + + +def test_extract_text_from_plain_text_non_utf8(): + import tempfile + + non_utf8_content = b"Hello, world\xa9." # \xA9 represents © in Latin-1 + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + temp_file.write(non_utf8_content) + temp_file.seek(0) + text = _extract_text_from_plain_text(temp_file.read()) + assert text == "Hello, world." + + +@patch("pypdfium2.PdfDocument") +def test_extract_text_from_pdf(mock_pdf_document): + mock_page = Mock() + mock_text_page = Mock() + mock_text_page.get_text_range.return_value = "PDF content" + mock_page.get_textpage.return_value = mock_text_page + mock_pdf_document.return_value = [mock_page] + text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content") + assert text == "PDF content" + + +@patch("docx.Document") +def test_extract_text_from_doc(mock_document): + mock_paragraph1 = Mock() + mock_paragraph1.text = "Paragraph 1" + mock_paragraph2 = Mock() + mock_paragraph2.text = "Paragraph 2" + mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2] + + text = _extract_text_from_doc(b"PK\x03\x04") + assert text == "Paragraph 1\nParagraph 2" + + +def test_node_type(document_extractor_node): + assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py new file mode 100644 index 0000000000000000000000000000000000000000..41e2c5d48468f6721a588b129ffad807023fd84e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -0,0 +1,260 @@ +import time +import uuid +from unittest.mock import MagicMock, Mock + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.variables import ArrayFileSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from extensions.ext_database import db +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_execute_if_else_result_true(): + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "null"], None) + pool.add(["start", "not_null"], "1212") + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", + }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, + { + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", + }, + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is True + + +def test_execute_if_else_result_false(): + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + }, + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is False + + +def test_array_file_contains_file_name(): + node_data = IfElseNodeData( + title="123", + logical_operator="and", + cases=[ + IfElseNodeData.Case( + case_id="true", + logical_operator="and", + conditions=[ + Condition( + comparison_operator="contains", + variable_selector=["start", "array_contains"], + sub_variable_condition=SubVariableCondition( + logical_operator="and", + conditions=[ + SubCondition( + key="name", + comparison_operator="contains", + value="ab", + ) + ], + ), + ) + ], + ) + ], + ) + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=Mock(), + graph=Mock(), + graph_runtime_state=Mock(), + config={ + "id": "if-else", + "data": node_data.model_dump(), + }, + ) + + node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( + value=[ + File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + filename="ab", + storage_key="", + ), + ], + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..36116d35404cf50a3426a9ea8061d499180e336c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -0,0 +1,168 @@ +from unittest.mock import MagicMock + +import pytest + +from core.file import File, FileTransferMethod, FileType +from core.variables import ArrayFileSegment +from core.workflow.nodes.list_operator.entities import ( + ExtractConfig, + FilterBy, + FilterCondition, + Limit, + ListOperatorNodeData, + OrderBy, +) +from core.workflow.nodes.list_operator.exc import InvalidKeyError +from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from models.workflow import WorkflowNodeExecutionStatus + + +@pytest.fixture +def list_operator_node(): + config = { + "variable": ["test_variable"], + "filter_by": FilterBy( + enabled=True, + conditions=[ + FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT]) + ], + ), + "order_by": OrderBy(enabled=False, value="asc"), + "limit": Limit(enabled=False, size=0), + "extract_by": ExtractConfig(enabled=False, serial="1"), + "title": "Test Title", + } + node_data = ListOperatorNodeData(**config) + node = ListOperatorNode( + id="test_node_id", + config={ + "id": "test_node_id", + "data": node_data.model_dump(), + }, + graph_init_params=MagicMock(), + graph=MagicMock(), + graph_runtime_state=MagicMock(), + ) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.variable_pool = MagicMock() + return node + + +def test_filter_files_by_type(list_operator_node): + # Setup test data + files = [ + File( + filename="image1.jpg", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related1", + storage_key="", + ), + File( + filename="document1.pdf", + type=FileType.DOCUMENT, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related2", + storage_key="", + ), + File( + filename="image2.png", + type=FileType.IMAGE, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related3", + storage_key="", + ), + File( + filename="audio1.mp3", + type=FileType.AUDIO, + tenant_id="tenant1", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="related4", + storage_key="", + ), + ] + variable = ArrayFileSegment(value=files) + list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable + + # Run the node + result = list_operator_node._run() + + # Verify the result + expected_files = [ + { + "filename": "image1.jpg", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related1", + }, + { + "filename": "document1.pdf", + "type": FileType.DOCUMENT, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related2", + }, + { + "filename": "image2.png", + "type": FileType.IMAGE, + "tenant_id": "tenant1", + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "related3", + }, + ] + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + for expected_file, result_file in zip(expected_files, result.outputs["result"]): + assert expected_file["filename"] == result_file.filename + assert expected_file["type"] == result_file.type + assert expected_file["tenant_id"] == result_file.tenant_id + assert expected_file["transfer_method"] == result_file.transfer_method + assert expected_file["related_id"] == result_file.related_id + + +def test_get_file_extract_string_func(): + # Create a File object + file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + remote_url="https://example.com/test_file.txt", + related_id="test_related_id", + storage_key="", + ) + + # Test each case + assert _get_file_extract_string_func(key="name")(file) == "test_file.txt" + assert _get_file_extract_string_func(key="type")(file) == "document" + assert _get_file_extract_string_func(key="extension")(file) == ".txt" + assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain" + assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file" + assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt" + + # Test with empty values + empty_file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename=None, + extension=None, + mime_type=None, + remote_url=None, + related_id="test_related_id", + storage_key="", + ) + + assert _get_file_extract_string_func(key="name")(empty_file) == "" + assert _get_file_extract_string_func(key="extension")(empty_file) == "" + assert _get_file_extract_string_func(key="mime_type")(empty_file) == "" + assert _get_file_extract_string_func(key="url")(empty_file) == "" + + # Test invalid key + with pytest.raises(InvalidKeyError): + _get_file_extract_string_func(key="invalid_key") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py new file mode 100644 index 0000000000000000000000000000000000000000..f990280c5f195123d0b57f6c9208a9e073922d6c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -0,0 +1,67 @@ +from core.model_runtime.entities import ImagePromptMessageContent +from core.workflow.nodes.question_classifier import QuestionClassifierNodeData + + +def test_init_question_classifier_node_data(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == True + assert node_data.vision.configs.variable_selector == ["image"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW + + +def test_init_question_classifier_node_data_without_vision_config(): + data = { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + "memory": { + "role_prefix": {"user": "Human:", "assistant": "AI:"}, + "window": {"enabled": True, "size": 5}, + "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", + }, + } + + node_data = QuestionClassifierNodeData(**data) + + assert node_data.query_variable_selector == ["id", "name"] + assert node_data.model.provider == "openai" + assert node_data.classes[0].id == "1" + assert node_data.instruction == "This is a test instruction" + assert node_data.memory is not None + assert node_data.memory.role_prefix is not None + assert node_data.memory.role_prefix.user == "Human:" + assert node_data.memory.role_prefix.assistant == "AI:" + assert node_data.memory.window.enabled == True + assert node_data.memory.window.size == 5 + assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" + assert node_data.vision.enabled == False + assert node_data.vision.configs.variable_selector == ["sys", "files"] + assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac79d7acd5825a93e283b967cb76db140cf3a9c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -0,0 +1,72 @@ +from core.workflow.graph_engine.entities.event import ( + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + NodeRunRetryEvent, +) +from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + + +def test_retry_default_value_partial_success(): + """retry default value node with partial success status""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", + [{"key": "result", "type": "string", "value": "http node got error response"}], + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert events[-1].outputs == {"answer": "http node got error response"} + assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) + assert len(events) == 11 + + +def test_retry_failed(): + """retry failed with success status""" + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + None, + None, + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert any(isinstance(e, GraphRunFailedEvent) for e in events) + assert len(events) == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..9793da129d714e078c1290a967c3398b8a8c6ee7 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -0,0 +1,257 @@ +import time +import uuid +from unittest import mock +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import ArrayStringVariable, StringVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode +from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode +from models.enums import UserFrom +from models.workflow import WorkflowType + +DEFAULT_NODE_ID = "node_id" + + +def test_overwrite_string_variable(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + conversation_variable = StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + + input_variable = StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.value == "the second value" + assert got.to_object() == "the second value" + + +def test_append_variable_to_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value=["the first value"], + ) + + input_variable = StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + variable_pool.add( + [DEFAULT_NODE_ID, input_variable.name], + input_variable, + ) + + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], + }, + }, + ) + + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.to_object() == ["the first value", "the second value"] + + +def test_clear_array(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "data": { + "type": "assigner", + }, + "id": "assigner", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + conversation_variable = ArrayStringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value=["the first value"], + ) + + variable_pool = VariablePool( + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + user_inputs={}, + environment_variables=[], + conversation_variables=[conversation_variable], + ) + + node = VariableAssignerNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config={ + "id": "node_id", + "data": { + "title": "test", + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], + }, + }, + ) + + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: + list(node.run()) + mock_run.assert_called_once() + + got = variable_pool.get(["conversation", conversation_variable.name]) + assert got is not None + assert got.to_object() == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1501722b82a52eb9934d77e59d20209f9b2b749f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -0,0 +1,22 @@ +from core.variables import SegmentType +from core.workflow.nodes.variable_assigner.v2.enums import Operation +from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid + + +def test_is_input_value_valid_overwrite_array_string(): + # Valid cases + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] + ) + assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) + + # Invalid cases + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" + ) + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] + ) + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] + ) diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..efbcdc760c69953e75e5e6416b8808d13b825db9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -0,0 +1,46 @@ +import pytest + +from core.file import File, FileTransferMethod, FileType +from core.variables import FileSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool + + +@pytest.fixture +def pool(): + return VariablePool(system_variables={}, user_inputs={}) + + +@pytest.fixture +def file(): + return File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + storage_key="", + ) + + +def test_get_file_attribute(pool, file): + # Add a FileSegment to the pool + pool.add(("node_1", "file_var"), FileSegment(value=file)) + + # Test getting the 'name' attribute of the file + result = pool.get(("node_1", "file_var", "name")) + + assert result is not None + assert result.value == file.filename + + # Test getting a non-existent attribute + result = pool.get(("node_1", "file_var", "non_existent_attr")) + assert result is None + + +def test_use_long_selector(pool): + pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) + + result = pool.get(("node_1", "part_1", "part_2")) + assert result is not None + assert result.value == "test_value" diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..2f90afcf8908da5d854733cca783a7ef0b4fd557 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -0,0 +1,28 @@ +from core.variables import SecretVariable +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.utils import variable_template_parser + + +def test_extract_selectors_from_template(): + variable_pool = VariablePool( + system_variables={ + SystemVariableKey("user_id"): "fake-user-id", + }, + user_inputs={}, + environment_variables=[ + SecretVariable(name="secret_key", value="fake-secret-key"), + ], + conversation_variables=[], + ) + variable_pool.add(("node_id", "custom_query"), "fake-user-query") + template = ( + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." + ) + selectors = variable_template_parser.extract_selectors_from_template(template) + assert selectors == [ + VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]), + VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), + VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), + ] diff --git a/api/tests/unit_tests/libs/test_email.py b/api/tests/unit_tests/libs/test_email.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0177791b5986ce489f20169f1e35193c7f87a8 --- /dev/null +++ b/api/tests/unit_tests/libs/test_email.py @@ -0,0 +1,21 @@ +import pytest + +from libs.helper import email + + +def test_email_with_valid_email(): + assert email("test@example.com") == "test@example.com" + assert email("TEST12345@example.com") == "TEST12345@example.com" + assert email("test+test@example.com") == "test+test@example.com" + assert email("!#$%&'*+-/=?^_{|}~`@example.com") == "!#$%&'*+-/=?^_{|}~`@example.com" + + +def test_email_with_invalid_email(): + with pytest.raises(ValueError, match="invalid_email is not a valid email."): + email("invalid_email") + + with pytest.raises(ValueError, match="@example.com is not a valid email."): + email("@example.com") + + with pytest.raises(ValueError, match="()@example.com is not a valid email."): + email("()@example.com") diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py new file mode 100644 index 0000000000000000000000000000000000000000..21c2f0781d85f9e9ee550741ac2ee382dabd7531 --- /dev/null +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -0,0 +1,58 @@ +import pandas as pd + + +def test_pandas_csv(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} + df1 = pd.DataFrame(data) + + # write to csv file + csv_file_path = tmp_path.joinpath("example.csv") + df1.to_csv(csv_file_path, index=False) + + # read from csv file + df2 = pd.read_csv(csv_file_path, on_bad_lines="skip") + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] + + +def test_pandas_xlsx(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} + df1 = pd.DataFrame(data) + + # write to xlsx file + xlsx_file_path = tmp_path.joinpath("example.xlsx") + df1.to_excel(xlsx_file_path, index=False) + + # read from xlsx file + df2 = pd.read_excel(xlsx_file_path) + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] + + +def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} + df1 = pd.DataFrame(data1) + + data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]} + df2 = pd.DataFrame(data2) + + # write to xlsx file with sheets + xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx") + sheet1 = "Sheet1" + sheet2 = "Sheet2" + with pd.ExcelWriter(xlsx_file_path) as excel_writer: + df1.to_excel(excel_writer, sheet_name=sheet1, index=False) + df2.to_excel(excel_writer, sheet_name=sheet2, index=False) + + # read from xlsx file with sheets + with pd.ExcelFile(xlsx_file_path) as excel_file: + df1 = pd.read_excel(excel_file, sheet_name=sheet1) + assert df1[df1.columns[0]].to_list() == data1["col1"] + assert df1[df1.columns[1]].to_list() == data1["col2"] + + df2 = pd.read_excel(excel_file, sheet_name=sheet2) + assert df2[df2.columns[0]].to_list() == data2["col1"] + assert df2[df2.columns[1]].to_list() == data2["col2"] diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc51252f00e72a4c57e9a9c0d02e59c403edc01 --- /dev/null +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -0,0 +1,29 @@ +import rsa as pyrsa +from Crypto.PublicKey import RSA + +from libs import gmpy2_pkcs10aep_cipher + + +def test_gmpy2_pkcs10aep_cipher() -> None: + rsa_key_pair = pyrsa.newkeys(2048) + public_key = rsa_key_pair[0].save_pkcs1() + private_key = rsa_key_pair[1].save_pkcs1() + + public_rsa_key = RSA.import_key(public_key) + public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key) + + private_rsa_key = RSA.import_key(private_key) + private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key) + + raw_text = "raw_text" + raw_text_bytes = raw_text.encode() + + # RSA encryption by public key and decryption by private key + encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes) + decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key) + assert decrypted_by_pub_key == raw_text_bytes + + # RSA encryption and decryption by private key + encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes) + decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key) + assert decrypted_by_private_key == raw_text_bytes diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py new file mode 100644 index 0000000000000000000000000000000000000000..b9aee4af5f31c7b60d2a0ec85eedda09d82b371d --- /dev/null +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -0,0 +1,23 @@ +import pytest +from yarl import URL + + +def test_yarl_urls(): + expected_1 = "https://dify.ai/api" + assert str(URL("https://dify.ai") / "api") == expected_1 + assert str(URL("https://dify.ai/") / "api") == expected_1 + + expected_2 = "http://dify.ai:12345/api" + assert str(URL("http://dify.ai:12345") / "api") == expected_2 + assert str(URL("http://dify.ai:12345/") / "api") == expected_2 + + expected_3 = "https://dify.ai/api/v1" + assert str(URL("https://dify.ai") / "api" / "v1") == expected_3 + assert str(URL("https://dify.ai") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/api") / "v1") == expected_3 + assert str(URL("https://dify.ai/api/") / "v1") == expected_3 + + with pytest.raises(ValueError) as e1: + str(URL("https://dify.ai") / "/api") + assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py new file mode 100644 index 0000000000000000000000000000000000000000..026912ffbed300347b256cbbe9caebdf6f8233d3 --- /dev/null +++ b/api/tests/unit_tests/models/test_account.py @@ -0,0 +1,14 @@ +from models.account import TenantAccountRole + + +def test_account_is_privileged_role() -> None: + assert TenantAccountRole.ADMIN == "admin" + assert TenantAccountRole.OWNER == "owner" + assert TenantAccountRole.EDITOR == "editor" + assert TenantAccountRole.NORMAL == "normal" + + assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN) + assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER) + assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL) + assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR) + assert not TenantAccountRole.is_privileged_role("") diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..5d84a2ec85d39f544bd6ae32494c8857c09f1ca5 --- /dev/null +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -0,0 +1,26 @@ +from uuid import uuid4 + +from core.variables import SegmentType +from factories import variable_factory +from models import ConversationVariable + + +def test_from_variable_and_to_variable(): + variable = variable_factory.build_conversation_variable_from_mapping( + { + "id": str(uuid4()), + "name": "name", + "value_type": SegmentType.OBJECT, + "value": { + "key": { + "key": "value", + } + }, + } + ) + + conversation_variable = ConversationVariable.from_variable( + app_id="app_id", conversation_id="conversation_id", variable=variable + ) + + assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..fe56f18f1b1d05a7e8900fae39bead1fe278b3d5 --- /dev/null +++ b/api/tests/unit_tests/models/test_workflow.py @@ -0,0 +1,139 @@ +from unittest import mock +from uuid import uuid4 + +import contexts +from constants import HIDDEN_VALUE +from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from models.workflow import Workflow + + +def test_environment_variables(): + contexts.tenant_id.set("tenant_id") + + # Create a Workflow instance + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) + + # Create some EnvironmentVariable instances + variable1 = StringVariable.model_validate( + {"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]} + ) + variable2 = IntegerVariable.model_validate( + {"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]} + ) + variable3 = SecretVariable.model_validate( + {"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]} + ) + variable4 = FloatVariable.model_validate( + {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} + ) + + with ( + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + ): + # Set the environment_variables property of the Workflow instance + variables = [variable1, variable2, variable3, variable4] + workflow.environment_variables = variables + + # Get the environment_variables property and assert its value + assert workflow.environment_variables == variables + + +def test_update_environment_variables(): + contexts.tenant_id.set("tenant_id") + + # Create a Workflow instance + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) + + # Create some EnvironmentVariable instances + variable1 = StringVariable.model_validate( + {"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]} + ) + variable2 = IntegerVariable.model_validate( + {"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]} + ) + variable3 = SecretVariable.model_validate( + {"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]} + ) + variable4 = FloatVariable.model_validate( + {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} + ) + + with ( + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + ): + variables = [variable1, variable2, variable3, variable4] + + # Set the environment_variables property of the Workflow instance + workflow.environment_variables = variables + assert workflow.environment_variables == [variable1, variable2, variable3, variable4] + + # Update the name of variable3 and keep the value as it is + variables[2] = variable3.model_copy( + update={ + "name": "new name", + "value": HIDDEN_VALUE, + } + ) + + workflow.environment_variables = variables + assert workflow.environment_variables[2].name == "new name" + assert workflow.environment_variables[2].value == variable3.value + + +def test_to_dict(): + contexts.tenant_id.set("tenant_id") + + # Create a Workflow instance + workflow = Workflow( + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", + environment_variables=[], + conversation_variables=[], + ) + + # Create some EnvironmentVariable instances + + with ( + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + ): + # Set the environment_variables property of the Workflow instance + workflow.environment_variables = [ + SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}), + StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}), + ] + + workflow_dict = workflow.to_dict() + assert workflow_dict["environment_variables"][0]["value"] == "" + assert workflow_dict["environment_variables"][1]["value"] == "text" + + workflow_dict = workflow.to_dict(include_secret=True) + assert workflow_dict["environment_variables"][0]["value"] == "secret" + assert workflow_dict["environment_variables"][1]["value"] == "text" diff --git a/api/tests/unit_tests/oss/__init__.py b/api/tests/unit_tests/oss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/__mock/__init__.py b/api/tests/unit_tests/oss/__mock/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6d8a2f54a4fd92c0776c5ff17bfccaa6e4db7b --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -0,0 +1,100 @@ +import os +import posixpath +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore + +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, + get_example_folder, +) + + +class MockResponse: + def __init__(self, status, headers, request_id): + self.status = status + self.headers = headers + self.request_id = request_id + + +class MockAliyunOssClass: + def __init__( + self, + auth, + endpoint, + bucket_name, + is_cname=False, + session=None, + connect_timeout=None, + app_name="", + enable_crc=True, + proxies=None, + region=None, + cloudbox_id=None, + is_path_style=False, + is_verify_object_strict=True, + ): + self.bucket_name = get_example_bucket() + self.key = posixpath.join(get_example_folder(), get_example_filename()) + self.content = get_example_data() + self.filepath = get_example_filepath() + self.resp = MockResponse( + 200, + { + "etag": "ee8de918d05640145b18f70f4c3aa602", + "x-oss-version-id": "CAEQNhiBgMDJgZCA0BYiIDc4MGZjZGI2OTBjOTRmNTE5NmU5NmFhZjhjYmY0****", + }, + "request_id", + ) + + def put_object(self, key, data, headers=None, progress_callback=None): + assert key == self.key + assert data == self.content + return PutObjectResult(self.resp) + + def get_object(self, key, byte_range=None, headers=None, progress_callback=None, process=None, params=None): + assert key == self.key + + get_object_output = MagicMock(GetObjectResult) + get_object_output.read.return_value = self.content + return get_object_output + + def get_object_to_file( + self, key, filename, byte_range=None, headers=None, progress_callback=None, process=None, params=None + ): + assert key == self.key + assert filename == self.filepath + + def object_exists(self, key, headers=None): + assert key == self.key + return True + + def delete_object(self, key, params=None, headers=None): + assert key == self.key + self.resp.headers["x-oss-delete-marker"] = True + return self.resp + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_aliyun_oss_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Bucket, "__init__", MockAliyunOssClass.__init__) + monkeypatch.setattr(Bucket, "put_object", MockAliyunOssClass.put_object) + monkeypatch.setattr(Bucket, "get_object", MockAliyunOssClass.get_object) + monkeypatch.setattr(Bucket, "get_object_to_file", MockAliyunOssClass.get_object_to_file) + monkeypatch.setattr(Bucket, "object_exists", MockAliyunOssClass.object_exists) + monkeypatch.setattr(Bucket, "delete_object", MockAliyunOssClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3c9716c3c48b676d26d35dcbd56b056d4a3008 --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -0,0 +1,62 @@ +from collections.abc import Generator + +import pytest + +from extensions.storage.base_storage import BaseStorage + + +def get_example_folder() -> str: + return "~/dify" + + +def get_example_bucket() -> str: + return "dify" + + +def get_opendal_bucket() -> str: + return "./dify" + + +def get_example_filename() -> str: + return "test.txt" + + +def get_example_data() -> bytes: + return b"test" + + +def get_example_filepath() -> str: + return "~/test" + + +class BaseStorageTest: + @pytest.fixture(autouse=True) + def setup_method(self, *args, **kwargs): + """Should be implemented in child classes to setup specific storage.""" + self.storage: BaseStorage + + def test_save(self): + """Test saving data.""" + self.storage.save(get_example_filename(), get_example_data()) + + def test_load_once(self): + """Test loading data once.""" + assert self.storage.load_once(get_example_filename()) == get_example_data() + + def test_load_stream(self): + """Test loading data as a stream.""" + generator = self.storage.load_stream(get_example_filename()) + assert isinstance(generator, Generator) + assert next(generator) == get_example_data() + + def test_download(self): + """Test downloading data.""" + self.storage.download(get_example_filename(), get_example_filepath()) + + def test_exists(self): + """Test checking if a file exists.""" + assert self.storage.exists(get_example_filename()) + + def test_delete(self): + """Test deleting a file.""" + self.storage.delete(get_example_filename()) diff --git a/api/tests/unit_tests/oss/__mock/local.py b/api/tests/unit_tests/oss/__mock/local.py new file mode 100644 index 0000000000000000000000000000000000000000..95cc06958c61b315cae3efa4421e0f83a0ba3c1a --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/local.py @@ -0,0 +1,57 @@ +import os +import shutil +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from tests.unit_tests.oss.__mock.base import ( + get_example_data, + get_example_filename, + get_example_filepath, + get_example_folder, +) + + +class MockLocalFSClass: + def write_bytes(self, data): + assert data == get_example_data() + + def read_bytes(self): + return get_example_data() + + @staticmethod + def copyfile(src, dst): + assert src == os.path.join(get_example_folder(), get_example_filename()) + assert dst == get_example_filepath() + + @staticmethod + def exists(path): + assert path == os.path.join(get_example_folder(), get_example_filename()) + return True + + @staticmethod + def remove(path): + assert path == os.path.join(get_example_folder(), get_example_filename()) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_local_fs_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Path, "write_bytes", MockLocalFSClass.write_bytes) + monkeypatch.setattr(Path, "read_bytes", MockLocalFSClass.read_bytes) + monkeypatch.setattr(shutil, "copyfile", MockLocalFSClass.copyfile) + monkeypatch.setattr(os.path, "exists", MockLocalFSClass.exists) + monkeypatch.setattr(os, "remove", MockLocalFSClass.remove) + + os.makedirs = MagicMock() + + with patch("builtins.open", mock_open(read_data=get_example_data())): + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..c77c5b08f37d153503edfd6ee07c8cfb3f77a057 --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -0,0 +1,81 @@ +import os +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore + +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, +) + + +class MockTencentCosClass: + def __init__(self, conf, retry=1, session=None): + self.bucket_name = get_example_bucket() + self.key = get_example_filename() + self.content = get_example_data() + self.filepath = get_example_filepath() + self.resp = { + "ETag": "ee8de918d05640145b18f70f4c3aa602", + "Server": "tencent-cos", + "x-cos-hash-crc64ecma": 16749565679157681890, + "x-cos-request-id": "NWU5MDNkYzlfNjRiODJhMDlfMzFmYzhfMTFm****", + } + + def put_object(self, Bucket, Body, Key, EnableMD5=False, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + assert Body == self.content + return self.resp + + def get_object(self, Bucket, Key, KeySimplifyCheck=True, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + + mock_stream_body = MagicMock(StreamBody) + mock_raw_stream = MagicMock() + mock_stream_body.get_raw_stream.return_value = mock_raw_stream + mock_raw_stream.read.return_value = self.content + + mock_stream_body.get_stream_to_file = MagicMock() + + def chunk_generator(chunk_size=2): + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + + mock_stream_body.get_stream.return_value = chunk_generator(chunk_size=4096) + return {"Body": mock_stream_body} + + def object_exists(self, Bucket, Key): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + return True + + def delete_object(self, Bucket, Key, **kwargs): # noqa: N803 + assert Bucket == self.bucket_name + assert Key == self.key + self.resp.update({"x-cos-delete-marker": True}) + return self.resp + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_tencent_cos_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(CosS3Client, "__init__", MockTencentCosClass.__init__) + monkeypatch.setattr(CosS3Client, "put_object", MockTencentCosClass.put_object) + monkeypatch.setattr(CosS3Client, "get_object", MockTencentCosClass.get_object) + monkeypatch.setattr(CosS3Client, "object_exists", MockTencentCosClass.object_exists) + monkeypatch.setattr(CosS3Client, "delete_object", MockTencentCosClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py new file mode 100644 index 0000000000000000000000000000000000000000..88df59f91c307110402f273b2fb5a7c2ee4121cb --- /dev/null +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -0,0 +1,91 @@ +import os +from collections import UserDict +from unittest.mock import MagicMock + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore + +from tests.unit_tests.oss.__mock.base import ( + get_example_bucket, + get_example_data, + get_example_filename, + get_example_filepath, +) + + +class AttrDict(UserDict): + def __getattr__(self, item): + return self.get(item) + + +class MockVolcengineTosClass: + def __init__(self, ak="", sk="", endpoint="", region=""): + self.bucket_name = get_example_bucket() + self.key = get_example_filename() + self.content = get_example_data() + self.filepath = get_example_filepath() + self.resp = AttrDict( + { + "x-tos-server-side-encryption": "kms", + "x-tos-server-side-encryption-kms-key-id": "trn:kms:cn-beijing:****:keyrings/ring-test/keys/key-test", + "x-tos-server-side-encryption-customer-algorithm": "AES256", + "x-tos-version-id": "test", + "x-tos-hash-crc64ecma": 123456, + "request_id": "test", + "headers": { + "x-tos-id-2": "test", + "ETag": "123456", + }, + "status": 200, + } + ) + + def put_object(self, bucket: str, key: str, content=None) -> PutObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + assert content == self.content + return PutObjectOutput(self.resp) + + def get_object(self, bucket: str, key: str) -> GetObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + + get_object_output = MagicMock(GetObjectOutput) + get_object_output.read.return_value = self.content + return get_object_output + + def get_object_to_file(self, bucket: str, key: str, file_path: str): + assert bucket == self.bucket_name + assert key == self.key + assert file_path == self.filepath + + def head_object(self, bucket: str, key: str) -> HeadObjectOutput: + assert bucket == self.bucket_name + assert key == self.key + return HeadObjectOutput(self.resp) + + def delete_object(self, bucket: str, key: str): + assert bucket == self.bucket_name + assert key == self.key + return DeleteObjectOutput(self.resp) + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_volcengine_tos_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(TosClientV2, "__init__", MockVolcengineTosClass.__init__) + monkeypatch.setattr(TosClientV2, "put_object", MockVolcengineTosClass.put_object) + monkeypatch.setattr(TosClientV2, "get_object", MockVolcengineTosClass.get_object) + monkeypatch.setattr(TosClientV2, "get_object_to_file", MockVolcengineTosClass.get_object_to_file) + monkeypatch.setattr(TosClientV2, "head_object", MockVolcengineTosClass.head_object) + monkeypatch.setattr(TosClientV2, "delete_object", MockVolcengineTosClass.delete_object) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/__init__.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..f87a38569092f16555785accf7082b46fa04e787 --- /dev/null +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -0,0 +1,22 @@ +from unittest.mock import patch + +import pytest +from oss2 import Auth # type: ignore + +from extensions.storage.aliyun_oss_storage import AliyunOssStorage +from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_bucket, + get_example_folder, +) + + +class TestAliyunOss(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_aliyun_oss_mock): + """Executed before each test method.""" + with patch.object(Auth, "__init__", return_value=None): + self.storage = AliyunOssStorage() + self.storage.bucket_name = get_example_bucket() + self.storage.folder = get_example_folder() diff --git a/api/tests/unit_tests/oss/opendal/__init__.py b/api/tests/unit_tests/oss/opendal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py new file mode 100644 index 0000000000000000000000000000000000000000..6acec6e579d2a6961dc11c2e90951ccd3a69ec45 --- /dev/null +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -0,0 +1,85 @@ +from collections.abc import Generator +from pathlib import Path + +import pytest + +from extensions.storage.opendal_storage import OpenDALStorage +from tests.unit_tests.oss.__mock.base import ( + get_example_data, + get_example_filename, + get_opendal_bucket, +) + + +class TestOpenDAL: + @pytest.fixture(autouse=True) + def setup_method(self, *args, **kwargs): + """Executed before each test method.""" + self.storage = OpenDALStorage( + scheme="fs", + root=get_opendal_bucket(), + ) + + @pytest.fixture(scope="class", autouse=True) + def teardown_class(self, request): + """Clean up after all tests in the class.""" + + def cleanup(): + folder = Path(get_opendal_bucket()) + if folder.exists() and folder.is_dir(): + for item in folder.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + item.rmdir() + folder.rmdir() + + return cleanup() + + def test_save_and_exists(self): + """Test saving data and checking existence.""" + filename = get_example_filename() + data = get_example_data() + + assert not self.storage.exists(filename) + self.storage.save(filename, data) + assert self.storage.exists(filename) + + def test_load_once(self): + """Test loading data once.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + loaded_data = self.storage.load_once(filename) + assert loaded_data == data + + def test_load_stream(self): + """Test loading data as a stream.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + generator = self.storage.load_stream(filename) + assert isinstance(generator, Generator) + assert next(generator) == data + + def test_download(self): + """Test downloading data to a file.""" + filename = get_example_filename() + filepath = str(Path(get_opendal_bucket()) / filename) + data = get_example_data() + + self.storage.save(filename, data) + self.storage.download(filename, filepath) + + def test_delete(self): + """Test deleting a file.""" + filename = get_example_filename() + data = get_example_data() + + self.storage.save(filename, data) + assert self.storage.exists(filename) + + self.storage.delete(filename) + assert not self.storage.exists(filename) diff --git a/api/tests/unit_tests/oss/tencent_cos/__init__.py b/api/tests/unit_tests/oss/tencent_cos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..d289751800633a7507835af3f9dc98bc43678719 --- /dev/null +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -0,0 +1,20 @@ +from unittest.mock import patch + +import pytest +from qcloud_cos import CosConfig # type: ignore + +from extensions.storage.tencent_cos_storage import TencentCosStorage +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_bucket, +) +from tests.unit_tests.oss.__mock.tencent_cos import setup_tencent_cos_mock + + +class TestTencentCos(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_tencent_cos_mock): + """Executed before each test method.""" + with patch.object(CosConfig, "__init__", return_value=None): + self.storage = TencentCosStorage() + self.storage.bucket_name = get_example_bucket() diff --git a/api/tests/unit_tests/oss/volcengine_tos/__init__.py b/api/tests/unit_tests/oss/volcengine_tos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py new file mode 100644 index 0000000000000000000000000000000000000000..04988e85d85881036fef8669914fbae53abdc582 --- /dev/null +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -0,0 +1,23 @@ +import pytest +from tos import TosClientV2 # type: ignore + +from extensions.storage.volcengine_tos_storage import VolcengineTosStorage +from tests.unit_tests.oss.__mock.base import ( + BaseStorageTest, + get_example_bucket, +) +from tests.unit_tests.oss.__mock.volcengine_tos import setup_volcengine_tos_mock + + +class TestVolcengineTos(BaseStorageTest): + @pytest.fixture(autouse=True) + def setup_method(self, setup_volcengine_tos_mock): + """Executed before each test method.""" + self.storage = VolcengineTosStorage() + self.storage.bucket_name = get_example_bucket() + self.storage.client = TosClientV2( + ak="dify", + sk="dify", + endpoint="https://xxx.volces.com", + region="cn-beijing", + ) diff --git a/api/tests/unit_tests/services/__init__.py b/api/tests/unit_tests/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/services/workflow/__init__.py b/api/tests/unit_tests/services/workflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0916734929a43d9eca84321652ec441f9a34bd --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -0,0 +1,424 @@ +# test for api/services/workflow/workflow_converter.py +import json +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, + VariableEntityType, +) +from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import AppMode +from services.workflow.workflow_converter import WorkflowConverter + + +@pytest.fixture +def default_variables(): + value = [ + VariableEntity( + variable="text_input", + label="text-input", + type=VariableEntityType.TEXT_INPUT, + ), + VariableEntity( + variable="paragraph", + label="paragraph", + type=VariableEntityType.PARAGRAPH, + ), + VariableEntity( + variable="select", + label="select", + type=VariableEntityType.SELECT, + ), + ] + return value + + +def test__convert_to_start_node(default_variables): + # act + result = WorkflowConverter()._convert_to_start_node(default_variables) + + # assert + assert isinstance(result["data"]["variables"][0]["type"], str) + assert result["data"]["variables"][0]["type"] == "text-input" + assert result["data"]["variables"][0]["variable"] == "text_input" + assert result["data"]["variables"][1]["variable"] == "paragraph" + assert result["data"]["variables"][2]["variable"] == "select" + + +def test__convert_to_http_request_node_for_chatbot(default_variables): + """ + Test convert to http request nodes for chatbot + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} + ) + ] + + nodes, _ = workflow_converter._convert_to_http_request_node( + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "{{#sys.query#}}" # for chatbot + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_http_request_node_for_workflow_app(default_variables): + """ + Test convert to http request nodes for workflow app + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} + ) + ] + + nodes, _ = workflow_converter._convert_to_http_request_node( + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "" + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot(): + new_app_mode = AppMode.ADVANCED_CHAT + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["sys", "query"] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model, + } + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app(): + new_app_mode = AppMode.WORKFLOW + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model, + } + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): + new_app_mode = AppMode.ADVANCED_CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [start_node], + "edges": [], # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", + ) + + llm_node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template, + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value + template = prompt_template.simple_prompt_template + for v in default_variables: + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False + + +def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): + new_app_mode = AppMode.ADVANCED_CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [start_node], + "edges": [], # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", + ) + + llm_node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template, + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value + template = prompt_template.simple_prompt_template + for v in default_variables: + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): + new_app_mode = AppMode.ADVANCED_CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [start_node], + "edges": [], # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + + llm_node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template, + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value + assert isinstance(llm_node["data"]["prompt_template"], list) + assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) + template = prompt_template.advanced_chat_prompt_template.messages[0].text + for v in default_variables: + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): + new_app_mode = AppMode.ADVANCED_CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [start_node], + "edges": [], # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = "openai" + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), + ) + + llm_node = workflow_converter._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template, + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value + assert isinstance(llm_node["data"]["prompt_template"], dict) + template = prompt_template.advanced_completion_prompt_template.prompt + for v in default_variables: + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/utils/__init__.py b/api/tests/unit_tests/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..29558a93c242a8c72da340b49ba439d3b8ddc9aa --- /dev/null +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -0,0 +1,123 @@ +from textwrap import dedent + +import pytest + +from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map + + +@pytest.fixture +def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions.yaml").write_text( + dedent( + """\ + - first + - second + # - commented + - third + + - 9999999999999 + - forth + """ + ) + ) + return str(tmp_path) + + +@pytest.fixture +def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + tmp_path.joinpath("example_positions_all_commented.yaml").write_text( + dedent( + """\ + # - commented1 + # - commented2 + - + - + + """ + ) + ) + return str(tmp_path) + + +def test_position_helper(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + assert len(position_map) == 4 + assert position_map == { + "first": 0, + "second": 1, + "third": 2, + "forth": 3, + } + + +def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): + position_map = get_position_map( + folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml" + ) + assert position_map == {} + + +def test_excluded_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = set() + exclude_set = {"9999999999999"} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"] + + +def test_included_position_data(prepare_example_positions_yaml): + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = {"forth", "first"} + exclude_set = {} + + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ["forth", "first"] diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfc97ae63688b9ef2b9243c16222c81e4e7bfe3 --- /dev/null +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -0,0 +1,18 @@ +import pytest + +from core.tools.utils.text_processing_utils import remove_leading_symbols + + +@pytest.mark.parametrize( + ("input_text", "expected_output"), + [ + ("...Hello, World!", "Hello, World!"), + ("。测试中文标点", "测试中文标点"), + ("!@#Test symbols", "Test symbols"), + ("Hello, World!", "Hello, World!"), + ("", ""), + (" ", " "), + ], +) +def test_remove_leading_symbols(input_text, expected_output): + assert remove_leading_symbols(input_text) == expected_output diff --git a/api/tests/unit_tests/utils/yaml/__init__.py b/api/tests/unit_tests/utils/yaml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d645487278a5fc422802bf7f2a7b01f87a671a5 --- /dev/null +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -0,0 +1,83 @@ +from textwrap import dedent + +import pytest +from yaml import YAMLError # type: ignore + +from core.tools.utils.yaml_utils import load_yaml_file + +EXAMPLE_YAML_FILE = "example_yaml.yaml" +INVALID_YAML_FILE = "invalid_yaml.yaml" +NON_EXISTING_YAML_FILE = "non_existing_file.yaml" + + +@pytest.fixture +def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) + file_path.write_text( + dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + empty_key: + """ + ) + ) + return str(file_path) + + +@pytest.fixture +def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: + monkeypatch.chdir(tmp_path) + file_path = tmp_path.joinpath(INVALID_YAML_FILE) + file_path.write_text( + dedent( + """\ + address: + city: Example City + country: Example Country + age: 30 + gender: male + languages: + - Python + - Java + - C++ + """ + ) + ) + return str(file_path) + + +def test_load_yaml_non_existing_file(): + assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} + assert load_yaml_file(file_path="") == {} + + with pytest.raises(FileNotFoundError): + load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) + + +def test_load_valid_yaml_file(prepare_example_yaml_file): + yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) + assert len(yaml_data) > 0 + assert yaml_data["age"] == 30 + assert yaml_data["gender"] == "male" + assert yaml_data["address"]["city"] == "Example City" + assert set(yaml_data["languages"]) == {"Python", "Java", "C++"} + assert yaml_data.get("empty_key") is None + assert yaml_data.get("non_existed_key") is None + + +def test_load_invalid_yaml_file(prepare_invalid_yaml_file): + # yaml syntax error + with pytest.raises(YAMLError): + load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) + + # ignore error + assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}