Upload 697 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- api/docker/entrypoint.sh +40 -0
- api/events/__init__.py +0 -0
- api/events/app_event.py +13 -0
- api/events/dataset_event.py +4 -0
- api/events/document_event.py +4 -0
- api/events/event_handlers/__init__.py +10 -0
- api/events/event_handlers/clean_when_dataset_deleted.py +15 -0
- api/events/event_handlers/clean_when_document_deleted.py +11 -0
- api/events/event_handlers/create_document_index.py +49 -0
- api/events/event_handlers/create_installed_app_when_app_created.py +16 -0
- api/events/event_handlers/create_site_record_when_app_created.py +26 -0
- api/events/event_handlers/deduct_quota_when_message_created.py +53 -0
- api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +34 -0
- api/events/event_handlers/document_index_event.py +4 -0
- api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +68 -0
- api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +67 -0
- api/events/event_handlers/update_provider_last_used_at_when_message_created.py +21 -0
- api/events/message_event.py +4 -0
- api/events/tenant_event.py +7 -0
- api/extensions/__init__.py +0 -0
- api/extensions/ext_app_metrics.py +67 -0
- api/extensions/ext_blueprints.py +48 -0
- api/extensions/ext_celery.py +104 -0
- api/extensions/ext_code_based_extension.py +9 -0
- api/extensions/ext_commands.py +29 -0
- api/extensions/ext_compress.py +13 -0
- api/extensions/ext_database.py +6 -0
- api/extensions/ext_hosting_provider.py +10 -0
- api/extensions/ext_import_modules.py +5 -0
- api/extensions/ext_logging.py +71 -0
- api/extensions/ext_login.py +62 -0
- api/extensions/ext_mail.py +97 -0
- api/extensions/ext_migrate.py +9 -0
- api/extensions/ext_proxy_fix.py +9 -0
- api/extensions/ext_redis.py +98 -0
- api/extensions/ext_sentry.py +40 -0
- api/extensions/ext_set_secretkey.py +6 -0
- api/extensions/ext_storage.py +138 -0
- api/extensions/ext_timezone.py +11 -0
- api/extensions/ext_warnings.py +7 -0
- api/extensions/storage/aliyun_oss_storage.py +54 -0
- api/extensions/storage/aws_s3_storage.py +91 -0
- api/extensions/storage/azure_blob_storage.py +84 -0
- api/extensions/storage/baidu_obs_storage.py +57 -0
- api/extensions/storage/base_storage.py +32 -0
- api/extensions/storage/google_cloud_storage.py +60 -0
- api/extensions/storage/huawei_obs_storage.py +51 -0
- api/extensions/storage/opendal_storage.py +89 -0
- api/extensions/storage/oracle_oci_storage.py +59 -0
.gitattributes
CHANGED
@@ -5,3 +5,4 @@
|
|
5 |
# them.
|
6 |
|
7 |
*.sh text eol=lf
|
|
|
|
5 |
# them.
|
6 |
|
7 |
*.sh text eol=lf
|
8 |
+
api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
|
api/docker/entrypoint.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
6 |
+
echo "Running migrations"
|
7 |
+
flask upgrade-db
|
8 |
+
fi
|
9 |
+
|
10 |
+
if [[ "${MODE}" == "worker" ]]; then
|
11 |
+
|
12 |
+
# Get the number of available CPU cores
|
13 |
+
if [ "${CELERY_AUTO_SCALE,,}" = "true" ]; then
|
14 |
+
# Set MAX_WORKERS to the number of available cores if not specified
|
15 |
+
AVAILABLE_CORES=$(nproc)
|
16 |
+
MAX_WORKERS=${CELERY_MAX_WORKERS:-$AVAILABLE_CORES}
|
17 |
+
MIN_WORKERS=${CELERY_MIN_WORKERS:-1}
|
18 |
+
CONCURRENCY_OPTION="--autoscale=${MAX_WORKERS},${MIN_WORKERS}"
|
19 |
+
else
|
20 |
+
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
|
21 |
+
fi
|
22 |
+
|
23 |
+
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \
|
24 |
+
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
|
25 |
+
|
26 |
+
elif [[ "${MODE}" == "beat" ]]; then
|
27 |
+
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
28 |
+
else
|
29 |
+
if [[ "${DEBUG}" == "true" ]]; then
|
30 |
+
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
|
31 |
+
else
|
32 |
+
exec gunicorn \
|
33 |
+
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
34 |
+
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
35 |
+
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
36 |
+
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
37 |
+
--timeout ${GUNICORN_TIMEOUT:-200} \
|
38 |
+
app:app
|
39 |
+
fi
|
40 |
+
fi
|
api/events/__init__.py
ADDED
File without changes
|
api/events/app_event.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: app
|
4 |
+
app_was_created = signal("app-was-created")
|
5 |
+
|
6 |
+
# sender: app, kwargs: app_model_config
|
7 |
+
app_model_config_was_updated = signal("app-model-config-was-updated")
|
8 |
+
|
9 |
+
# sender: app, kwargs: published_workflow
|
10 |
+
app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
|
11 |
+
|
12 |
+
# sender: app, kwargs: synced_draft_workflow
|
13 |
+
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")
|
api/events/dataset_event.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: dataset
|
4 |
+
dataset_was_deleted = signal("dataset-was-deleted")
|
api/events/document_event.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: document
|
4 |
+
document_was_deleted = signal("document-was-deleted")
|
api/events/event_handlers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clean_when_dataset_deleted import handle
|
2 |
+
from .clean_when_document_deleted import handle
|
3 |
+
from .create_document_index import handle
|
4 |
+
from .create_installed_app_when_app_created import handle
|
5 |
+
from .create_site_record_when_app_created import handle
|
6 |
+
from .deduct_quota_when_message_created import handle
|
7 |
+
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
8 |
+
from .update_app_dataset_join_when_app_model_config_updated import handle
|
9 |
+
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
10 |
+
from .update_provider_last_used_at_when_message_created import handle
|
api/events/event_handlers/clean_when_dataset_deleted.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from events.dataset_event import dataset_was_deleted
|
2 |
+
from tasks.clean_dataset_task import clean_dataset_task
|
3 |
+
|
4 |
+
|
5 |
+
@dataset_was_deleted.connect
|
6 |
+
def handle(sender, **kwargs):
|
7 |
+
dataset = sender
|
8 |
+
clean_dataset_task.delay(
|
9 |
+
dataset.id,
|
10 |
+
dataset.tenant_id,
|
11 |
+
dataset.indexing_technique,
|
12 |
+
dataset.index_struct,
|
13 |
+
dataset.collection_binding_id,
|
14 |
+
dataset.doc_form,
|
15 |
+
)
|
api/events/event_handlers/clean_when_document_deleted.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from events.document_event import document_was_deleted
|
2 |
+
from tasks.clean_document_task import clean_document_task
|
3 |
+
|
4 |
+
|
5 |
+
@document_was_deleted.connect
|
6 |
+
def handle(sender, **kwargs):
|
7 |
+
document_id = sender
|
8 |
+
dataset_id = kwargs.get("dataset_id")
|
9 |
+
doc_form = kwargs.get("doc_form")
|
10 |
+
file_id = kwargs.get("file_id")
|
11 |
+
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
api/events/event_handlers/create_document_index.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
|
5 |
+
import click
|
6 |
+
from werkzeug.exceptions import NotFound
|
7 |
+
|
8 |
+
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
9 |
+
from events.event_handlers.document_index_event import document_index_created
|
10 |
+
from extensions.ext_database import db
|
11 |
+
from models.dataset import Document
|
12 |
+
|
13 |
+
|
14 |
+
@document_index_created.connect
|
15 |
+
def handle(sender, **kwargs):
|
16 |
+
dataset_id = sender
|
17 |
+
document_ids = kwargs.get("document_ids", [])
|
18 |
+
documents = []
|
19 |
+
start_at = time.perf_counter()
|
20 |
+
for document_id in document_ids:
|
21 |
+
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
|
22 |
+
|
23 |
+
document = (
|
24 |
+
db.session.query(Document)
|
25 |
+
.filter(
|
26 |
+
Document.id == document_id,
|
27 |
+
Document.dataset_id == dataset_id,
|
28 |
+
)
|
29 |
+
.first()
|
30 |
+
)
|
31 |
+
|
32 |
+
if not document:
|
33 |
+
raise NotFound("Document not found")
|
34 |
+
|
35 |
+
document.indexing_status = "parsing"
|
36 |
+
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
37 |
+
documents.append(document)
|
38 |
+
db.session.add(document)
|
39 |
+
db.session.commit()
|
40 |
+
|
41 |
+
try:
|
42 |
+
indexing_runner = IndexingRunner()
|
43 |
+
indexing_runner.run(documents)
|
44 |
+
end_at = time.perf_counter()
|
45 |
+
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
46 |
+
except DocumentIsPausedError as ex:
|
47 |
+
logging.info(click.style(str(ex), fg="yellow"))
|
48 |
+
except Exception:
|
49 |
+
pass
|
api/events/event_handlers/create_installed_app_when_app_created.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from events.app_event import app_was_created
|
2 |
+
from extensions.ext_database import db
|
3 |
+
from models.model import InstalledApp
|
4 |
+
|
5 |
+
|
6 |
+
@app_was_created.connect
|
7 |
+
def handle(sender, **kwargs):
|
8 |
+
"""Create an installed app when an app is created."""
|
9 |
+
app = sender
|
10 |
+
installed_app = InstalledApp(
|
11 |
+
tenant_id=app.tenant_id,
|
12 |
+
app_id=app.id,
|
13 |
+
app_owner_tenant_id=app.tenant_id,
|
14 |
+
)
|
15 |
+
db.session.add(installed_app)
|
16 |
+
db.session.commit()
|
api/events/event_handlers/create_site_record_when_app_created.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from events.app_event import app_was_created
|
2 |
+
from extensions.ext_database import db
|
3 |
+
from models.model import Site
|
4 |
+
|
5 |
+
|
6 |
+
@app_was_created.connect
|
7 |
+
def handle(sender, **kwargs):
|
8 |
+
"""Create site record when an app is created."""
|
9 |
+
app = sender
|
10 |
+
account = kwargs.get("account")
|
11 |
+
if account is not None:
|
12 |
+
site = Site(
|
13 |
+
app_id=app.id,
|
14 |
+
title=app.name,
|
15 |
+
icon_type=app.icon_type,
|
16 |
+
icon=app.icon,
|
17 |
+
icon_background=app.icon_background,
|
18 |
+
default_language=account.interface_language,
|
19 |
+
customize_token_strategy="not_allow",
|
20 |
+
code=Site.generate_code(16),
|
21 |
+
created_by=app.created_by,
|
22 |
+
updated_by=app.updated_by,
|
23 |
+
)
|
24 |
+
|
25 |
+
db.session.add(site)
|
26 |
+
db.session.commit()
|
api/events/event_handlers/deduct_quota_when_message_created.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
3 |
+
from core.entities.provider_entities import QuotaUnit
|
4 |
+
from events.message_event import message_was_created
|
5 |
+
from extensions.ext_database import db
|
6 |
+
from models.provider import Provider, ProviderType
|
7 |
+
|
8 |
+
|
9 |
+
@message_was_created.connect
|
10 |
+
def handle(sender, **kwargs):
|
11 |
+
message = sender
|
12 |
+
application_generate_entity = kwargs.get("application_generate_entity")
|
13 |
+
|
14 |
+
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
15 |
+
return
|
16 |
+
|
17 |
+
model_config = application_generate_entity.model_conf
|
18 |
+
provider_model_bundle = model_config.provider_model_bundle
|
19 |
+
provider_configuration = provider_model_bundle.configuration
|
20 |
+
|
21 |
+
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
22 |
+
return
|
23 |
+
|
24 |
+
system_configuration = provider_configuration.system_configuration
|
25 |
+
|
26 |
+
quota_unit = None
|
27 |
+
for quota_configuration in system_configuration.quota_configurations:
|
28 |
+
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
29 |
+
quota_unit = quota_configuration.quota_unit
|
30 |
+
|
31 |
+
if quota_configuration.quota_limit == -1:
|
32 |
+
return
|
33 |
+
|
34 |
+
break
|
35 |
+
|
36 |
+
used_quota = None
|
37 |
+
if quota_unit:
|
38 |
+
if quota_unit == QuotaUnit.TOKENS:
|
39 |
+
used_quota = message.message_tokens + message.answer_tokens
|
40 |
+
elif quota_unit == QuotaUnit.CREDITS:
|
41 |
+
used_quota = dify_config.get_model_credits(model_config.model)
|
42 |
+
else:
|
43 |
+
used_quota = 1
|
44 |
+
|
45 |
+
if used_quota is not None and system_configuration.current_quota_type is not None:
|
46 |
+
db.session.query(Provider).filter(
|
47 |
+
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
48 |
+
Provider.provider_name == model_config.provider,
|
49 |
+
Provider.provider_type == ProviderType.SYSTEM.value,
|
50 |
+
Provider.quota_type == system_configuration.current_quota_type.value,
|
51 |
+
Provider.quota_limit > Provider.quota_used,
|
52 |
+
).update({"quota_used": Provider.quota_used + used_quota})
|
53 |
+
db.session.commit()
|
api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.tools.tool_manager import ToolManager
|
2 |
+
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
3 |
+
from core.workflow.nodes import NodeType
|
4 |
+
from core.workflow.nodes.tool.entities import ToolEntity
|
5 |
+
from events.app_event import app_draft_workflow_was_synced
|
6 |
+
|
7 |
+
|
8 |
+
@app_draft_workflow_was_synced.connect
|
9 |
+
def handle(sender, **kwargs):
|
10 |
+
app = sender
|
11 |
+
synced_draft_workflow = kwargs.get("synced_draft_workflow")
|
12 |
+
if synced_draft_workflow is None:
|
13 |
+
return
|
14 |
+
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
|
15 |
+
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
|
16 |
+
try:
|
17 |
+
tool_entity = ToolEntity(**node_data["data"])
|
18 |
+
tool_runtime = ToolManager.get_tool_runtime(
|
19 |
+
provider_type=tool_entity.provider_type,
|
20 |
+
provider_id=tool_entity.provider_id,
|
21 |
+
tool_name=tool_entity.tool_name,
|
22 |
+
tenant_id=app.tenant_id,
|
23 |
+
)
|
24 |
+
manager = ToolParameterConfigurationManager(
|
25 |
+
tenant_id=app.tenant_id,
|
26 |
+
tool_runtime=tool_runtime,
|
27 |
+
provider_name=tool_entity.provider_name,
|
28 |
+
provider_type=tool_entity.provider_type,
|
29 |
+
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
|
30 |
+
)
|
31 |
+
manager.delete_tool_parameters_cache()
|
32 |
+
except:
|
33 |
+
# tool dose not exist
|
34 |
+
pass
|
api/events/event_handlers/document_index_event.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: document
|
4 |
+
document_index_created = signal("document-index-created")
|
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from events.app_event import app_model_config_was_updated
|
2 |
+
from extensions.ext_database import db
|
3 |
+
from models.dataset import AppDatasetJoin
|
4 |
+
from models.model import AppModelConfig
|
5 |
+
|
6 |
+
|
7 |
+
@app_model_config_was_updated.connect
|
8 |
+
def handle(sender, **kwargs):
|
9 |
+
app = sender
|
10 |
+
app_model_config = kwargs.get("app_model_config")
|
11 |
+
if app_model_config is None:
|
12 |
+
return
|
13 |
+
|
14 |
+
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
15 |
+
|
16 |
+
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
|
17 |
+
|
18 |
+
removed_dataset_ids: set[str] = set()
|
19 |
+
if not app_dataset_joins:
|
20 |
+
added_dataset_ids = dataset_ids
|
21 |
+
else:
|
22 |
+
old_dataset_ids: set[str] = set()
|
23 |
+
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
24 |
+
|
25 |
+
added_dataset_ids = dataset_ids - old_dataset_ids
|
26 |
+
removed_dataset_ids = old_dataset_ids - dataset_ids
|
27 |
+
|
28 |
+
if removed_dataset_ids:
|
29 |
+
for dataset_id in removed_dataset_ids:
|
30 |
+
db.session.query(AppDatasetJoin).filter(
|
31 |
+
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
32 |
+
).delete()
|
33 |
+
|
34 |
+
if added_dataset_ids:
|
35 |
+
for dataset_id in added_dataset_ids:
|
36 |
+
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
37 |
+
db.session.add(app_dataset_join)
|
38 |
+
|
39 |
+
db.session.commit()
|
40 |
+
|
41 |
+
|
42 |
+
def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]:
|
43 |
+
dataset_ids: set[str] = set()
|
44 |
+
if not app_model_config:
|
45 |
+
return dataset_ids
|
46 |
+
|
47 |
+
agent_mode = app_model_config.agent_mode_dict
|
48 |
+
|
49 |
+
tools = agent_mode.get("tools", []) or []
|
50 |
+
for tool in tools:
|
51 |
+
if len(list(tool.keys())) != 1:
|
52 |
+
continue
|
53 |
+
|
54 |
+
tool_type = list(tool.keys())[0]
|
55 |
+
tool_config = list(tool.values())[0]
|
56 |
+
if tool_type == "dataset":
|
57 |
+
dataset_ids.add(tool_config.get("id"))
|
58 |
+
|
59 |
+
# get dataset from dataset_configs
|
60 |
+
dataset_configs = app_model_config.dataset_configs_dict
|
61 |
+
datasets = dataset_configs.get("datasets", {}) or {}
|
62 |
+
for dataset in datasets.get("datasets", []) or []:
|
63 |
+
keys = list(dataset.keys())
|
64 |
+
if len(keys) == 1 and keys[0] == "dataset":
|
65 |
+
if dataset["dataset"].get("id"):
|
66 |
+
dataset_ids.add(dataset["dataset"].get("id"))
|
67 |
+
|
68 |
+
return dataset_ids
|
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import cast
|
2 |
+
|
3 |
+
from core.workflow.nodes import NodeType
|
4 |
+
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
5 |
+
from events.app_event import app_published_workflow_was_updated
|
6 |
+
from extensions.ext_database import db
|
7 |
+
from models.dataset import AppDatasetJoin
|
8 |
+
from models.workflow import Workflow
|
9 |
+
|
10 |
+
|
11 |
+
@app_published_workflow_was_updated.connect
|
12 |
+
def handle(sender, **kwargs):
|
13 |
+
app = sender
|
14 |
+
published_workflow = kwargs.get("published_workflow")
|
15 |
+
published_workflow = cast(Workflow, published_workflow)
|
16 |
+
|
17 |
+
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
|
18 |
+
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
|
19 |
+
|
20 |
+
removed_dataset_ids: set[str] = set()
|
21 |
+
if not app_dataset_joins:
|
22 |
+
added_dataset_ids = dataset_ids
|
23 |
+
else:
|
24 |
+
old_dataset_ids: set[str] = set()
|
25 |
+
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
|
26 |
+
|
27 |
+
added_dataset_ids = dataset_ids - old_dataset_ids
|
28 |
+
removed_dataset_ids = old_dataset_ids - dataset_ids
|
29 |
+
|
30 |
+
if removed_dataset_ids:
|
31 |
+
for dataset_id in removed_dataset_ids:
|
32 |
+
db.session.query(AppDatasetJoin).filter(
|
33 |
+
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
34 |
+
).delete()
|
35 |
+
|
36 |
+
if added_dataset_ids:
|
37 |
+
for dataset_id in added_dataset_ids:
|
38 |
+
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
39 |
+
db.session.add(app_dataset_join)
|
40 |
+
|
41 |
+
db.session.commit()
|
42 |
+
|
43 |
+
|
44 |
+
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
|
45 |
+
dataset_ids: set[str] = set()
|
46 |
+
graph = published_workflow.graph_dict
|
47 |
+
if not graph:
|
48 |
+
return dataset_ids
|
49 |
+
|
50 |
+
nodes = graph.get("nodes", [])
|
51 |
+
|
52 |
+
# fetch all knowledge retrieval nodes
|
53 |
+
knowledge_retrieval_nodes = [
|
54 |
+
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
|
55 |
+
]
|
56 |
+
|
57 |
+
if not knowledge_retrieval_nodes:
|
58 |
+
return dataset_ids
|
59 |
+
|
60 |
+
for node in knowledge_retrieval_nodes:
|
61 |
+
try:
|
62 |
+
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
|
63 |
+
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
|
64 |
+
except Exception as e:
|
65 |
+
continue
|
66 |
+
|
67 |
+
return dataset_ids
|
api/events/event_handlers/update_provider_last_used_at_when_message_created.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import UTC, datetime
|
2 |
+
|
3 |
+
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
4 |
+
from events.message_event import message_was_created
|
5 |
+
from extensions.ext_database import db
|
6 |
+
from models.provider import Provider
|
7 |
+
|
8 |
+
|
9 |
+
@message_was_created.connect
|
10 |
+
def handle(sender, **kwargs):
|
11 |
+
message = sender
|
12 |
+
application_generate_entity = kwargs.get("application_generate_entity")
|
13 |
+
|
14 |
+
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
15 |
+
return
|
16 |
+
|
17 |
+
db.session.query(Provider).filter(
|
18 |
+
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
19 |
+
Provider.provider_name == application_generate_entity.model_conf.provider,
|
20 |
+
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
21 |
+
db.session.commit()
|
api/events/message_event.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: message, kwargs: conversation
|
4 |
+
message_was_created = signal("message-was-created")
|
api/events/tenant_event.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from blinker import signal
|
2 |
+
|
3 |
+
# sender: tenant
|
4 |
+
tenant_was_created = signal("tenant-was-created")
|
5 |
+
|
6 |
+
# sender: tenant
|
7 |
+
tenant_was_updated = signal("tenant-was-updated")
|
api/extensions/__init__.py
ADDED
File without changes
|
api/extensions/ext_app_metrics.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import threading
|
4 |
+
|
5 |
+
from flask import Response
|
6 |
+
|
7 |
+
from configs import dify_config
|
8 |
+
from dify_app import DifyApp
|
9 |
+
|
10 |
+
|
11 |
+
def init_app(app: DifyApp):
|
12 |
+
@app.after_request
|
13 |
+
def after_request(response):
|
14 |
+
"""Add Version headers to the response."""
|
15 |
+
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
16 |
+
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
17 |
+
return response
|
18 |
+
|
19 |
+
@app.route("/health")
|
20 |
+
def health():
|
21 |
+
return Response(
|
22 |
+
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
23 |
+
status=200,
|
24 |
+
content_type="application/json",
|
25 |
+
)
|
26 |
+
|
27 |
+
@app.route("/threads")
|
28 |
+
def threads():
|
29 |
+
num_threads = threading.active_count()
|
30 |
+
threads = threading.enumerate()
|
31 |
+
|
32 |
+
thread_list = []
|
33 |
+
for thread in threads:
|
34 |
+
thread_name = thread.name
|
35 |
+
thread_id = thread.ident
|
36 |
+
is_alive = thread.is_alive()
|
37 |
+
|
38 |
+
thread_list.append(
|
39 |
+
{
|
40 |
+
"name": thread_name,
|
41 |
+
"id": thread_id,
|
42 |
+
"is_alive": is_alive,
|
43 |
+
}
|
44 |
+
)
|
45 |
+
|
46 |
+
return {
|
47 |
+
"pid": os.getpid(),
|
48 |
+
"thread_num": num_threads,
|
49 |
+
"threads": thread_list,
|
50 |
+
}
|
51 |
+
|
52 |
+
@app.route("/db-pool-stat")
|
53 |
+
def pool_stat():
|
54 |
+
from extensions.ext_database import db
|
55 |
+
|
56 |
+
engine = db.engine
|
57 |
+
# TODO: Fix the type error
|
58 |
+
# FIXME maybe its sqlalchemy issue
|
59 |
+
return {
|
60 |
+
"pid": os.getpid(),
|
61 |
+
"pool_size": engine.pool.size(), # type: ignore
|
62 |
+
"checked_in_connections": engine.pool.checkedin(), # type: ignore
|
63 |
+
"checked_out_connections": engine.pool.checkedout(), # type: ignore
|
64 |
+
"overflow_connections": engine.pool.overflow(), # type: ignore
|
65 |
+
"connection_timeout": engine.pool.timeout(), # type: ignore
|
66 |
+
"recycle_time": db.engine.pool._recycle, # type: ignore
|
67 |
+
}
|
api/extensions/ext_blueprints.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
# register blueprint routers
|
7 |
+
|
8 |
+
from flask_cors import CORS # type: ignore
|
9 |
+
|
10 |
+
from controllers.console import bp as console_app_bp
|
11 |
+
from controllers.files import bp as files_bp
|
12 |
+
from controllers.inner_api import bp as inner_api_bp
|
13 |
+
from controllers.service_api import bp as service_api_bp
|
14 |
+
from controllers.web import bp as web_bp
|
15 |
+
|
16 |
+
CORS(
|
17 |
+
service_api_bp,
|
18 |
+
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
19 |
+
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
20 |
+
)
|
21 |
+
app.register_blueprint(service_api_bp)
|
22 |
+
|
23 |
+
CORS(
|
24 |
+
web_bp,
|
25 |
+
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
26 |
+
supports_credentials=True,
|
27 |
+
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
28 |
+
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
29 |
+
expose_headers=["X-Version", "X-Env"],
|
30 |
+
)
|
31 |
+
|
32 |
+
app.register_blueprint(web_bp)
|
33 |
+
|
34 |
+
CORS(
|
35 |
+
console_app_bp,
|
36 |
+
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
37 |
+
supports_credentials=True,
|
38 |
+
allow_headers=["Content-Type", "Authorization"],
|
39 |
+
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
40 |
+
expose_headers=["X-Version", "X-Env"],
|
41 |
+
)
|
42 |
+
|
43 |
+
app.register_blueprint(console_app_bp)
|
44 |
+
|
45 |
+
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
46 |
+
app.register_blueprint(files_bp)
|
47 |
+
|
48 |
+
app.register_blueprint(inner_api_bp)
|
api/extensions/ext_celery.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
|
3 |
+
import pytz
|
4 |
+
from celery import Celery, Task # type: ignore
|
5 |
+
from celery.schedules import crontab # type: ignore
|
6 |
+
|
7 |
+
from configs import dify_config
|
8 |
+
from dify_app import DifyApp
|
9 |
+
|
10 |
+
|
11 |
+
def init_app(app: DifyApp) -> Celery:
|
12 |
+
class FlaskTask(Task):
|
13 |
+
def __call__(self, *args: object, **kwargs: object) -> object:
|
14 |
+
with app.app_context():
|
15 |
+
return self.run(*args, **kwargs)
|
16 |
+
|
17 |
+
broker_transport_options = {}
|
18 |
+
|
19 |
+
if dify_config.CELERY_USE_SENTINEL:
|
20 |
+
broker_transport_options = {
|
21 |
+
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
22 |
+
"sentinel_kwargs": {
|
23 |
+
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
celery_app = Celery(
|
28 |
+
app.name,
|
29 |
+
task_cls=FlaskTask,
|
30 |
+
broker=dify_config.CELERY_BROKER_URL,
|
31 |
+
backend=dify_config.CELERY_BACKEND,
|
32 |
+
task_ignore_result=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
# Add SSL options to the Celery configuration
|
36 |
+
ssl_options = {
|
37 |
+
"ssl_cert_reqs": None,
|
38 |
+
"ssl_ca_certs": None,
|
39 |
+
"ssl_certfile": None,
|
40 |
+
"ssl_keyfile": None,
|
41 |
+
}
|
42 |
+
|
43 |
+
celery_app.conf.update(
|
44 |
+
result_backend=dify_config.CELERY_RESULT_BACKEND,
|
45 |
+
broker_transport_options=broker_transport_options,
|
46 |
+
broker_connection_retry_on_startup=True,
|
47 |
+
worker_log_format=dify_config.LOG_FORMAT,
|
48 |
+
worker_task_log_format=dify_config.LOG_FORMAT,
|
49 |
+
worker_hijack_root_logger=False,
|
50 |
+
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
51 |
+
)
|
52 |
+
|
53 |
+
if dify_config.BROKER_USE_SSL:
|
54 |
+
celery_app.conf.update(
|
55 |
+
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
|
56 |
+
)
|
57 |
+
|
58 |
+
if dify_config.LOG_FILE:
|
59 |
+
celery_app.conf.update(
|
60 |
+
worker_logfile=dify_config.LOG_FILE,
|
61 |
+
)
|
62 |
+
|
63 |
+
celery_app.set_default()
|
64 |
+
app.extensions["celery"] = celery_app
|
65 |
+
|
66 |
+
imports = [
|
67 |
+
"schedule.clean_embedding_cache_task",
|
68 |
+
"schedule.clean_unused_datasets_task",
|
69 |
+
"schedule.create_tidb_serverless_task",
|
70 |
+
"schedule.update_tidb_serverless_status_task",
|
71 |
+
"schedule.clean_messages",
|
72 |
+
"schedule.mail_clean_document_notify_task",
|
73 |
+
]
|
74 |
+
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
75 |
+
beat_schedule = {
|
76 |
+
"clean_embedding_cache_task": {
|
77 |
+
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
|
78 |
+
"schedule": timedelta(days=day),
|
79 |
+
},
|
80 |
+
"clean_unused_datasets_task": {
|
81 |
+
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
|
82 |
+
"schedule": timedelta(days=day),
|
83 |
+
},
|
84 |
+
"create_tidb_serverless_task": {
|
85 |
+
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
|
86 |
+
"schedule": crontab(minute="0", hour="*"),
|
87 |
+
},
|
88 |
+
"update_tidb_serverless_status_task": {
|
89 |
+
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
|
90 |
+
"schedule": timedelta(minutes=10),
|
91 |
+
},
|
92 |
+
"clean_messages": {
|
93 |
+
"task": "schedule.clean_messages.clean_messages",
|
94 |
+
"schedule": timedelta(days=day),
|
95 |
+
},
|
96 |
+
# every Monday
|
97 |
+
"mail_clean_document_notify_task": {
|
98 |
+
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
|
99 |
+
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
|
100 |
+
},
|
101 |
+
}
|
102 |
+
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
103 |
+
|
104 |
+
return celery_app
|
api/extensions/ext_code_based_extension.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.extension.extension import Extension
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
code_based_extension.init()
|
7 |
+
|
8 |
+
|
9 |
+
code_based_extension = Extension()
|
api/extensions/ext_commands.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dify_app import DifyApp
|
2 |
+
|
3 |
+
|
4 |
+
def init_app(app: DifyApp):
|
5 |
+
from commands import (
|
6 |
+
add_qdrant_doc_id_index,
|
7 |
+
convert_to_agent_apps,
|
8 |
+
create_tenant,
|
9 |
+
fix_app_site_missing,
|
10 |
+
reset_email,
|
11 |
+
reset_encrypt_key_pair,
|
12 |
+
reset_password,
|
13 |
+
upgrade_db,
|
14 |
+
vdb_migrate,
|
15 |
+
)
|
16 |
+
|
17 |
+
cmds_to_register = [
|
18 |
+
reset_password,
|
19 |
+
reset_email,
|
20 |
+
reset_encrypt_key_pair,
|
21 |
+
vdb_migrate,
|
22 |
+
convert_to_agent_apps,
|
23 |
+
add_qdrant_doc_id_index,
|
24 |
+
create_tenant,
|
25 |
+
upgrade_db,
|
26 |
+
fix_app_site_missing,
|
27 |
+
]
|
28 |
+
for cmd in cmds_to_register:
|
29 |
+
app.cli.add_command(cmd)
|
api/extensions/ext_compress.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def is_enabled() -> bool:
|
6 |
+
return dify_config.API_COMPRESSION_ENABLED
|
7 |
+
|
8 |
+
|
9 |
+
def init_app(app: DifyApp):
|
10 |
+
from flask_compress import Compress # type: ignore
|
11 |
+
|
12 |
+
compress = Compress()
|
13 |
+
compress.init_app(app)
|
api/extensions/ext_database.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dify_app import DifyApp
|
2 |
+
from models import db
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
db.init_app(app)
|
api/extensions/ext_hosting_provider.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.hosting_configuration import HostingConfiguration
|
2 |
+
|
3 |
+
hosting_configuration = HostingConfiguration()
|
4 |
+
|
5 |
+
|
6 |
+
from dify_app import DifyApp
|
7 |
+
|
8 |
+
|
9 |
+
def init_app(app: DifyApp):
|
10 |
+
hosting_configuration.init_app(app)
|
api/extensions/ext_import_modules.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dify_app import DifyApp
|
2 |
+
|
3 |
+
|
4 |
+
def init_app(app: DifyApp):
|
5 |
+
from events import event_handlers # noqa: F401
|
api/extensions/ext_logging.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import uuid
|
5 |
+
from logging.handlers import RotatingFileHandler
|
6 |
+
|
7 |
+
import flask
|
8 |
+
|
9 |
+
from configs import dify_config
|
10 |
+
from dify_app import DifyApp
|
11 |
+
|
12 |
+
|
13 |
+
def init_app(app: DifyApp):
|
14 |
+
log_handlers: list[logging.Handler] = []
|
15 |
+
log_file = dify_config.LOG_FILE
|
16 |
+
if log_file:
|
17 |
+
log_dir = os.path.dirname(log_file)
|
18 |
+
os.makedirs(log_dir, exist_ok=True)
|
19 |
+
log_handlers.append(
|
20 |
+
RotatingFileHandler(
|
21 |
+
filename=log_file,
|
22 |
+
maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024,
|
23 |
+
backupCount=dify_config.LOG_FILE_BACKUP_COUNT,
|
24 |
+
)
|
25 |
+
)
|
26 |
+
|
27 |
+
# Always add StreamHandler to log to console
|
28 |
+
sh = logging.StreamHandler(sys.stdout)
|
29 |
+
sh.addFilter(RequestIdFilter())
|
30 |
+
log_handlers.append(sh)
|
31 |
+
|
32 |
+
logging.basicConfig(
|
33 |
+
level=dify_config.LOG_LEVEL,
|
34 |
+
format=dify_config.LOG_FORMAT,
|
35 |
+
datefmt=dify_config.LOG_DATEFORMAT,
|
36 |
+
handlers=log_handlers,
|
37 |
+
force=True,
|
38 |
+
)
|
39 |
+
log_tz = dify_config.LOG_TZ
|
40 |
+
if log_tz:
|
41 |
+
from datetime import datetime
|
42 |
+
|
43 |
+
import pytz
|
44 |
+
|
45 |
+
timezone = pytz.timezone(log_tz)
|
46 |
+
|
47 |
+
def time_converter(seconds):
|
48 |
+
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
49 |
+
|
50 |
+
for handler in logging.root.handlers:
|
51 |
+
if handler.formatter:
|
52 |
+
handler.formatter.converter = time_converter
|
53 |
+
|
54 |
+
|
55 |
+
def get_request_id():
|
56 |
+
if getattr(flask.g, "request_id", None):
|
57 |
+
return flask.g.request_id
|
58 |
+
|
59 |
+
new_uuid = uuid.uuid4().hex[:10]
|
60 |
+
flask.g.request_id = new_uuid
|
61 |
+
|
62 |
+
return new_uuid
|
63 |
+
|
64 |
+
|
65 |
+
class RequestIdFilter(logging.Filter):
|
66 |
+
# This is a logging filter that makes the request ID available for use in
|
67 |
+
# the logging format. Note that we're checking if we're in a request
|
68 |
+
# context, as we may want to log things before Flask is fully loaded.
|
69 |
+
def filter(self, record):
|
70 |
+
record.req_id = get_request_id() if flask.has_request_context() else ""
|
71 |
+
return True
|
api/extensions/ext_login.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import flask_login # type: ignore
|
4 |
+
from flask import Response, request
|
5 |
+
from flask_login import user_loaded_from_request, user_logged_in
|
6 |
+
from werkzeug.exceptions import Unauthorized
|
7 |
+
|
8 |
+
import contexts
|
9 |
+
from dify_app import DifyApp
|
10 |
+
from libs.passport import PassportService
|
11 |
+
from services.account_service import AccountService
|
12 |
+
|
13 |
+
login_manager = flask_login.LoginManager()
|
14 |
+
|
15 |
+
|
16 |
+
# Flask-Login configuration
|
17 |
+
@login_manager.request_loader
|
18 |
+
def load_user_from_request(request_from_flask_login):
|
19 |
+
"""Load user based on the request."""
|
20 |
+
if request.blueprint not in {"console", "inner_api"}:
|
21 |
+
return None
|
22 |
+
# Check if the user_id contains a dot, indicating the old format
|
23 |
+
auth_header = request.headers.get("Authorization", "")
|
24 |
+
if not auth_header:
|
25 |
+
auth_token = request.args.get("_token")
|
26 |
+
if not auth_token:
|
27 |
+
raise Unauthorized("Invalid Authorization token.")
|
28 |
+
else:
|
29 |
+
if " " not in auth_header:
|
30 |
+
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
31 |
+
auth_scheme, auth_token = auth_header.split(None, 1)
|
32 |
+
auth_scheme = auth_scheme.lower()
|
33 |
+
if auth_scheme != "bearer":
|
34 |
+
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
35 |
+
|
36 |
+
decoded = PassportService().verify(auth_token)
|
37 |
+
user_id = decoded.get("user_id")
|
38 |
+
|
39 |
+
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
40 |
+
return logged_in_account
|
41 |
+
|
42 |
+
|
43 |
+
@user_logged_in.connect
|
44 |
+
@user_loaded_from_request.connect
|
45 |
+
def on_user_logged_in(_sender, user):
|
46 |
+
"""Called when a user logged in."""
|
47 |
+
if user:
|
48 |
+
contexts.tenant_id.set(user.current_tenant_id)
|
49 |
+
|
50 |
+
|
51 |
+
@login_manager.unauthorized_handler
|
52 |
+
def unauthorized_handler():
|
53 |
+
"""Handle unauthorized requests."""
|
54 |
+
return Response(
|
55 |
+
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
56 |
+
status=401,
|
57 |
+
content_type="application/json",
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def init_app(app: DifyApp):
|
62 |
+
login_manager.init_app(app)
|
api/extensions/ext_mail.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from flask import Flask
|
5 |
+
|
6 |
+
from configs import dify_config
|
7 |
+
from dify_app import DifyApp
|
8 |
+
|
9 |
+
|
10 |
+
class Mail:
|
11 |
+
def __init__(self):
|
12 |
+
self._client = None
|
13 |
+
self._default_send_from = None
|
14 |
+
|
15 |
+
def is_inited(self) -> bool:
|
16 |
+
return self._client is not None
|
17 |
+
|
18 |
+
def init_app(self, app: Flask):
|
19 |
+
mail_type = dify_config.MAIL_TYPE
|
20 |
+
if not mail_type:
|
21 |
+
logging.warning("MAIL_TYPE is not set")
|
22 |
+
return
|
23 |
+
|
24 |
+
if dify_config.MAIL_DEFAULT_SEND_FROM:
|
25 |
+
self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM
|
26 |
+
|
27 |
+
match mail_type:
|
28 |
+
case "resend":
|
29 |
+
import resend # type: ignore
|
30 |
+
|
31 |
+
api_key = dify_config.RESEND_API_KEY
|
32 |
+
if not api_key:
|
33 |
+
raise ValueError("RESEND_API_KEY is not set")
|
34 |
+
|
35 |
+
api_url = dify_config.RESEND_API_URL
|
36 |
+
if api_url:
|
37 |
+
resend.api_url = api_url
|
38 |
+
|
39 |
+
resend.api_key = api_key
|
40 |
+
self._client = resend.Emails
|
41 |
+
case "smtp":
|
42 |
+
from libs.smtp import SMTPClient
|
43 |
+
|
44 |
+
if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
|
45 |
+
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
|
46 |
+
if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
|
47 |
+
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
|
48 |
+
self._client = SMTPClient(
|
49 |
+
server=dify_config.SMTP_SERVER,
|
50 |
+
port=dify_config.SMTP_PORT,
|
51 |
+
username=dify_config.SMTP_USERNAME or "",
|
52 |
+
password=dify_config.SMTP_PASSWORD or "",
|
53 |
+
_from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
|
54 |
+
use_tls=dify_config.SMTP_USE_TLS,
|
55 |
+
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
|
56 |
+
)
|
57 |
+
case _:
|
58 |
+
raise ValueError("Unsupported mail type {}".format(mail_type))
|
59 |
+
|
60 |
+
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
61 |
+
if not self._client:
|
62 |
+
raise ValueError("Mail client is not initialized")
|
63 |
+
|
64 |
+
if not from_ and self._default_send_from:
|
65 |
+
from_ = self._default_send_from
|
66 |
+
|
67 |
+
if not from_:
|
68 |
+
raise ValueError("mail from is not set")
|
69 |
+
|
70 |
+
if not to:
|
71 |
+
raise ValueError("mail to is not set")
|
72 |
+
|
73 |
+
if not subject:
|
74 |
+
raise ValueError("mail subject is not set")
|
75 |
+
|
76 |
+
if not html:
|
77 |
+
raise ValueError("mail html is not set")
|
78 |
+
|
79 |
+
self._client.send(
|
80 |
+
{
|
81 |
+
"from": from_,
|
82 |
+
"to": to,
|
83 |
+
"subject": subject,
|
84 |
+
"html": html,
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def is_enabled() -> bool:
|
90 |
+
return dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
91 |
+
|
92 |
+
|
93 |
+
def init_app(app: DifyApp):
|
94 |
+
mail.init_app(app)
|
95 |
+
|
96 |
+
|
97 |
+
mail = Mail()
|
api/extensions/ext_migrate.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dify_app import DifyApp
|
2 |
+
|
3 |
+
|
4 |
+
def init_app(app: DifyApp):
|
5 |
+
import flask_migrate # type: ignore
|
6 |
+
|
7 |
+
from extensions.ext_database import db
|
8 |
+
|
9 |
+
flask_migrate.Migrate(app, db)
|
api/extensions/ext_proxy_fix.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
7 |
+
from werkzeug.middleware.proxy_fix import ProxyFix
|
8 |
+
|
9 |
+
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
api/extensions/ext_redis.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Union
|
2 |
+
|
3 |
+
import redis
|
4 |
+
from redis.cluster import ClusterNode, RedisCluster
|
5 |
+
from redis.connection import Connection, SSLConnection
|
6 |
+
from redis.sentinel import Sentinel
|
7 |
+
|
8 |
+
from configs import dify_config
|
9 |
+
from dify_app import DifyApp
|
10 |
+
|
11 |
+
|
12 |
+
class RedisClientWrapper:
|
13 |
+
"""
|
14 |
+
A wrapper class for the Redis client that addresses the issue where the global
|
15 |
+
`redis_client` variable cannot be updated when a new Redis instance is returned
|
16 |
+
by Sentinel.
|
17 |
+
|
18 |
+
This class allows for deferred initialization of the Redis client, enabling the
|
19 |
+
client to be re-initialized with a new instance when necessary. This is particularly
|
20 |
+
useful in scenarios where the Redis instance may change dynamically, such as during
|
21 |
+
a failover in a Sentinel-managed Redis setup.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
_client (redis.Redis): The actual Redis client instance. It remains None until
|
25 |
+
initialized with the `initialize` method.
|
26 |
+
|
27 |
+
Methods:
|
28 |
+
initialize(client): Initializes the Redis client if it hasn't been initialized already.
|
29 |
+
__getattr__(item): Delegates attribute access to the Redis client, raising an error
|
30 |
+
if the client is not initialized.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
self._client = None
|
35 |
+
|
36 |
+
def initialize(self, client):
|
37 |
+
if self._client is None:
|
38 |
+
self._client = client
|
39 |
+
|
40 |
+
def __getattr__(self, item):
|
41 |
+
if self._client is None:
|
42 |
+
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
43 |
+
return getattr(self._client, item)
|
44 |
+
|
45 |
+
|
46 |
+
redis_client = RedisClientWrapper()
|
47 |
+
|
48 |
+
|
49 |
+
def init_app(app: DifyApp):
|
50 |
+
global redis_client
|
51 |
+
connection_class: type[Union[Connection, SSLConnection]] = Connection
|
52 |
+
if dify_config.REDIS_USE_SSL:
|
53 |
+
connection_class = SSLConnection
|
54 |
+
|
55 |
+
redis_params: dict[str, Any] = {
|
56 |
+
"username": dify_config.REDIS_USERNAME,
|
57 |
+
"password": dify_config.REDIS_PASSWORD,
|
58 |
+
"db": dify_config.REDIS_DB,
|
59 |
+
"encoding": "utf-8",
|
60 |
+
"encoding_errors": "strict",
|
61 |
+
"decode_responses": False,
|
62 |
+
}
|
63 |
+
|
64 |
+
if dify_config.REDIS_USE_SENTINEL:
|
65 |
+
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
|
66 |
+
sentinel_hosts = [
|
67 |
+
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
|
68 |
+
]
|
69 |
+
sentinel = Sentinel(
|
70 |
+
sentinel_hosts,
|
71 |
+
sentinel_kwargs={
|
72 |
+
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
73 |
+
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
74 |
+
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
75 |
+
},
|
76 |
+
)
|
77 |
+
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
78 |
+
redis_client.initialize(master)
|
79 |
+
elif dify_config.REDIS_USE_CLUSTERS:
|
80 |
+
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
|
81 |
+
nodes = [
|
82 |
+
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
83 |
+
for node in dify_config.REDIS_CLUSTERS.split(",")
|
84 |
+
]
|
85 |
+
# FIXME: mypy error here, try to figure out how to fix it
|
86 |
+
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
|
87 |
+
else:
|
88 |
+
redis_params.update(
|
89 |
+
{
|
90 |
+
"host": dify_config.REDIS_HOST,
|
91 |
+
"port": dify_config.REDIS_PORT,
|
92 |
+
"connection_class": connection_class,
|
93 |
+
}
|
94 |
+
)
|
95 |
+
pool = redis.ConnectionPool(**redis_params)
|
96 |
+
redis_client.initialize(redis.Redis(connection_pool=pool))
|
97 |
+
|
98 |
+
app.extensions["redis"] = redis_client
|
api/extensions/ext_sentry.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
if dify_config.SENTRY_DSN:
|
7 |
+
import openai
|
8 |
+
import sentry_sdk
|
9 |
+
from langfuse import parse_error # type: ignore
|
10 |
+
from sentry_sdk.integrations.celery import CeleryIntegration
|
11 |
+
from sentry_sdk.integrations.flask import FlaskIntegration
|
12 |
+
from werkzeug.exceptions import HTTPException
|
13 |
+
|
14 |
+
from core.model_runtime.errors.invoke import InvokeRateLimitError
|
15 |
+
|
16 |
+
def before_send(event, hint):
|
17 |
+
if "exc_info" in hint:
|
18 |
+
exc_type, exc_value, tb = hint["exc_info"]
|
19 |
+
if parse_error.defaultErrorResponse in str(exc_value):
|
20 |
+
return None
|
21 |
+
|
22 |
+
return event
|
23 |
+
|
24 |
+
sentry_sdk.init(
|
25 |
+
dsn=dify_config.SENTRY_DSN,
|
26 |
+
integrations=[FlaskIntegration(), CeleryIntegration()],
|
27 |
+
ignore_errors=[
|
28 |
+
HTTPException,
|
29 |
+
ValueError,
|
30 |
+
FileNotFoundError,
|
31 |
+
openai.APIStatusError,
|
32 |
+
InvokeRateLimitError,
|
33 |
+
parse_error.defaultErrorResponse,
|
34 |
+
],
|
35 |
+
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
36 |
+
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
37 |
+
environment=dify_config.DEPLOY_ENV,
|
38 |
+
release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
|
39 |
+
before_send=before_send,
|
40 |
+
)
|
api/extensions/ext_set_secretkey.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import dify_config
|
2 |
+
from dify_app import DifyApp
|
3 |
+
|
4 |
+
|
5 |
+
def init_app(app: DifyApp):
|
6 |
+
app.secret_key = dify_config.SECRET_KEY
|
api/extensions/ext_storage.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections.abc import Callable, Generator
|
3 |
+
from typing import Literal, Union, overload
|
4 |
+
|
5 |
+
from flask import Flask
|
6 |
+
|
7 |
+
from configs import dify_config
|
8 |
+
from dify_app import DifyApp
|
9 |
+
from extensions.storage.base_storage import BaseStorage
|
10 |
+
from extensions.storage.storage_type import StorageType
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class Storage:
|
16 |
+
def init_app(self, app: Flask):
|
17 |
+
storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
|
18 |
+
with app.app_context():
|
19 |
+
self.storage_runner = storage_factory()
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]:
|
23 |
+
match storage_type:
|
24 |
+
case StorageType.S3:
|
25 |
+
from extensions.storage.aws_s3_storage import AwsS3Storage
|
26 |
+
|
27 |
+
return AwsS3Storage
|
28 |
+
case StorageType.OPENDAL:
|
29 |
+
from extensions.storage.opendal_storage import OpenDALStorage
|
30 |
+
|
31 |
+
return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME)
|
32 |
+
case StorageType.LOCAL:
|
33 |
+
from extensions.storage.opendal_storage import OpenDALStorage
|
34 |
+
|
35 |
+
return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH)
|
36 |
+
case StorageType.AZURE_BLOB:
|
37 |
+
from extensions.storage.azure_blob_storage import AzureBlobStorage
|
38 |
+
|
39 |
+
return AzureBlobStorage
|
40 |
+
case StorageType.ALIYUN_OSS:
|
41 |
+
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
|
42 |
+
|
43 |
+
return AliyunOssStorage
|
44 |
+
case StorageType.GOOGLE_STORAGE:
|
45 |
+
from extensions.storage.google_cloud_storage import GoogleCloudStorage
|
46 |
+
|
47 |
+
return GoogleCloudStorage
|
48 |
+
case StorageType.TENCENT_COS:
|
49 |
+
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
50 |
+
|
51 |
+
return TencentCosStorage
|
52 |
+
case StorageType.OCI_STORAGE:
|
53 |
+
from extensions.storage.oracle_oci_storage import OracleOCIStorage
|
54 |
+
|
55 |
+
return OracleOCIStorage
|
56 |
+
case StorageType.HUAWEI_OBS:
|
57 |
+
from extensions.storage.huawei_obs_storage import HuaweiObsStorage
|
58 |
+
|
59 |
+
return HuaweiObsStorage
|
60 |
+
case StorageType.BAIDU_OBS:
|
61 |
+
from extensions.storage.baidu_obs_storage import BaiduObsStorage
|
62 |
+
|
63 |
+
return BaiduObsStorage
|
64 |
+
case StorageType.VOLCENGINE_TOS:
|
65 |
+
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
66 |
+
|
67 |
+
return VolcengineTosStorage
|
68 |
+
case StorageType.SUPBASE:
|
69 |
+
from extensions.storage.supabase_storage import SupabaseStorage
|
70 |
+
|
71 |
+
return SupabaseStorage
|
72 |
+
case _:
|
73 |
+
raise ValueError(f"unsupported storage type {storage_type}")
|
74 |
+
|
75 |
+
def save(self, filename, data):
|
76 |
+
try:
|
77 |
+
self.storage_runner.save(filename, data)
|
78 |
+
except Exception as e:
|
79 |
+
logger.exception(f"Failed to save file {filename}")
|
80 |
+
raise e
|
81 |
+
|
82 |
+
@overload
|
83 |
+
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
|
84 |
+
|
85 |
+
@overload
|
86 |
+
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
|
87 |
+
|
88 |
+
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
|
89 |
+
try:
|
90 |
+
if stream:
|
91 |
+
return self.load_stream(filename)
|
92 |
+
else:
|
93 |
+
return self.load_once(filename)
|
94 |
+
except Exception as e:
|
95 |
+
logger.exception(f"Failed to load file {filename}")
|
96 |
+
raise e
|
97 |
+
|
98 |
+
def load_once(self, filename: str) -> bytes:
|
99 |
+
try:
|
100 |
+
return self.storage_runner.load_once(filename)
|
101 |
+
except Exception as e:
|
102 |
+
logger.exception(f"Failed to load_once file {filename}")
|
103 |
+
raise e
|
104 |
+
|
105 |
+
def load_stream(self, filename: str) -> Generator:
|
106 |
+
try:
|
107 |
+
return self.storage_runner.load_stream(filename)
|
108 |
+
except Exception as e:
|
109 |
+
logger.exception(f"Failed to load_stream file {filename}")
|
110 |
+
raise e
|
111 |
+
|
112 |
+
def download(self, filename, target_filepath):
|
113 |
+
try:
|
114 |
+
self.storage_runner.download(filename, target_filepath)
|
115 |
+
except Exception as e:
|
116 |
+
logger.exception(f"Failed to download file {filename}")
|
117 |
+
raise e
|
118 |
+
|
119 |
+
def exists(self, filename):
|
120 |
+
try:
|
121 |
+
return self.storage_runner.exists(filename)
|
122 |
+
except Exception as e:
|
123 |
+
logger.exception(f"Failed to check file exists {filename}")
|
124 |
+
raise e
|
125 |
+
|
126 |
+
def delete(self, filename):
|
127 |
+
try:
|
128 |
+
return self.storage_runner.delete(filename)
|
129 |
+
except Exception as e:
|
130 |
+
logger.exception(f"Failed to delete file {filename}")
|
131 |
+
raise e
|
132 |
+
|
133 |
+
|
134 |
+
storage = Storage()
|
135 |
+
|
136 |
+
|
137 |
+
def init_app(app: DifyApp):
|
138 |
+
storage.init_app(app)
|
api/extensions/ext_timezone.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
from dify_app import DifyApp
|
5 |
+
|
6 |
+
|
7 |
+
def init_app(app: DifyApp):
|
8 |
+
os.environ["TZ"] = "UTC"
|
9 |
+
# windows platform not support tzset
|
10 |
+
if hasattr(time, "tzset"):
|
11 |
+
time.tzset()
|
api/extensions/ext_warnings.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dify_app import DifyApp
|
2 |
+
|
3 |
+
|
4 |
+
def init_app(app: DifyApp):
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
warnings.simplefilter("ignore", ResourceWarning)
|
api/extensions/storage/aliyun_oss_storage.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import posixpath
|
2 |
+
from collections.abc import Generator
|
3 |
+
|
4 |
+
import oss2 as aliyun_s3 # type: ignore
|
5 |
+
|
6 |
+
from configs import dify_config
|
7 |
+
from extensions.storage.base_storage import BaseStorage
|
8 |
+
|
9 |
+
|
10 |
+
class AliyunOssStorage(BaseStorage):
|
11 |
+
"""Implementation for Aliyun OSS storage."""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
|
16 |
+
self.folder = dify_config.ALIYUN_OSS_PATH
|
17 |
+
oss_auth_method = aliyun_s3.Auth
|
18 |
+
region = None
|
19 |
+
if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
|
20 |
+
oss_auth_method = aliyun_s3.AuthV4
|
21 |
+
region = dify_config.ALIYUN_OSS_REGION
|
22 |
+
oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
|
23 |
+
self.client = aliyun_s3.Bucket(
|
24 |
+
oss_auth,
|
25 |
+
dify_config.ALIYUN_OSS_ENDPOINT,
|
26 |
+
self.bucket_name,
|
27 |
+
connect_timeout=30,
|
28 |
+
region=region,
|
29 |
+
)
|
30 |
+
|
31 |
+
def save(self, filename, data):
|
32 |
+
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
33 |
+
|
34 |
+
def load_once(self, filename: str) -> bytes:
|
35 |
+
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
36 |
+
data: bytes = obj.read()
|
37 |
+
return data
|
38 |
+
|
39 |
+
def load_stream(self, filename: str) -> Generator:
|
40 |
+
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
41 |
+
while chunk := obj.read(4096):
|
42 |
+
yield chunk
|
43 |
+
|
44 |
+
def download(self, filename: str, target_filepath):
|
45 |
+
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
|
46 |
+
|
47 |
+
def exists(self, filename: str):
|
48 |
+
return self.client.object_exists(self.__wrapper_folder_filename(filename))
|
49 |
+
|
50 |
+
def delete(self, filename: str):
|
51 |
+
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
52 |
+
|
53 |
+
def __wrapper_folder_filename(self, filename: str) -> str:
|
54 |
+
return posixpath.join(self.folder, filename) if self.folder else filename
|
api/extensions/storage/aws_s3_storage.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections.abc import Generator
|
3 |
+
|
4 |
+
import boto3 # type: ignore
|
5 |
+
from botocore.client import Config # type: ignore
|
6 |
+
from botocore.exceptions import ClientError # type: ignore
|
7 |
+
|
8 |
+
from configs import dify_config
|
9 |
+
from extensions.storage.base_storage import BaseStorage
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class AwsS3Storage(BaseStorage):
|
15 |
+
"""Implementation for Amazon Web Services S3 storage."""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
self.bucket_name = dify_config.S3_BUCKET_NAME
|
20 |
+
if dify_config.S3_USE_AWS_MANAGED_IAM:
|
21 |
+
logger.info("Using AWS managed IAM role for S3")
|
22 |
+
|
23 |
+
session = boto3.Session()
|
24 |
+
region_name = dify_config.S3_REGION
|
25 |
+
self.client = session.client(service_name="s3", region_name=region_name)
|
26 |
+
else:
|
27 |
+
logger.info("Using ak and sk for S3")
|
28 |
+
|
29 |
+
self.client = boto3.client(
|
30 |
+
"s3",
|
31 |
+
aws_secret_access_key=dify_config.S3_SECRET_KEY,
|
32 |
+
aws_access_key_id=dify_config.S3_ACCESS_KEY,
|
33 |
+
endpoint_url=dify_config.S3_ENDPOINT,
|
34 |
+
region_name=dify_config.S3_REGION,
|
35 |
+
config=Config(
|
36 |
+
s3={"addressing_style": dify_config.S3_ADDRESS_STYLE},
|
37 |
+
request_checksum_calculation="when_required",
|
38 |
+
response_checksum_validation="when_required",
|
39 |
+
),
|
40 |
+
)
|
41 |
+
# create bucket
|
42 |
+
try:
|
43 |
+
self.client.head_bucket(Bucket=self.bucket_name)
|
44 |
+
except ClientError as e:
|
45 |
+
# if bucket not exists, create it
|
46 |
+
if e.response["Error"]["Code"] == "404":
|
47 |
+
self.client.create_bucket(Bucket=self.bucket_name)
|
48 |
+
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
|
49 |
+
elif e.response["Error"]["Code"] == "403":
|
50 |
+
pass
|
51 |
+
else:
|
52 |
+
# other error, raise exception
|
53 |
+
raise
|
54 |
+
|
55 |
+
def save(self, filename, data):
|
56 |
+
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
57 |
+
|
58 |
+
def load_once(self, filename: str) -> bytes:
|
59 |
+
try:
|
60 |
+
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
61 |
+
except ClientError as ex:
|
62 |
+
if ex.response["Error"]["Code"] == "NoSuchKey":
|
63 |
+
raise FileNotFoundError("File not found")
|
64 |
+
else:
|
65 |
+
raise
|
66 |
+
return data
|
67 |
+
|
68 |
+
def load_stream(self, filename: str) -> Generator:
|
69 |
+
try:
|
70 |
+
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
71 |
+
yield from response["Body"].iter_chunks()
|
72 |
+
except ClientError as ex:
|
73 |
+
if ex.response["Error"]["Code"] == "NoSuchKey":
|
74 |
+
raise FileNotFoundError("file not found")
|
75 |
+
elif "reached max retries" in str(ex):
|
76 |
+
raise ValueError("please do not request the same file too frequently")
|
77 |
+
else:
|
78 |
+
raise
|
79 |
+
|
80 |
+
def download(self, filename, target_filepath):
|
81 |
+
self.client.download_file(self.bucket_name, filename, target_filepath)
|
82 |
+
|
83 |
+
def exists(self, filename):
|
84 |
+
try:
|
85 |
+
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
86 |
+
return True
|
87 |
+
except:
|
88 |
+
return False
|
89 |
+
|
90 |
+
def delete(self, filename):
|
91 |
+
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
api/extensions/storage/azure_blob_storage.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
+
from datetime import UTC, datetime, timedelta
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
|
6 |
+
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
7 |
+
|
8 |
+
from configs import dify_config
|
9 |
+
from extensions.ext_redis import redis_client
|
10 |
+
from extensions.storage.base_storage import BaseStorage
|
11 |
+
|
12 |
+
|
13 |
+
class AzureBlobStorage(BaseStorage):
|
14 |
+
"""Implementation for Azure Blob storage."""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
|
19 |
+
self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
|
20 |
+
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
|
21 |
+
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
|
22 |
+
|
23 |
+
self.credential: Optional[ChainedTokenCredential] = None
|
24 |
+
if self.account_key == "managedidentity":
|
25 |
+
self.credential = DefaultAzureCredential()
|
26 |
+
else:
|
27 |
+
self.credential = None
|
28 |
+
|
29 |
+
def save(self, filename, data):
|
30 |
+
client = self._sync_client()
|
31 |
+
blob_container = client.get_container_client(container=self.bucket_name)
|
32 |
+
blob_container.upload_blob(filename, data)
|
33 |
+
|
34 |
+
def load_once(self, filename: str) -> bytes:
|
35 |
+
client = self._sync_client()
|
36 |
+
blob = client.get_container_client(container=self.bucket_name)
|
37 |
+
blob = blob.get_blob_client(blob=filename)
|
38 |
+
data: bytes = blob.download_blob().readall()
|
39 |
+
return data
|
40 |
+
|
41 |
+
def load_stream(self, filename: str) -> Generator:
|
42 |
+
client = self._sync_client()
|
43 |
+
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
44 |
+
blob_data = blob.download_blob()
|
45 |
+
yield from blob_data.chunks()
|
46 |
+
|
47 |
+
def download(self, filename, target_filepath):
|
48 |
+
client = self._sync_client()
|
49 |
+
|
50 |
+
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
51 |
+
with open(target_filepath, "wb") as my_blob:
|
52 |
+
blob_data = blob.download_blob()
|
53 |
+
blob_data.readinto(my_blob)
|
54 |
+
|
55 |
+
def exists(self, filename):
|
56 |
+
client = self._sync_client()
|
57 |
+
|
58 |
+
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
59 |
+
return blob.exists()
|
60 |
+
|
61 |
+
def delete(self, filename):
|
62 |
+
client = self._sync_client()
|
63 |
+
|
64 |
+
blob_container = client.get_container_client(container=self.bucket_name)
|
65 |
+
blob_container.delete_blob(filename)
|
66 |
+
|
67 |
+
def _sync_client(self):
|
68 |
+
if self.account_key == "managedidentity":
|
69 |
+
return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
|
70 |
+
|
71 |
+
cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
|
72 |
+
cache_result = redis_client.get(cache_key)
|
73 |
+
if cache_result is not None:
|
74 |
+
sas_token = cache_result.decode("utf-8")
|
75 |
+
else:
|
76 |
+
sas_token = generate_account_sas(
|
77 |
+
account_name=self.account_name or "",
|
78 |
+
account_key=self.account_key or "",
|
79 |
+
resource_types=ResourceTypes(service=True, container=True, object=True),
|
80 |
+
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
81 |
+
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
|
82 |
+
)
|
83 |
+
redis_client.set(cache_key, sas_token, ex=3000)
|
84 |
+
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
|
api/extensions/storage/baidu_obs_storage.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import hashlib
|
3 |
+
from collections.abc import Generator
|
4 |
+
|
5 |
+
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
6 |
+
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
7 |
+
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
8 |
+
|
9 |
+
from configs import dify_config
|
10 |
+
from extensions.storage.base_storage import BaseStorage
|
11 |
+
|
12 |
+
|
13 |
+
class BaiduObsStorage(BaseStorage):
|
14 |
+
"""Implementation for Baidu OBS storage."""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
|
19 |
+
client_config = BceClientConfiguration(
|
20 |
+
credentials=BceCredentials(
|
21 |
+
access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
|
22 |
+
secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
|
23 |
+
),
|
24 |
+
endpoint=dify_config.BAIDU_OBS_ENDPOINT,
|
25 |
+
)
|
26 |
+
|
27 |
+
self.client = BosClient(config=client_config)
|
28 |
+
|
29 |
+
def save(self, filename, data):
|
30 |
+
md5 = hashlib.md5()
|
31 |
+
md5.update(data)
|
32 |
+
content_md5 = base64.standard_b64encode(md5.digest())
|
33 |
+
self.client.put_object(
|
34 |
+
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
|
35 |
+
)
|
36 |
+
|
37 |
+
def load_once(self, filename: str) -> bytes:
|
38 |
+
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
|
39 |
+
data: bytes = response.data.read()
|
40 |
+
return data
|
41 |
+
|
42 |
+
def load_stream(self, filename: str) -> Generator:
|
43 |
+
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
|
44 |
+
while chunk := response.read(4096):
|
45 |
+
yield chunk
|
46 |
+
|
47 |
+
def download(self, filename, target_filepath):
|
48 |
+
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
|
49 |
+
|
50 |
+
def exists(self, filename):
|
51 |
+
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
|
52 |
+
if res is None:
|
53 |
+
return False
|
54 |
+
return True
|
55 |
+
|
56 |
+
def delete(self, filename):
|
57 |
+
self.client.delete_object(bucket_name=self.bucket_name, key=filename)
|
api/extensions/storage/base_storage.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Abstract interface for file storage implementations."""
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from collections.abc import Generator
|
5 |
+
|
6 |
+
|
7 |
+
class BaseStorage(ABC):
|
8 |
+
"""Interface for file storage."""
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def save(self, filename, data):
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def load_once(self, filename: str) -> bytes:
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def load_stream(self, filename: str) -> Generator:
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def download(self, filename, target_filepath):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def exists(self, filename):
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def delete(self, filename):
|
32 |
+
raise NotImplementedError
|
api/extensions/storage/google_cloud_storage.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
from collections.abc import Generator
|
5 |
+
|
6 |
+
from google.cloud import storage as google_cloud_storage # type: ignore
|
7 |
+
|
8 |
+
from configs import dify_config
|
9 |
+
from extensions.storage.base_storage import BaseStorage
|
10 |
+
|
11 |
+
|
12 |
+
class GoogleCloudStorage(BaseStorage):
|
13 |
+
"""Implementation for Google Cloud storage."""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
|
19 |
+
service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
|
20 |
+
# if service_account_json_str is empty, use Application Default Credentials
|
21 |
+
if service_account_json_str:
|
22 |
+
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
|
23 |
+
# convert str to object
|
24 |
+
service_account_obj = json.loads(service_account_json)
|
25 |
+
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
|
26 |
+
else:
|
27 |
+
self.client = google_cloud_storage.Client()
|
28 |
+
|
29 |
+
def save(self, filename, data):
|
30 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
31 |
+
blob = bucket.blob(filename)
|
32 |
+
with io.BytesIO(data) as stream:
|
33 |
+
blob.upload_from_file(stream)
|
34 |
+
|
35 |
+
def load_once(self, filename: str) -> bytes:
|
36 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
37 |
+
blob = bucket.get_blob(filename)
|
38 |
+
data: bytes = blob.download_as_bytes()
|
39 |
+
return data
|
40 |
+
|
41 |
+
def load_stream(self, filename: str) -> Generator:
|
42 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
43 |
+
blob = bucket.get_blob(filename)
|
44 |
+
with blob.open(mode="rb") as blob_stream:
|
45 |
+
while chunk := blob_stream.read(4096):
|
46 |
+
yield chunk
|
47 |
+
|
48 |
+
def download(self, filename, target_filepath):
|
49 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
50 |
+
blob = bucket.get_blob(filename)
|
51 |
+
blob.download_to_filename(target_filepath)
|
52 |
+
|
53 |
+
def exists(self, filename):
|
54 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
55 |
+
blob = bucket.blob(filename)
|
56 |
+
return blob.exists()
|
57 |
+
|
58 |
+
def delete(self, filename):
|
59 |
+
bucket = self.client.get_bucket(self.bucket_name)
|
60 |
+
bucket.delete_blob(filename)
|
api/extensions/storage/huawei_obs_storage.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
+
|
3 |
+
from obs import ObsClient # type: ignore
|
4 |
+
|
5 |
+
from configs import dify_config
|
6 |
+
from extensions.storage.base_storage import BaseStorage
|
7 |
+
|
8 |
+
|
9 |
+
class HuaweiObsStorage(BaseStorage):
|
10 |
+
"""Implementation for Huawei OBS storage."""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
|
16 |
+
self.client = ObsClient(
|
17 |
+
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
18 |
+
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
19 |
+
server=dify_config.HUAWEI_OBS_SERVER,
|
20 |
+
)
|
21 |
+
|
22 |
+
def save(self, filename, data):
|
23 |
+
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
|
24 |
+
|
25 |
+
def load_once(self, filename: str) -> bytes:
|
26 |
+
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
27 |
+
return data
|
28 |
+
|
29 |
+
def load_stream(self, filename: str) -> Generator:
|
30 |
+
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
|
31 |
+
while chunk := response.read(4096):
|
32 |
+
yield chunk
|
33 |
+
|
34 |
+
def download(self, filename, target_filepath):
|
35 |
+
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
|
36 |
+
|
37 |
+
def exists(self, filename):
|
38 |
+
res = self._get_meta(filename)
|
39 |
+
if res is None:
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
def delete(self, filename):
|
44 |
+
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)
|
45 |
+
|
46 |
+
def _get_meta(self, filename):
|
47 |
+
res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename)
|
48 |
+
if res.status < 300:
|
49 |
+
return res
|
50 |
+
else:
|
51 |
+
return None
|
api/extensions/storage/opendal_storage.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from collections.abc import Generator
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import opendal # type: ignore[import]
|
7 |
+
from dotenv import dotenv_values
|
8 |
+
|
9 |
+
from extensions.storage.base_storage import BaseStorage
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"):
|
15 |
+
kwargs = {}
|
16 |
+
config_prefix = prefix + scheme.upper() + "_"
|
17 |
+
for key, value in os.environ.items():
|
18 |
+
if key.startswith(config_prefix):
|
19 |
+
kwargs[key[len(config_prefix) :].lower()] = value
|
20 |
+
|
21 |
+
file_env_vars: dict = dotenv_values(env_file_path) or {}
|
22 |
+
for key, value in file_env_vars.items():
|
23 |
+
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
|
24 |
+
kwargs[key[len(config_prefix) :].lower()] = value
|
25 |
+
|
26 |
+
return kwargs
|
27 |
+
|
28 |
+
|
29 |
+
class OpenDALStorage(BaseStorage):
|
30 |
+
def __init__(self, scheme: str, **kwargs):
|
31 |
+
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
|
32 |
+
|
33 |
+
if scheme == "fs":
|
34 |
+
root = kwargs.get("root", "storage")
|
35 |
+
Path(root).mkdir(parents=True, exist_ok=True)
|
36 |
+
|
37 |
+
self.op = opendal.Operator(scheme=scheme, **kwargs)
|
38 |
+
logger.debug(f"opendal operator created with scheme {scheme}")
|
39 |
+
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
40 |
+
self.op = self.op.layer(retry_layer)
|
41 |
+
logger.debug("added retry layer to opendal operator")
|
42 |
+
|
43 |
+
def save(self, filename: str, data: bytes) -> None:
|
44 |
+
self.op.write(path=filename, bs=data)
|
45 |
+
logger.debug(f"file {filename} saved")
|
46 |
+
|
47 |
+
def load_once(self, filename: str) -> bytes:
|
48 |
+
if not self.exists(filename):
|
49 |
+
raise FileNotFoundError("File not found")
|
50 |
+
|
51 |
+
content: bytes = self.op.read(path=filename)
|
52 |
+
logger.debug(f"file {filename} loaded")
|
53 |
+
return content
|
54 |
+
|
55 |
+
def load_stream(self, filename: str) -> Generator:
|
56 |
+
if not self.exists(filename):
|
57 |
+
raise FileNotFoundError("File not found")
|
58 |
+
|
59 |
+
batch_size = 4096
|
60 |
+
file = self.op.open(path=filename, mode="rb")
|
61 |
+
while chunk := file.read(batch_size):
|
62 |
+
yield chunk
|
63 |
+
logger.debug(f"file {filename} loaded as stream")
|
64 |
+
|
65 |
+
def download(self, filename: str, target_filepath: str):
|
66 |
+
if not self.exists(filename):
|
67 |
+
raise FileNotFoundError("File not found")
|
68 |
+
|
69 |
+
with Path(target_filepath).open("wb") as f:
|
70 |
+
f.write(self.op.read(path=filename))
|
71 |
+
logger.debug(f"file {filename} downloaded to {target_filepath}")
|
72 |
+
|
73 |
+
def exists(self, filename: str) -> bool:
|
74 |
+
# FIXME this is a workaround for opendal python-binding do not have a exists method and no better
|
75 |
+
# error handler here when opendal python-binding has a exists method, we should use it
|
76 |
+
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
|
77 |
+
try:
|
78 |
+
res: bool = self.op.stat(path=filename).mode.is_file()
|
79 |
+
logger.debug(f"file {filename} checked")
|
80 |
+
return res
|
81 |
+
except Exception:
|
82 |
+
return False
|
83 |
+
|
84 |
+
def delete(self, filename: str):
|
85 |
+
if self.exists(filename):
|
86 |
+
self.op.delete(path=filename)
|
87 |
+
logger.debug(f"file {filename} deleted")
|
88 |
+
return
|
89 |
+
logger.debug(f"file {filename} not found, skip delete")
|
api/extensions/storage/oracle_oci_storage.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
+
|
3 |
+
import boto3 # type: ignore
|
4 |
+
from botocore.exceptions import ClientError # type: ignore
|
5 |
+
|
6 |
+
from configs import dify_config
|
7 |
+
from extensions.storage.base_storage import BaseStorage
|
8 |
+
|
9 |
+
|
10 |
+
class OracleOCIStorage(BaseStorage):
|
11 |
+
"""Implementation for Oracle OCI storage."""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.bucket_name = dify_config.OCI_BUCKET_NAME
|
17 |
+
self.client = boto3.client(
|
18 |
+
"s3",
|
19 |
+
aws_secret_access_key=dify_config.OCI_SECRET_KEY,
|
20 |
+
aws_access_key_id=dify_config.OCI_ACCESS_KEY,
|
21 |
+
endpoint_url=dify_config.OCI_ENDPOINT,
|
22 |
+
region_name=dify_config.OCI_REGION,
|
23 |
+
)
|
24 |
+
|
25 |
+
def save(self, filename, data):
|
26 |
+
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
27 |
+
|
28 |
+
def load_once(self, filename: str) -> bytes:
|
29 |
+
try:
|
30 |
+
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
31 |
+
except ClientError as ex:
|
32 |
+
if ex.response["Error"]["Code"] == "NoSuchKey":
|
33 |
+
raise FileNotFoundError("File not found")
|
34 |
+
else:
|
35 |
+
raise
|
36 |
+
return data
|
37 |
+
|
38 |
+
def load_stream(self, filename: str) -> Generator:
|
39 |
+
try:
|
40 |
+
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
41 |
+
yield from response["Body"].iter_chunks()
|
42 |
+
except ClientError as ex:
|
43 |
+
if ex.response["Error"]["Code"] == "NoSuchKey":
|
44 |
+
raise FileNotFoundError("File not found")
|
45 |
+
else:
|
46 |
+
raise
|
47 |
+
|
48 |
+
def download(self, filename, target_filepath):
|
49 |
+
self.client.download_file(self.bucket_name, filename, target_filepath)
|
50 |
+
|
51 |
+
def exists(self, filename):
|
52 |
+
try:
|
53 |
+
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
54 |
+
return True
|
55 |
+
except:
|
56 |
+
return False
|
57 |
+
|
58 |
+
def delete(self, filename):
|
59 |
+
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|