CatPtain commited on
Commit
20f348c
·
verified ·
1 Parent(s): ee58329

Upload 697 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. api/docker/entrypoint.sh +40 -0
  3. api/events/__init__.py +0 -0
  4. api/events/app_event.py +13 -0
  5. api/events/dataset_event.py +4 -0
  6. api/events/document_event.py +4 -0
  7. api/events/event_handlers/__init__.py +10 -0
  8. api/events/event_handlers/clean_when_dataset_deleted.py +15 -0
  9. api/events/event_handlers/clean_when_document_deleted.py +11 -0
  10. api/events/event_handlers/create_document_index.py +49 -0
  11. api/events/event_handlers/create_installed_app_when_app_created.py +16 -0
  12. api/events/event_handlers/create_site_record_when_app_created.py +26 -0
  13. api/events/event_handlers/deduct_quota_when_message_created.py +53 -0
  14. api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +34 -0
  15. api/events/event_handlers/document_index_event.py +4 -0
  16. api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +68 -0
  17. api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +67 -0
  18. api/events/event_handlers/update_provider_last_used_at_when_message_created.py +21 -0
  19. api/events/message_event.py +4 -0
  20. api/events/tenant_event.py +7 -0
  21. api/extensions/__init__.py +0 -0
  22. api/extensions/ext_app_metrics.py +67 -0
  23. api/extensions/ext_blueprints.py +48 -0
  24. api/extensions/ext_celery.py +104 -0
  25. api/extensions/ext_code_based_extension.py +9 -0
  26. api/extensions/ext_commands.py +29 -0
  27. api/extensions/ext_compress.py +13 -0
  28. api/extensions/ext_database.py +6 -0
  29. api/extensions/ext_hosting_provider.py +10 -0
  30. api/extensions/ext_import_modules.py +5 -0
  31. api/extensions/ext_logging.py +71 -0
  32. api/extensions/ext_login.py +62 -0
  33. api/extensions/ext_mail.py +97 -0
  34. api/extensions/ext_migrate.py +9 -0
  35. api/extensions/ext_proxy_fix.py +9 -0
  36. api/extensions/ext_redis.py +98 -0
  37. api/extensions/ext_sentry.py +40 -0
  38. api/extensions/ext_set_secretkey.py +6 -0
  39. api/extensions/ext_storage.py +138 -0
  40. api/extensions/ext_timezone.py +11 -0
  41. api/extensions/ext_warnings.py +7 -0
  42. api/extensions/storage/aliyun_oss_storage.py +54 -0
  43. api/extensions/storage/aws_s3_storage.py +91 -0
  44. api/extensions/storage/azure_blob_storage.py +84 -0
  45. api/extensions/storage/baidu_obs_storage.py +57 -0
  46. api/extensions/storage/base_storage.py +32 -0
  47. api/extensions/storage/google_cloud_storage.py +60 -0
  48. api/extensions/storage/huawei_obs_storage.py +51 -0
  49. api/extensions/storage/opendal_storage.py +89 -0
  50. api/extensions/storage/oracle_oci_storage.py +59 -0
.gitattributes CHANGED
@@ -5,3 +5,4 @@
5
  # them.
6
 
7
  *.sh text eol=lf
 
 
5
  # them.
6
 
7
  *.sh text eol=lf
8
+ api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
api/docker/entrypoint.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+
5
+ if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
6
+ echo "Running migrations"
7
+ flask upgrade-db
8
+ fi
9
+
10
+ if [[ "${MODE}" == "worker" ]]; then
11
+
12
+ # Get the number of available CPU cores
13
+ if [ "${CELERY_AUTO_SCALE,,}" = "true" ]; then
14
+ # Set MAX_WORKERS to the number of available cores if not specified
15
+ AVAILABLE_CORES=$(nproc)
16
+ MAX_WORKERS=${CELERY_MAX_WORKERS:-$AVAILABLE_CORES}
17
+ MIN_WORKERS=${CELERY_MIN_WORKERS:-1}
18
+ CONCURRENCY_OPTION="--autoscale=${MAX_WORKERS},${MIN_WORKERS}"
19
+ else
20
+ CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
21
+ fi
22
+
23
+ exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \
24
+ -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
25
+
26
+ elif [[ "${MODE}" == "beat" ]]; then
27
+ exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
28
+ else
29
+ if [[ "${DEBUG}" == "true" ]]; then
30
+ exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
31
+ else
32
+ exec gunicorn \
33
+ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
34
+ --workers ${SERVER_WORKER_AMOUNT:-1} \
35
+ --worker-class ${SERVER_WORKER_CLASS:-gevent} \
36
+ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
37
+ --timeout ${GUNICORN_TIMEOUT:-200} \
38
+ app:app
39
+ fi
40
+ fi
api/events/__init__.py ADDED
File without changes
api/events/app_event.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: app
4
+ app_was_created = signal("app-was-created")
5
+
6
+ # sender: app, kwargs: app_model_config
7
+ app_model_config_was_updated = signal("app-model-config-was-updated")
8
+
9
+ # sender: app, kwargs: published_workflow
10
+ app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
11
+
12
+ # sender: app, kwargs: synced_draft_workflow
13
+ app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")
api/events/dataset_event.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: dataset
4
+ dataset_was_deleted = signal("dataset-was-deleted")
api/events/document_event.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: document
4
+ document_was_deleted = signal("document-was-deleted")
api/events/event_handlers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .clean_when_dataset_deleted import handle
2
+ from .clean_when_document_deleted import handle
3
+ from .create_document_index import handle
4
+ from .create_installed_app_when_app_created import handle
5
+ from .create_site_record_when_app_created import handle
6
+ from .deduct_quota_when_message_created import handle
7
+ from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
8
+ from .update_app_dataset_join_when_app_model_config_updated import handle
9
+ from .update_app_dataset_join_when_app_published_workflow_updated import handle
10
+ from .update_provider_last_used_at_when_message_created import handle
api/events/event_handlers/clean_when_dataset_deleted.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from events.dataset_event import dataset_was_deleted
2
+ from tasks.clean_dataset_task import clean_dataset_task
3
+
4
+
5
+ @dataset_was_deleted.connect
6
+ def handle(sender, **kwargs):
7
+ dataset = sender
8
+ clean_dataset_task.delay(
9
+ dataset.id,
10
+ dataset.tenant_id,
11
+ dataset.indexing_technique,
12
+ dataset.index_struct,
13
+ dataset.collection_binding_id,
14
+ dataset.doc_form,
15
+ )
api/events/event_handlers/clean_when_document_deleted.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from events.document_event import document_was_deleted
2
+ from tasks.clean_document_task import clean_document_task
3
+
4
+
5
+ @document_was_deleted.connect
6
+ def handle(sender, **kwargs):
7
+ document_id = sender
8
+ dataset_id = kwargs.get("dataset_id")
9
+ doc_form = kwargs.get("doc_form")
10
+ file_id = kwargs.get("file_id")
11
+ clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
api/events/event_handlers/create_document_index.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ import click
6
+ from werkzeug.exceptions import NotFound
7
+
8
+ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
9
+ from events.event_handlers.document_index_event import document_index_created
10
+ from extensions.ext_database import db
11
+ from models.dataset import Document
12
+
13
+
14
+ @document_index_created.connect
15
+ def handle(sender, **kwargs):
16
+ dataset_id = sender
17
+ document_ids = kwargs.get("document_ids", [])
18
+ documents = []
19
+ start_at = time.perf_counter()
20
+ for document_id in document_ids:
21
+ logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
22
+
23
+ document = (
24
+ db.session.query(Document)
25
+ .filter(
26
+ Document.id == document_id,
27
+ Document.dataset_id == dataset_id,
28
+ )
29
+ .first()
30
+ )
31
+
32
+ if not document:
33
+ raise NotFound("Document not found")
34
+
35
+ document.indexing_status = "parsing"
36
+ document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
37
+ documents.append(document)
38
+ db.session.add(document)
39
+ db.session.commit()
40
+
41
+ try:
42
+ indexing_runner = IndexingRunner()
43
+ indexing_runner.run(documents)
44
+ end_at = time.perf_counter()
45
+ logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
46
+ except DocumentIsPausedError as ex:
47
+ logging.info(click.style(str(ex), fg="yellow"))
48
+ except Exception:
49
+ pass
api/events/event_handlers/create_installed_app_when_app_created.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from events.app_event import app_was_created
2
+ from extensions.ext_database import db
3
+ from models.model import InstalledApp
4
+
5
+
6
+ @app_was_created.connect
7
+ def handle(sender, **kwargs):
8
+ """Create an installed app when an app is created."""
9
+ app = sender
10
+ installed_app = InstalledApp(
11
+ tenant_id=app.tenant_id,
12
+ app_id=app.id,
13
+ app_owner_tenant_id=app.tenant_id,
14
+ )
15
+ db.session.add(installed_app)
16
+ db.session.commit()
api/events/event_handlers/create_site_record_when_app_created.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from events.app_event import app_was_created
2
+ from extensions.ext_database import db
3
+ from models.model import Site
4
+
5
+
6
+ @app_was_created.connect
7
+ def handle(sender, **kwargs):
8
+ """Create site record when an app is created."""
9
+ app = sender
10
+ account = kwargs.get("account")
11
+ if account is not None:
12
+ site = Site(
13
+ app_id=app.id,
14
+ title=app.name,
15
+ icon_type=app.icon_type,
16
+ icon=app.icon,
17
+ icon_background=app.icon_background,
18
+ default_language=account.interface_language,
19
+ customize_token_strategy="not_allow",
20
+ code=Site.generate_code(16),
21
+ created_by=app.created_by,
22
+ updated_by=app.updated_by,
23
+ )
24
+
25
+ db.session.add(site)
26
+ db.session.commit()
api/events/event_handlers/deduct_quota_when_message_created.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
3
+ from core.entities.provider_entities import QuotaUnit
4
+ from events.message_event import message_was_created
5
+ from extensions.ext_database import db
6
+ from models.provider import Provider, ProviderType
7
+
8
+
9
+ @message_was_created.connect
10
+ def handle(sender, **kwargs):
11
+ message = sender
12
+ application_generate_entity = kwargs.get("application_generate_entity")
13
+
14
+ if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
15
+ return
16
+
17
+ model_config = application_generate_entity.model_conf
18
+ provider_model_bundle = model_config.provider_model_bundle
19
+ provider_configuration = provider_model_bundle.configuration
20
+
21
+ if provider_configuration.using_provider_type != ProviderType.SYSTEM:
22
+ return
23
+
24
+ system_configuration = provider_configuration.system_configuration
25
+
26
+ quota_unit = None
27
+ for quota_configuration in system_configuration.quota_configurations:
28
+ if quota_configuration.quota_type == system_configuration.current_quota_type:
29
+ quota_unit = quota_configuration.quota_unit
30
+
31
+ if quota_configuration.quota_limit == -1:
32
+ return
33
+
34
+ break
35
+
36
+ used_quota = None
37
+ if quota_unit:
38
+ if quota_unit == QuotaUnit.TOKENS:
39
+ used_quota = message.message_tokens + message.answer_tokens
40
+ elif quota_unit == QuotaUnit.CREDITS:
41
+ used_quota = dify_config.get_model_credits(model_config.model)
42
+ else:
43
+ used_quota = 1
44
+
45
+ if used_quota is not None and system_configuration.current_quota_type is not None:
46
+ db.session.query(Provider).filter(
47
+ Provider.tenant_id == application_generate_entity.app_config.tenant_id,
48
+ Provider.provider_name == model_config.provider,
49
+ Provider.provider_type == ProviderType.SYSTEM.value,
50
+ Provider.quota_type == system_configuration.current_quota_type.value,
51
+ Provider.quota_limit > Provider.quota_used,
52
+ ).update({"quota_used": Provider.quota_used + used_quota})
53
+ db.session.commit()
api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.tools.tool_manager import ToolManager
2
+ from core.tools.utils.configuration import ToolParameterConfigurationManager
3
+ from core.workflow.nodes import NodeType
4
+ from core.workflow.nodes.tool.entities import ToolEntity
5
+ from events.app_event import app_draft_workflow_was_synced
6
+
7
+
8
+ @app_draft_workflow_was_synced.connect
9
+ def handle(sender, **kwargs):
10
+ app = sender
11
+ synced_draft_workflow = kwargs.get("synced_draft_workflow")
12
+ if synced_draft_workflow is None:
13
+ return
14
+ for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
15
+ if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
16
+ try:
17
+ tool_entity = ToolEntity(**node_data["data"])
18
+ tool_runtime = ToolManager.get_tool_runtime(
19
+ provider_type=tool_entity.provider_type,
20
+ provider_id=tool_entity.provider_id,
21
+ tool_name=tool_entity.tool_name,
22
+ tenant_id=app.tenant_id,
23
+ )
24
+ manager = ToolParameterConfigurationManager(
25
+ tenant_id=app.tenant_id,
26
+ tool_runtime=tool_runtime,
27
+ provider_name=tool_entity.provider_name,
28
+ provider_type=tool_entity.provider_type,
29
+ identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
30
+ )
31
+ manager.delete_tool_parameters_cache()
32
+ except:
33
+ # tool dose not exist
34
+ pass
api/events/event_handlers/document_index_event.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: document
4
+ document_index_created = signal("document-index-created")
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from events.app_event import app_model_config_was_updated
2
+ from extensions.ext_database import db
3
+ from models.dataset import AppDatasetJoin
4
+ from models.model import AppModelConfig
5
+
6
+
7
+ @app_model_config_was_updated.connect
8
+ def handle(sender, **kwargs):
9
+ app = sender
10
+ app_model_config = kwargs.get("app_model_config")
11
+ if app_model_config is None:
12
+ return
13
+
14
+ dataset_ids = get_dataset_ids_from_model_config(app_model_config)
15
+
16
+ app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
17
+
18
+ removed_dataset_ids: set[str] = set()
19
+ if not app_dataset_joins:
20
+ added_dataset_ids = dataset_ids
21
+ else:
22
+ old_dataset_ids: set[str] = set()
23
+ old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
24
+
25
+ added_dataset_ids = dataset_ids - old_dataset_ids
26
+ removed_dataset_ids = old_dataset_ids - dataset_ids
27
+
28
+ if removed_dataset_ids:
29
+ for dataset_id in removed_dataset_ids:
30
+ db.session.query(AppDatasetJoin).filter(
31
+ AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
32
+ ).delete()
33
+
34
+ if added_dataset_ids:
35
+ for dataset_id in added_dataset_ids:
36
+ app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
37
+ db.session.add(app_dataset_join)
38
+
39
+ db.session.commit()
40
+
41
+
42
+ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]:
43
+ dataset_ids: set[str] = set()
44
+ if not app_model_config:
45
+ return dataset_ids
46
+
47
+ agent_mode = app_model_config.agent_mode_dict
48
+
49
+ tools = agent_mode.get("tools", []) or []
50
+ for tool in tools:
51
+ if len(list(tool.keys())) != 1:
52
+ continue
53
+
54
+ tool_type = list(tool.keys())[0]
55
+ tool_config = list(tool.values())[0]
56
+ if tool_type == "dataset":
57
+ dataset_ids.add(tool_config.get("id"))
58
+
59
+ # get dataset from dataset_configs
60
+ dataset_configs = app_model_config.dataset_configs_dict
61
+ datasets = dataset_configs.get("datasets", {}) or {}
62
+ for dataset in datasets.get("datasets", []) or []:
63
+ keys = list(dataset.keys())
64
+ if len(keys) == 1 and keys[0] == "dataset":
65
+ if dataset["dataset"].get("id"):
66
+ dataset_ids.add(dataset["dataset"].get("id"))
67
+
68
+ return dataset_ids
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import cast
2
+
3
+ from core.workflow.nodes import NodeType
4
+ from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
5
+ from events.app_event import app_published_workflow_was_updated
6
+ from extensions.ext_database import db
7
+ from models.dataset import AppDatasetJoin
8
+ from models.workflow import Workflow
9
+
10
+
11
+ @app_published_workflow_was_updated.connect
12
+ def handle(sender, **kwargs):
13
+ app = sender
14
+ published_workflow = kwargs.get("published_workflow")
15
+ published_workflow = cast(Workflow, published_workflow)
16
+
17
+ dataset_ids = get_dataset_ids_from_workflow(published_workflow)
18
+ app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
19
+
20
+ removed_dataset_ids: set[str] = set()
21
+ if not app_dataset_joins:
22
+ added_dataset_ids = dataset_ids
23
+ else:
24
+ old_dataset_ids: set[str] = set()
25
+ old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
26
+
27
+ added_dataset_ids = dataset_ids - old_dataset_ids
28
+ removed_dataset_ids = old_dataset_ids - dataset_ids
29
+
30
+ if removed_dataset_ids:
31
+ for dataset_id in removed_dataset_ids:
32
+ db.session.query(AppDatasetJoin).filter(
33
+ AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
34
+ ).delete()
35
+
36
+ if added_dataset_ids:
37
+ for dataset_id in added_dataset_ids:
38
+ app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
39
+ db.session.add(app_dataset_join)
40
+
41
+ db.session.commit()
42
+
43
+
44
+ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
45
+ dataset_ids: set[str] = set()
46
+ graph = published_workflow.graph_dict
47
+ if not graph:
48
+ return dataset_ids
49
+
50
+ nodes = graph.get("nodes", [])
51
+
52
+ # fetch all knowledge retrieval nodes
53
+ knowledge_retrieval_nodes = [
54
+ node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
55
+ ]
56
+
57
+ if not knowledge_retrieval_nodes:
58
+ return dataset_ids
59
+
60
+ for node in knowledge_retrieval_nodes:
61
+ try:
62
+ node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
63
+ dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
64
+ except Exception as e:
65
+ continue
66
+
67
+ return dataset_ids
api/events/event_handlers/update_provider_last_used_at_when_message_created.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import UTC, datetime
2
+
3
+ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
4
+ from events.message_event import message_was_created
5
+ from extensions.ext_database import db
6
+ from models.provider import Provider
7
+
8
+
9
+ @message_was_created.connect
10
+ def handle(sender, **kwargs):
11
+ message = sender
12
+ application_generate_entity = kwargs.get("application_generate_entity")
13
+
14
+ if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
15
+ return
16
+
17
+ db.session.query(Provider).filter(
18
+ Provider.tenant_id == application_generate_entity.app_config.tenant_id,
19
+ Provider.provider_name == application_generate_entity.model_conf.provider,
20
+ ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
21
+ db.session.commit()
api/events/message_event.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: message, kwargs: conversation
4
+ message_was_created = signal("message-was-created")
api/events/tenant_event.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from blinker import signal
2
+
3
+ # sender: tenant
4
+ tenant_was_created = signal("tenant-was-created")
5
+
6
+ # sender: tenant
7
+ tenant_was_updated = signal("tenant-was-updated")
api/extensions/__init__.py ADDED
File without changes
api/extensions/ext_app_metrics.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import threading
4
+
5
+ from flask import Response
6
+
7
+ from configs import dify_config
8
+ from dify_app import DifyApp
9
+
10
+
11
+ def init_app(app: DifyApp):
12
+ @app.after_request
13
+ def after_request(response):
14
+ """Add Version headers to the response."""
15
+ response.headers.add("X-Version", dify_config.CURRENT_VERSION)
16
+ response.headers.add("X-Env", dify_config.DEPLOY_ENV)
17
+ return response
18
+
19
+ @app.route("/health")
20
+ def health():
21
+ return Response(
22
+ json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
23
+ status=200,
24
+ content_type="application/json",
25
+ )
26
+
27
+ @app.route("/threads")
28
+ def threads():
29
+ num_threads = threading.active_count()
30
+ threads = threading.enumerate()
31
+
32
+ thread_list = []
33
+ for thread in threads:
34
+ thread_name = thread.name
35
+ thread_id = thread.ident
36
+ is_alive = thread.is_alive()
37
+
38
+ thread_list.append(
39
+ {
40
+ "name": thread_name,
41
+ "id": thread_id,
42
+ "is_alive": is_alive,
43
+ }
44
+ )
45
+
46
+ return {
47
+ "pid": os.getpid(),
48
+ "thread_num": num_threads,
49
+ "threads": thread_list,
50
+ }
51
+
52
+ @app.route("/db-pool-stat")
53
+ def pool_stat():
54
+ from extensions.ext_database import db
55
+
56
+ engine = db.engine
57
+ # TODO: Fix the type error
58
+ # FIXME maybe its sqlalchemy issue
59
+ return {
60
+ "pid": os.getpid(),
61
+ "pool_size": engine.pool.size(), # type: ignore
62
+ "checked_in_connections": engine.pool.checkedin(), # type: ignore
63
+ "checked_out_connections": engine.pool.checkedout(), # type: ignore
64
+ "overflow_connections": engine.pool.overflow(), # type: ignore
65
+ "connection_timeout": engine.pool.timeout(), # type: ignore
66
+ "recycle_time": db.engine.pool._recycle, # type: ignore
67
+ }
api/extensions/ext_blueprints.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ # register blueprint routers
7
+
8
+ from flask_cors import CORS # type: ignore
9
+
10
+ from controllers.console import bp as console_app_bp
11
+ from controllers.files import bp as files_bp
12
+ from controllers.inner_api import bp as inner_api_bp
13
+ from controllers.service_api import bp as service_api_bp
14
+ from controllers.web import bp as web_bp
15
+
16
+ CORS(
17
+ service_api_bp,
18
+ allow_headers=["Content-Type", "Authorization", "X-App-Code"],
19
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
20
+ )
21
+ app.register_blueprint(service_api_bp)
22
+
23
+ CORS(
24
+ web_bp,
25
+ resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
26
+ supports_credentials=True,
27
+ allow_headers=["Content-Type", "Authorization", "X-App-Code"],
28
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
29
+ expose_headers=["X-Version", "X-Env"],
30
+ )
31
+
32
+ app.register_blueprint(web_bp)
33
+
34
+ CORS(
35
+ console_app_bp,
36
+ resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
37
+ supports_credentials=True,
38
+ allow_headers=["Content-Type", "Authorization"],
39
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
40
+ expose_headers=["X-Version", "X-Env"],
41
+ )
42
+
43
+ app.register_blueprint(console_app_bp)
44
+
45
+ CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
46
+ app.register_blueprint(files_bp)
47
+
48
+ app.register_blueprint(inner_api_bp)
api/extensions/ext_celery.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+
3
+ import pytz
4
+ from celery import Celery, Task # type: ignore
5
+ from celery.schedules import crontab # type: ignore
6
+
7
+ from configs import dify_config
8
+ from dify_app import DifyApp
9
+
10
+
11
+ def init_app(app: DifyApp) -> Celery:
12
+ class FlaskTask(Task):
13
+ def __call__(self, *args: object, **kwargs: object) -> object:
14
+ with app.app_context():
15
+ return self.run(*args, **kwargs)
16
+
17
+ broker_transport_options = {}
18
+
19
+ if dify_config.CELERY_USE_SENTINEL:
20
+ broker_transport_options = {
21
+ "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
22
+ "sentinel_kwargs": {
23
+ "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
24
+ },
25
+ }
26
+
27
+ celery_app = Celery(
28
+ app.name,
29
+ task_cls=FlaskTask,
30
+ broker=dify_config.CELERY_BROKER_URL,
31
+ backend=dify_config.CELERY_BACKEND,
32
+ task_ignore_result=True,
33
+ )
34
+
35
+ # Add SSL options to the Celery configuration
36
+ ssl_options = {
37
+ "ssl_cert_reqs": None,
38
+ "ssl_ca_certs": None,
39
+ "ssl_certfile": None,
40
+ "ssl_keyfile": None,
41
+ }
42
+
43
+ celery_app.conf.update(
44
+ result_backend=dify_config.CELERY_RESULT_BACKEND,
45
+ broker_transport_options=broker_transport_options,
46
+ broker_connection_retry_on_startup=True,
47
+ worker_log_format=dify_config.LOG_FORMAT,
48
+ worker_task_log_format=dify_config.LOG_FORMAT,
49
+ worker_hijack_root_logger=False,
50
+ timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
51
+ )
52
+
53
+ if dify_config.BROKER_USE_SSL:
54
+ celery_app.conf.update(
55
+ broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
56
+ )
57
+
58
+ if dify_config.LOG_FILE:
59
+ celery_app.conf.update(
60
+ worker_logfile=dify_config.LOG_FILE,
61
+ )
62
+
63
+ celery_app.set_default()
64
+ app.extensions["celery"] = celery_app
65
+
66
+ imports = [
67
+ "schedule.clean_embedding_cache_task",
68
+ "schedule.clean_unused_datasets_task",
69
+ "schedule.create_tidb_serverless_task",
70
+ "schedule.update_tidb_serverless_status_task",
71
+ "schedule.clean_messages",
72
+ "schedule.mail_clean_document_notify_task",
73
+ ]
74
+ day = dify_config.CELERY_BEAT_SCHEDULER_TIME
75
+ beat_schedule = {
76
+ "clean_embedding_cache_task": {
77
+ "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
78
+ "schedule": timedelta(days=day),
79
+ },
80
+ "clean_unused_datasets_task": {
81
+ "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
82
+ "schedule": timedelta(days=day),
83
+ },
84
+ "create_tidb_serverless_task": {
85
+ "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
86
+ "schedule": crontab(minute="0", hour="*"),
87
+ },
88
+ "update_tidb_serverless_status_task": {
89
+ "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
90
+ "schedule": timedelta(minutes=10),
91
+ },
92
+ "clean_messages": {
93
+ "task": "schedule.clean_messages.clean_messages",
94
+ "schedule": timedelta(days=day),
95
+ },
96
+ # every Monday
97
+ "mail_clean_document_notify_task": {
98
+ "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
99
+ "schedule": crontab(minute="0", hour="10", day_of_week="1"),
100
+ },
101
+ }
102
+ celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
103
+
104
+ return celery_app
api/extensions/ext_code_based_extension.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from core.extension.extension import Extension
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ code_based_extension.init()
7
+
8
+
9
+ code_based_extension = Extension()
api/extensions/ext_commands.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dify_app import DifyApp
2
+
3
+
4
+ def init_app(app: DifyApp):
5
+ from commands import (
6
+ add_qdrant_doc_id_index,
7
+ convert_to_agent_apps,
8
+ create_tenant,
9
+ fix_app_site_missing,
10
+ reset_email,
11
+ reset_encrypt_key_pair,
12
+ reset_password,
13
+ upgrade_db,
14
+ vdb_migrate,
15
+ )
16
+
17
+ cmds_to_register = [
18
+ reset_password,
19
+ reset_email,
20
+ reset_encrypt_key_pair,
21
+ vdb_migrate,
22
+ convert_to_agent_apps,
23
+ add_qdrant_doc_id_index,
24
+ create_tenant,
25
+ upgrade_db,
26
+ fix_app_site_missing,
27
+ ]
28
+ for cmd in cmds_to_register:
29
+ app.cli.add_command(cmd)
api/extensions/ext_compress.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def is_enabled() -> bool:
6
+ return dify_config.API_COMPRESSION_ENABLED
7
+
8
+
9
+ def init_app(app: DifyApp):
10
+ from flask_compress import Compress # type: ignore
11
+
12
+ compress = Compress()
13
+ compress.init_app(app)
api/extensions/ext_database.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from dify_app import DifyApp
2
+ from models import db
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ db.init_app(app)
api/extensions/ext_hosting_provider.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.hosting_configuration import HostingConfiguration
2
+
3
+ hosting_configuration = HostingConfiguration()
4
+
5
+
6
+ from dify_app import DifyApp
7
+
8
+
9
+ def init_app(app: DifyApp):
10
+ hosting_configuration.init_app(app)
api/extensions/ext_import_modules.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from dify_app import DifyApp
2
+
3
+
4
+ def init_app(app: DifyApp):
5
+ from events import event_handlers # noqa: F401
api/extensions/ext_logging.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import uuid
5
+ from logging.handlers import RotatingFileHandler
6
+
7
+ import flask
8
+
9
+ from configs import dify_config
10
+ from dify_app import DifyApp
11
+
12
+
13
+ def init_app(app: DifyApp):
14
+ log_handlers: list[logging.Handler] = []
15
+ log_file = dify_config.LOG_FILE
16
+ if log_file:
17
+ log_dir = os.path.dirname(log_file)
18
+ os.makedirs(log_dir, exist_ok=True)
19
+ log_handlers.append(
20
+ RotatingFileHandler(
21
+ filename=log_file,
22
+ maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024,
23
+ backupCount=dify_config.LOG_FILE_BACKUP_COUNT,
24
+ )
25
+ )
26
+
27
+ # Always add StreamHandler to log to console
28
+ sh = logging.StreamHandler(sys.stdout)
29
+ sh.addFilter(RequestIdFilter())
30
+ log_handlers.append(sh)
31
+
32
+ logging.basicConfig(
33
+ level=dify_config.LOG_LEVEL,
34
+ format=dify_config.LOG_FORMAT,
35
+ datefmt=dify_config.LOG_DATEFORMAT,
36
+ handlers=log_handlers,
37
+ force=True,
38
+ )
39
+ log_tz = dify_config.LOG_TZ
40
+ if log_tz:
41
+ from datetime import datetime
42
+
43
+ import pytz
44
+
45
+ timezone = pytz.timezone(log_tz)
46
+
47
+ def time_converter(seconds):
48
+ return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
49
+
50
+ for handler in logging.root.handlers:
51
+ if handler.formatter:
52
+ handler.formatter.converter = time_converter
53
+
54
+
55
+ def get_request_id():
56
+ if getattr(flask.g, "request_id", None):
57
+ return flask.g.request_id
58
+
59
+ new_uuid = uuid.uuid4().hex[:10]
60
+ flask.g.request_id = new_uuid
61
+
62
+ return new_uuid
63
+
64
+
65
+ class RequestIdFilter(logging.Filter):
66
+ # This is a logging filter that makes the request ID available for use in
67
+ # the logging format. Note that we're checking if we're in a request
68
+ # context, as we may want to log things before Flask is fully loaded.
69
+ def filter(self, record):
70
+ record.req_id = get_request_id() if flask.has_request_context() else ""
71
+ return True
api/extensions/ext_login.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import flask_login # type: ignore
4
+ from flask import Response, request
5
+ from flask_login import user_loaded_from_request, user_logged_in
6
+ from werkzeug.exceptions import Unauthorized
7
+
8
+ import contexts
9
+ from dify_app import DifyApp
10
+ from libs.passport import PassportService
11
+ from services.account_service import AccountService
12
+
13
+ login_manager = flask_login.LoginManager()
14
+
15
+
16
+ # Flask-Login configuration
17
+ @login_manager.request_loader
18
+ def load_user_from_request(request_from_flask_login):
19
+ """Load user based on the request."""
20
+ if request.blueprint not in {"console", "inner_api"}:
21
+ return None
22
+ # Check if the user_id contains a dot, indicating the old format
23
+ auth_header = request.headers.get("Authorization", "")
24
+ if not auth_header:
25
+ auth_token = request.args.get("_token")
26
+ if not auth_token:
27
+ raise Unauthorized("Invalid Authorization token.")
28
+ else:
29
+ if " " not in auth_header:
30
+ raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
31
+ auth_scheme, auth_token = auth_header.split(None, 1)
32
+ auth_scheme = auth_scheme.lower()
33
+ if auth_scheme != "bearer":
34
+ raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
35
+
36
+ decoded = PassportService().verify(auth_token)
37
+ user_id = decoded.get("user_id")
38
+
39
+ logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
40
+ return logged_in_account
41
+
42
+
43
+ @user_logged_in.connect
44
+ @user_loaded_from_request.connect
45
+ def on_user_logged_in(_sender, user):
46
+ """Called when a user logged in."""
47
+ if user:
48
+ contexts.tenant_id.set(user.current_tenant_id)
49
+
50
+
51
+ @login_manager.unauthorized_handler
52
+ def unauthorized_handler():
53
+ """Handle unauthorized requests."""
54
+ return Response(
55
+ json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
56
+ status=401,
57
+ content_type="application/json",
58
+ )
59
+
60
+
61
+ def init_app(app: DifyApp):
62
+ login_manager.init_app(app)
api/extensions/ext_mail.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from flask import Flask
5
+
6
+ from configs import dify_config
7
+ from dify_app import DifyApp
8
+
9
+
10
+ class Mail:
11
+ def __init__(self):
12
+ self._client = None
13
+ self._default_send_from = None
14
+
15
+ def is_inited(self) -> bool:
16
+ return self._client is not None
17
+
18
+ def init_app(self, app: Flask):
19
+ mail_type = dify_config.MAIL_TYPE
20
+ if not mail_type:
21
+ logging.warning("MAIL_TYPE is not set")
22
+ return
23
+
24
+ if dify_config.MAIL_DEFAULT_SEND_FROM:
25
+ self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM
26
+
27
+ match mail_type:
28
+ case "resend":
29
+ import resend # type: ignore
30
+
31
+ api_key = dify_config.RESEND_API_KEY
32
+ if not api_key:
33
+ raise ValueError("RESEND_API_KEY is not set")
34
+
35
+ api_url = dify_config.RESEND_API_URL
36
+ if api_url:
37
+ resend.api_url = api_url
38
+
39
+ resend.api_key = api_key
40
+ self._client = resend.Emails
41
+ case "smtp":
42
+ from libs.smtp import SMTPClient
43
+
44
+ if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
45
+ raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
46
+ if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
47
+ raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
48
+ self._client = SMTPClient(
49
+ server=dify_config.SMTP_SERVER,
50
+ port=dify_config.SMTP_PORT,
51
+ username=dify_config.SMTP_USERNAME or "",
52
+ password=dify_config.SMTP_PASSWORD or "",
53
+ _from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
54
+ use_tls=dify_config.SMTP_USE_TLS,
55
+ opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
56
+ )
57
+ case _:
58
+ raise ValueError("Unsupported mail type {}".format(mail_type))
59
+
60
+ def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
61
+ if not self._client:
62
+ raise ValueError("Mail client is not initialized")
63
+
64
+ if not from_ and self._default_send_from:
65
+ from_ = self._default_send_from
66
+
67
+ if not from_:
68
+ raise ValueError("mail from is not set")
69
+
70
+ if not to:
71
+ raise ValueError("mail to is not set")
72
+
73
+ if not subject:
74
+ raise ValueError("mail subject is not set")
75
+
76
+ if not html:
77
+ raise ValueError("mail html is not set")
78
+
79
+ self._client.send(
80
+ {
81
+ "from": from_,
82
+ "to": to,
83
+ "subject": subject,
84
+ "html": html,
85
+ }
86
+ )
87
+
88
+
89
+ def is_enabled() -> bool:
90
+ return dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
91
+
92
+
93
+ def init_app(app: DifyApp):
94
+ mail.init_app(app)
95
+
96
+
97
+ mail = Mail()
api/extensions/ext_migrate.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dify_app import DifyApp
2
+
3
+
4
+ def init_app(app: DifyApp):
5
+ import flask_migrate # type: ignore
6
+
7
+ from extensions.ext_database import db
8
+
9
+ flask_migrate.Migrate(app, db)
api/extensions/ext_proxy_fix.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
7
+ from werkzeug.middleware.proxy_fix import ProxyFix
8
+
9
+ app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
api/extensions/ext_redis.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import redis
4
+ from redis.cluster import ClusterNode, RedisCluster
5
+ from redis.connection import Connection, SSLConnection
6
+ from redis.sentinel import Sentinel
7
+
8
+ from configs import dify_config
9
+ from dify_app import DifyApp
10
+
11
+
12
+ class RedisClientWrapper:
13
+ """
14
+ A wrapper class for the Redis client that addresses the issue where the global
15
+ `redis_client` variable cannot be updated when a new Redis instance is returned
16
+ by Sentinel.
17
+
18
+ This class allows for deferred initialization of the Redis client, enabling the
19
+ client to be re-initialized with a new instance when necessary. This is particularly
20
+ useful in scenarios where the Redis instance may change dynamically, such as during
21
+ a failover in a Sentinel-managed Redis setup.
22
+
23
+ Attributes:
24
+ _client (redis.Redis): The actual Redis client instance. It remains None until
25
+ initialized with the `initialize` method.
26
+
27
+ Methods:
28
+ initialize(client): Initializes the Redis client if it hasn't been initialized already.
29
+ __getattr__(item): Delegates attribute access to the Redis client, raising an error
30
+ if the client is not initialized.
31
+ """
32
+
33
+ def __init__(self):
34
+ self._client = None
35
+
36
+ def initialize(self, client):
37
+ if self._client is None:
38
+ self._client = client
39
+
40
+ def __getattr__(self, item):
41
+ if self._client is None:
42
+ raise RuntimeError("Redis client is not initialized. Call init_app first.")
43
+ return getattr(self._client, item)
44
+
45
+
46
+ redis_client = RedisClientWrapper()
47
+
48
+
49
+ def init_app(app: DifyApp):
50
+ global redis_client
51
+ connection_class: type[Union[Connection, SSLConnection]] = Connection
52
+ if dify_config.REDIS_USE_SSL:
53
+ connection_class = SSLConnection
54
+
55
+ redis_params: dict[str, Any] = {
56
+ "username": dify_config.REDIS_USERNAME,
57
+ "password": dify_config.REDIS_PASSWORD,
58
+ "db": dify_config.REDIS_DB,
59
+ "encoding": "utf-8",
60
+ "encoding_errors": "strict",
61
+ "decode_responses": False,
62
+ }
63
+
64
+ if dify_config.REDIS_USE_SENTINEL:
65
+ assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
66
+ sentinel_hosts = [
67
+ (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
68
+ ]
69
+ sentinel = Sentinel(
70
+ sentinel_hosts,
71
+ sentinel_kwargs={
72
+ "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
73
+ "username": dify_config.REDIS_SENTINEL_USERNAME,
74
+ "password": dify_config.REDIS_SENTINEL_PASSWORD,
75
+ },
76
+ )
77
+ master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
78
+ redis_client.initialize(master)
79
+ elif dify_config.REDIS_USE_CLUSTERS:
80
+ assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
81
+ nodes = [
82
+ ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
83
+ for node in dify_config.REDIS_CLUSTERS.split(",")
84
+ ]
85
+ # FIXME: mypy error here, try to figure out how to fix it
86
+ redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
87
+ else:
88
+ redis_params.update(
89
+ {
90
+ "host": dify_config.REDIS_HOST,
91
+ "port": dify_config.REDIS_PORT,
92
+ "connection_class": connection_class,
93
+ }
94
+ )
95
+ pool = redis.ConnectionPool(**redis_params)
96
+ redis_client.initialize(redis.Redis(connection_pool=pool))
97
+
98
+ app.extensions["redis"] = redis_client
api/extensions/ext_sentry.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ if dify_config.SENTRY_DSN:
7
+ import openai
8
+ import sentry_sdk
9
+ from langfuse import parse_error # type: ignore
10
+ from sentry_sdk.integrations.celery import CeleryIntegration
11
+ from sentry_sdk.integrations.flask import FlaskIntegration
12
+ from werkzeug.exceptions import HTTPException
13
+
14
+ from core.model_runtime.errors.invoke import InvokeRateLimitError
15
+
16
+ def before_send(event, hint):
17
+ if "exc_info" in hint:
18
+ exc_type, exc_value, tb = hint["exc_info"]
19
+ if parse_error.defaultErrorResponse in str(exc_value):
20
+ return None
21
+
22
+ return event
23
+
24
+ sentry_sdk.init(
25
+ dsn=dify_config.SENTRY_DSN,
26
+ integrations=[FlaskIntegration(), CeleryIntegration()],
27
+ ignore_errors=[
28
+ HTTPException,
29
+ ValueError,
30
+ FileNotFoundError,
31
+ openai.APIStatusError,
32
+ InvokeRateLimitError,
33
+ parse_error.defaultErrorResponse,
34
+ ],
35
+ traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
36
+ profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
37
+ environment=dify_config.DEPLOY_ENV,
38
+ release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
39
+ before_send=before_send,
40
+ )
api/extensions/ext_set_secretkey.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from configs import dify_config
2
+ from dify_app import DifyApp
3
+
4
+
5
+ def init_app(app: DifyApp):
6
+ app.secret_key = dify_config.SECRET_KEY
api/extensions/ext_storage.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Callable, Generator
3
+ from typing import Literal, Union, overload
4
+
5
+ from flask import Flask
6
+
7
+ from configs import dify_config
8
+ from dify_app import DifyApp
9
+ from extensions.storage.base_storage import BaseStorage
10
+ from extensions.storage.storage_type import StorageType
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Storage:
16
+ def init_app(self, app: Flask):
17
+ storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
18
+ with app.app_context():
19
+ self.storage_runner = storage_factory()
20
+
21
+ @staticmethod
22
+ def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]:
23
+ match storage_type:
24
+ case StorageType.S3:
25
+ from extensions.storage.aws_s3_storage import AwsS3Storage
26
+
27
+ return AwsS3Storage
28
+ case StorageType.OPENDAL:
29
+ from extensions.storage.opendal_storage import OpenDALStorage
30
+
31
+ return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME)
32
+ case StorageType.LOCAL:
33
+ from extensions.storage.opendal_storage import OpenDALStorage
34
+
35
+ return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH)
36
+ case StorageType.AZURE_BLOB:
37
+ from extensions.storage.azure_blob_storage import AzureBlobStorage
38
+
39
+ return AzureBlobStorage
40
+ case StorageType.ALIYUN_OSS:
41
+ from extensions.storage.aliyun_oss_storage import AliyunOssStorage
42
+
43
+ return AliyunOssStorage
44
+ case StorageType.GOOGLE_STORAGE:
45
+ from extensions.storage.google_cloud_storage import GoogleCloudStorage
46
+
47
+ return GoogleCloudStorage
48
+ case StorageType.TENCENT_COS:
49
+ from extensions.storage.tencent_cos_storage import TencentCosStorage
50
+
51
+ return TencentCosStorage
52
+ case StorageType.OCI_STORAGE:
53
+ from extensions.storage.oracle_oci_storage import OracleOCIStorage
54
+
55
+ return OracleOCIStorage
56
+ case StorageType.HUAWEI_OBS:
57
+ from extensions.storage.huawei_obs_storage import HuaweiObsStorage
58
+
59
+ return HuaweiObsStorage
60
+ case StorageType.BAIDU_OBS:
61
+ from extensions.storage.baidu_obs_storage import BaiduObsStorage
62
+
63
+ return BaiduObsStorage
64
+ case StorageType.VOLCENGINE_TOS:
65
+ from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
66
+
67
+ return VolcengineTosStorage
68
+ case StorageType.SUPBASE:
69
+ from extensions.storage.supabase_storage import SupabaseStorage
70
+
71
+ return SupabaseStorage
72
+ case _:
73
+ raise ValueError(f"unsupported storage type {storage_type}")
74
+
75
+ def save(self, filename, data):
76
+ try:
77
+ self.storage_runner.save(filename, data)
78
+ except Exception as e:
79
+ logger.exception(f"Failed to save file {filename}")
80
+ raise e
81
+
82
+ @overload
83
+ def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
84
+
85
+ @overload
86
+ def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
87
+
88
+ def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
89
+ try:
90
+ if stream:
91
+ return self.load_stream(filename)
92
+ else:
93
+ return self.load_once(filename)
94
+ except Exception as e:
95
+ logger.exception(f"Failed to load file {filename}")
96
+ raise e
97
+
98
+ def load_once(self, filename: str) -> bytes:
99
+ try:
100
+ return self.storage_runner.load_once(filename)
101
+ except Exception as e:
102
+ logger.exception(f"Failed to load_once file {filename}")
103
+ raise e
104
+
105
+ def load_stream(self, filename: str) -> Generator:
106
+ try:
107
+ return self.storage_runner.load_stream(filename)
108
+ except Exception as e:
109
+ logger.exception(f"Failed to load_stream file {filename}")
110
+ raise e
111
+
112
+ def download(self, filename, target_filepath):
113
+ try:
114
+ self.storage_runner.download(filename, target_filepath)
115
+ except Exception as e:
116
+ logger.exception(f"Failed to download file {filename}")
117
+ raise e
118
+
119
+ def exists(self, filename):
120
+ try:
121
+ return self.storage_runner.exists(filename)
122
+ except Exception as e:
123
+ logger.exception(f"Failed to check file exists {filename}")
124
+ raise e
125
+
126
+ def delete(self, filename):
127
+ try:
128
+ return self.storage_runner.delete(filename)
129
+ except Exception as e:
130
+ logger.exception(f"Failed to delete file {filename}")
131
+ raise e
132
+
133
+
134
+ storage = Storage()
135
+
136
+
137
+ def init_app(app: DifyApp):
138
+ storage.init_app(app)
api/extensions/ext_timezone.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ from dify_app import DifyApp
5
+
6
+
7
+ def init_app(app: DifyApp):
8
+ os.environ["TZ"] = "UTC"
9
+ # windows platform not support tzset
10
+ if hasattr(time, "tzset"):
11
+ time.tzset()
api/extensions/ext_warnings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from dify_app import DifyApp
2
+
3
+
4
+ def init_app(app: DifyApp):
5
+ import warnings
6
+
7
+ warnings.simplefilter("ignore", ResourceWarning)
api/extensions/storage/aliyun_oss_storage.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import posixpath
2
+ from collections.abc import Generator
3
+
4
+ import oss2 as aliyun_s3 # type: ignore
5
+
6
+ from configs import dify_config
7
+ from extensions.storage.base_storage import BaseStorage
8
+
9
+
10
+ class AliyunOssStorage(BaseStorage):
11
+ """Implementation for Aliyun OSS storage."""
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
16
+ self.folder = dify_config.ALIYUN_OSS_PATH
17
+ oss_auth_method = aliyun_s3.Auth
18
+ region = None
19
+ if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
20
+ oss_auth_method = aliyun_s3.AuthV4
21
+ region = dify_config.ALIYUN_OSS_REGION
22
+ oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
23
+ self.client = aliyun_s3.Bucket(
24
+ oss_auth,
25
+ dify_config.ALIYUN_OSS_ENDPOINT,
26
+ self.bucket_name,
27
+ connect_timeout=30,
28
+ region=region,
29
+ )
30
+
31
+ def save(self, filename, data):
32
+ self.client.put_object(self.__wrapper_folder_filename(filename), data)
33
+
34
+ def load_once(self, filename: str) -> bytes:
35
+ obj = self.client.get_object(self.__wrapper_folder_filename(filename))
36
+ data: bytes = obj.read()
37
+ return data
38
+
39
+ def load_stream(self, filename: str) -> Generator:
40
+ obj = self.client.get_object(self.__wrapper_folder_filename(filename))
41
+ while chunk := obj.read(4096):
42
+ yield chunk
43
+
44
+ def download(self, filename: str, target_filepath):
45
+ self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
46
+
47
+ def exists(self, filename: str):
48
+ return self.client.object_exists(self.__wrapper_folder_filename(filename))
49
+
50
+ def delete(self, filename: str):
51
+ self.client.delete_object(self.__wrapper_folder_filename(filename))
52
+
53
+ def __wrapper_folder_filename(self, filename: str) -> str:
54
+ return posixpath.join(self.folder, filename) if self.folder else filename
api/extensions/storage/aws_s3_storage.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Generator
3
+
4
+ import boto3 # type: ignore
5
+ from botocore.client import Config # type: ignore
6
+ from botocore.exceptions import ClientError # type: ignore
7
+
8
+ from configs import dify_config
9
+ from extensions.storage.base_storage import BaseStorage
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class AwsS3Storage(BaseStorage):
15
+ """Implementation for Amazon Web Services S3 storage."""
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.bucket_name = dify_config.S3_BUCKET_NAME
20
+ if dify_config.S3_USE_AWS_MANAGED_IAM:
21
+ logger.info("Using AWS managed IAM role for S3")
22
+
23
+ session = boto3.Session()
24
+ region_name = dify_config.S3_REGION
25
+ self.client = session.client(service_name="s3", region_name=region_name)
26
+ else:
27
+ logger.info("Using ak and sk for S3")
28
+
29
+ self.client = boto3.client(
30
+ "s3",
31
+ aws_secret_access_key=dify_config.S3_SECRET_KEY,
32
+ aws_access_key_id=dify_config.S3_ACCESS_KEY,
33
+ endpoint_url=dify_config.S3_ENDPOINT,
34
+ region_name=dify_config.S3_REGION,
35
+ config=Config(
36
+ s3={"addressing_style": dify_config.S3_ADDRESS_STYLE},
37
+ request_checksum_calculation="when_required",
38
+ response_checksum_validation="when_required",
39
+ ),
40
+ )
41
+ # create bucket
42
+ try:
43
+ self.client.head_bucket(Bucket=self.bucket_name)
44
+ except ClientError as e:
45
+ # if bucket not exists, create it
46
+ if e.response["Error"]["Code"] == "404":
47
+ self.client.create_bucket(Bucket=self.bucket_name)
48
+ # if bucket is not accessible, pass, maybe the bucket is existing but not accessible
49
+ elif e.response["Error"]["Code"] == "403":
50
+ pass
51
+ else:
52
+ # other error, raise exception
53
+ raise
54
+
55
+ def save(self, filename, data):
56
+ self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
57
+
58
+ def load_once(self, filename: str) -> bytes:
59
+ try:
60
+ data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
61
+ except ClientError as ex:
62
+ if ex.response["Error"]["Code"] == "NoSuchKey":
63
+ raise FileNotFoundError("File not found")
64
+ else:
65
+ raise
66
+ return data
67
+
68
+ def load_stream(self, filename: str) -> Generator:
69
+ try:
70
+ response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
71
+ yield from response["Body"].iter_chunks()
72
+ except ClientError as ex:
73
+ if ex.response["Error"]["Code"] == "NoSuchKey":
74
+ raise FileNotFoundError("file not found")
75
+ elif "reached max retries" in str(ex):
76
+ raise ValueError("please do not request the same file too frequently")
77
+ else:
78
+ raise
79
+
80
+ def download(self, filename, target_filepath):
81
+ self.client.download_file(self.bucket_name, filename, target_filepath)
82
+
83
+ def exists(self, filename):
84
+ try:
85
+ self.client.head_object(Bucket=self.bucket_name, Key=filename)
86
+ return True
87
+ except:
88
+ return False
89
+
90
+ def delete(self, filename):
91
+ self.client.delete_object(Bucket=self.bucket_name, Key=filename)
api/extensions/storage/azure_blob_storage.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Generator
2
+ from datetime import UTC, datetime, timedelta
3
+ from typing import Optional
4
+
5
+ from azure.identity import ChainedTokenCredential, DefaultAzureCredential
6
+ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
7
+
8
+ from configs import dify_config
9
+ from extensions.ext_redis import redis_client
10
+ from extensions.storage.base_storage import BaseStorage
11
+
12
+
13
+ class AzureBlobStorage(BaseStorage):
14
+ """Implementation for Azure Blob storage."""
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
19
+ self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
20
+ self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
21
+ self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
22
+
23
+ self.credential: Optional[ChainedTokenCredential] = None
24
+ if self.account_key == "managedidentity":
25
+ self.credential = DefaultAzureCredential()
26
+ else:
27
+ self.credential = None
28
+
29
+ def save(self, filename, data):
30
+ client = self._sync_client()
31
+ blob_container = client.get_container_client(container=self.bucket_name)
32
+ blob_container.upload_blob(filename, data)
33
+
34
+ def load_once(self, filename: str) -> bytes:
35
+ client = self._sync_client()
36
+ blob = client.get_container_client(container=self.bucket_name)
37
+ blob = blob.get_blob_client(blob=filename)
38
+ data: bytes = blob.download_blob().readall()
39
+ return data
40
+
41
+ def load_stream(self, filename: str) -> Generator:
42
+ client = self._sync_client()
43
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
44
+ blob_data = blob.download_blob()
45
+ yield from blob_data.chunks()
46
+
47
+ def download(self, filename, target_filepath):
48
+ client = self._sync_client()
49
+
50
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
51
+ with open(target_filepath, "wb") as my_blob:
52
+ blob_data = blob.download_blob()
53
+ blob_data.readinto(my_blob)
54
+
55
+ def exists(self, filename):
56
+ client = self._sync_client()
57
+
58
+ blob = client.get_blob_client(container=self.bucket_name, blob=filename)
59
+ return blob.exists()
60
+
61
+ def delete(self, filename):
62
+ client = self._sync_client()
63
+
64
+ blob_container = client.get_container_client(container=self.bucket_name)
65
+ blob_container.delete_blob(filename)
66
+
67
+ def _sync_client(self):
68
+ if self.account_key == "managedidentity":
69
+ return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
70
+
71
+ cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
72
+ cache_result = redis_client.get(cache_key)
73
+ if cache_result is not None:
74
+ sas_token = cache_result.decode("utf-8")
75
+ else:
76
+ sas_token = generate_account_sas(
77
+ account_name=self.account_name or "",
78
+ account_key=self.account_key or "",
79
+ resource_types=ResourceTypes(service=True, container=True, object=True),
80
+ permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
81
+ expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
82
+ )
83
+ redis_client.set(cache_key, sas_token, ex=3000)
84
+ return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
api/extensions/storage/baidu_obs_storage.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+ from collections.abc import Generator
4
+
5
+ from baidubce.auth.bce_credentials import BceCredentials # type: ignore
6
+ from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
7
+ from baidubce.services.bos.bos_client import BosClient # type: ignore
8
+
9
+ from configs import dify_config
10
+ from extensions.storage.base_storage import BaseStorage
11
+
12
+
13
+ class BaiduObsStorage(BaseStorage):
14
+ """Implementation for Baidu OBS storage."""
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
19
+ client_config = BceClientConfiguration(
20
+ credentials=BceCredentials(
21
+ access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
22
+ secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
23
+ ),
24
+ endpoint=dify_config.BAIDU_OBS_ENDPOINT,
25
+ )
26
+
27
+ self.client = BosClient(config=client_config)
28
+
29
+ def save(self, filename, data):
30
+ md5 = hashlib.md5()
31
+ md5.update(data)
32
+ content_md5 = base64.standard_b64encode(md5.digest())
33
+ self.client.put_object(
34
+ bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
35
+ )
36
+
37
+ def load_once(self, filename: str) -> bytes:
38
+ response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
39
+ data: bytes = response.data.read()
40
+ return data
41
+
42
+ def load_stream(self, filename: str) -> Generator:
43
+ response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
44
+ while chunk := response.read(4096):
45
+ yield chunk
46
+
47
+ def download(self, filename, target_filepath):
48
+ self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
49
+
50
+ def exists(self, filename):
51
+ res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
52
+ if res is None:
53
+ return False
54
+ return True
55
+
56
+ def delete(self, filename):
57
+ self.client.delete_object(bucket_name=self.bucket_name, key=filename)
api/extensions/storage/base_storage.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract interface for file storage implementations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Generator
5
+
6
+
7
+ class BaseStorage(ABC):
8
+ """Interface for file storage."""
9
+
10
+ @abstractmethod
11
+ def save(self, filename, data):
12
+ raise NotImplementedError
13
+
14
+ @abstractmethod
15
+ def load_once(self, filename: str) -> bytes:
16
+ raise NotImplementedError
17
+
18
+ @abstractmethod
19
+ def load_stream(self, filename: str) -> Generator:
20
+ raise NotImplementedError
21
+
22
+ @abstractmethod
23
+ def download(self, filename, target_filepath):
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def exists(self, filename):
28
+ raise NotImplementedError
29
+
30
+ @abstractmethod
31
+ def delete(self, filename):
32
+ raise NotImplementedError
api/extensions/storage/google_cloud_storage.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ from collections.abc import Generator
5
+
6
+ from google.cloud import storage as google_cloud_storage # type: ignore
7
+
8
+ from configs import dify_config
9
+ from extensions.storage.base_storage import BaseStorage
10
+
11
+
12
+ class GoogleCloudStorage(BaseStorage):
13
+ """Implementation for Google Cloud storage."""
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
19
+ service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
20
+ # if service_account_json_str is empty, use Application Default Credentials
21
+ if service_account_json_str:
22
+ service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
23
+ # convert str to object
24
+ service_account_obj = json.loads(service_account_json)
25
+ self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
26
+ else:
27
+ self.client = google_cloud_storage.Client()
28
+
29
+ def save(self, filename, data):
30
+ bucket = self.client.get_bucket(self.bucket_name)
31
+ blob = bucket.blob(filename)
32
+ with io.BytesIO(data) as stream:
33
+ blob.upload_from_file(stream)
34
+
35
+ def load_once(self, filename: str) -> bytes:
36
+ bucket = self.client.get_bucket(self.bucket_name)
37
+ blob = bucket.get_blob(filename)
38
+ data: bytes = blob.download_as_bytes()
39
+ return data
40
+
41
+ def load_stream(self, filename: str) -> Generator:
42
+ bucket = self.client.get_bucket(self.bucket_name)
43
+ blob = bucket.get_blob(filename)
44
+ with blob.open(mode="rb") as blob_stream:
45
+ while chunk := blob_stream.read(4096):
46
+ yield chunk
47
+
48
+ def download(self, filename, target_filepath):
49
+ bucket = self.client.get_bucket(self.bucket_name)
50
+ blob = bucket.get_blob(filename)
51
+ blob.download_to_filename(target_filepath)
52
+
53
+ def exists(self, filename):
54
+ bucket = self.client.get_bucket(self.bucket_name)
55
+ blob = bucket.blob(filename)
56
+ return blob.exists()
57
+
58
+ def delete(self, filename):
59
+ bucket = self.client.get_bucket(self.bucket_name)
60
+ bucket.delete_blob(filename)
api/extensions/storage/huawei_obs_storage.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Generator
2
+
3
+ from obs import ObsClient # type: ignore
4
+
5
+ from configs import dify_config
6
+ from extensions.storage.base_storage import BaseStorage
7
+
8
+
9
+ class HuaweiObsStorage(BaseStorage):
10
+ """Implementation for Huawei OBS storage."""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
16
+ self.client = ObsClient(
17
+ access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
18
+ secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
19
+ server=dify_config.HUAWEI_OBS_SERVER,
20
+ )
21
+
22
+ def save(self, filename, data):
23
+ self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
24
+
25
+ def load_once(self, filename: str) -> bytes:
26
+ data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
27
+ return data
28
+
29
+ def load_stream(self, filename: str) -> Generator:
30
+ response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
31
+ while chunk := response.read(4096):
32
+ yield chunk
33
+
34
+ def download(self, filename, target_filepath):
35
+ self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
36
+
37
+ def exists(self, filename):
38
+ res = self._get_meta(filename)
39
+ if res is None:
40
+ return False
41
+ return True
42
+
43
+ def delete(self, filename):
44
+ self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)
45
+
46
+ def _get_meta(self, filename):
47
+ res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
48
+ if res.status < 300:
49
+ return res
50
+ else:
51
+ return None
api/extensions/storage/opendal_storage.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections.abc import Generator
4
+ from pathlib import Path
5
+
6
+ import opendal # type: ignore[import]
7
+ from dotenv import dotenv_values
8
+
9
+ from extensions.storage.base_storage import BaseStorage
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"):
15
+ kwargs = {}
16
+ config_prefix = prefix + scheme.upper() + "_"
17
+ for key, value in os.environ.items():
18
+ if key.startswith(config_prefix):
19
+ kwargs[key[len(config_prefix) :].lower()] = value
20
+
21
+ file_env_vars: dict = dotenv_values(env_file_path) or {}
22
+ for key, value in file_env_vars.items():
23
+ if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
24
+ kwargs[key[len(config_prefix) :].lower()] = value
25
+
26
+ return kwargs
27
+
28
+
29
+ class OpenDALStorage(BaseStorage):
30
+ def __init__(self, scheme: str, **kwargs):
31
+ kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
32
+
33
+ if scheme == "fs":
34
+ root = kwargs.get("root", "storage")
35
+ Path(root).mkdir(parents=True, exist_ok=True)
36
+
37
+ self.op = opendal.Operator(scheme=scheme, **kwargs)
38
+ logger.debug(f"opendal operator created with scheme {scheme}")
39
+ retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
40
+ self.op = self.op.layer(retry_layer)
41
+ logger.debug("added retry layer to opendal operator")
42
+
43
+ def save(self, filename: str, data: bytes) -> None:
44
+ self.op.write(path=filename, bs=data)
45
+ logger.debug(f"file {filename} saved")
46
+
47
+ def load_once(self, filename: str) -> bytes:
48
+ if not self.exists(filename):
49
+ raise FileNotFoundError("File not found")
50
+
51
+ content: bytes = self.op.read(path=filename)
52
+ logger.debug(f"file {filename} loaded")
53
+ return content
54
+
55
+ def load_stream(self, filename: str) -> Generator:
56
+ if not self.exists(filename):
57
+ raise FileNotFoundError("File not found")
58
+
59
+ batch_size = 4096
60
+ file = self.op.open(path=filename, mode="rb")
61
+ while chunk := file.read(batch_size):
62
+ yield chunk
63
+ logger.debug(f"file {filename} loaded as stream")
64
+
65
+ def download(self, filename: str, target_filepath: str):
66
+ if not self.exists(filename):
67
+ raise FileNotFoundError("File not found")
68
+
69
+ with Path(target_filepath).open("wb") as f:
70
+ f.write(self.op.read(path=filename))
71
+ logger.debug(f"file {filename} downloaded to {target_filepath}")
72
+
73
+ def exists(self, filename: str) -> bool:
74
+ # FIXME this is a workaround for opendal python-binding do not have a exists method and no better
75
+ # error handler here when opendal python-binding has a exists method, we should use it
76
+ # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
77
+ try:
78
+ res: bool = self.op.stat(path=filename).mode.is_file()
79
+ logger.debug(f"file {filename} checked")
80
+ return res
81
+ except Exception:
82
+ return False
83
+
84
+ def delete(self, filename: str):
85
+ if self.exists(filename):
86
+ self.op.delete(path=filename)
87
+ logger.debug(f"file {filename} deleted")
88
+ return
89
+ logger.debug(f"file {filename} not found, skip delete")
api/extensions/storage/oracle_oci_storage.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Generator
2
+
3
+ import boto3 # type: ignore
4
+ from botocore.exceptions import ClientError # type: ignore
5
+
6
+ from configs import dify_config
7
+ from extensions.storage.base_storage import BaseStorage
8
+
9
+
10
+ class OracleOCIStorage(BaseStorage):
11
+ """Implementation for Oracle OCI storage."""
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ self.bucket_name = dify_config.OCI_BUCKET_NAME
17
+ self.client = boto3.client(
18
+ "s3",
19
+ aws_secret_access_key=dify_config.OCI_SECRET_KEY,
20
+ aws_access_key_id=dify_config.OCI_ACCESS_KEY,
21
+ endpoint_url=dify_config.OCI_ENDPOINT,
22
+ region_name=dify_config.OCI_REGION,
23
+ )
24
+
25
+ def save(self, filename, data):
26
+ self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
27
+
28
+ def load_once(self, filename: str) -> bytes:
29
+ try:
30
+ data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
31
+ except ClientError as ex:
32
+ if ex.response["Error"]["Code"] == "NoSuchKey":
33
+ raise FileNotFoundError("File not found")
34
+ else:
35
+ raise
36
+ return data
37
+
38
+ def load_stream(self, filename: str) -> Generator:
39
+ try:
40
+ response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
41
+ yield from response["Body"].iter_chunks()
42
+ except ClientError as ex:
43
+ if ex.response["Error"]["Code"] == "NoSuchKey":
44
+ raise FileNotFoundError("File not found")
45
+ else:
46
+ raise
47
+
48
+ def download(self, filename, target_filepath):
49
+ self.client.download_file(self.bucket_name, filename, target_filepath)
50
+
51
+ def exists(self, filename):
52
+ try:
53
+ self.client.head_object(Bucket=self.bucket_name, Key=filename)
54
+ return True
55
+ except:
56
+ return False
57
+
58
+ def delete(self, filename):
59
+ self.client.delete_object(Bucket=self.bucket_name, Key=filename)