|
import gradio as gr
|
|
from typing import Annotated
|
|
|
|
from fastapi import FastAPI, Form, UploadFile
|
|
from pydantic import BaseModel
|
|
from hamilton import driver
|
|
from pandas import DataFrame
|
|
|
|
from data_module import data_pipeline, embedding_pipeline, vectorstore
|
|
from classification_module import semantic_similarity, dio_support_detector
|
|
from enforcement_module import policy_enforcement_decider
|
|
|
|
from decouple import config
|
|
|
|
app = FastAPI()
|
|
|
|
config = {"loader": "pd",
|
|
"embedding_service": "openai",
|
|
"api_key": config("OPENAI_API_KEY"),
|
|
"model_name": "text-embedding-ada-002",
|
|
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
|
|
"ner_public_url": config("NER_PUBLIC_URL")
|
|
}
|
|
|
|
dr = (
|
|
driver.Builder()
|
|
.with_config(config)
|
|
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
|
|
.build()
|
|
)
|
|
|
|
dr_enforcement = (
|
|
driver.Builder()
|
|
.with_config(config)
|
|
.with_modules(policy_enforcement_decider)
|
|
.build()
|
|
)
|
|
|
|
class RadicalizationDetectionRequest(BaseModel):
|
|
user_text: str
|
|
|
|
class PolicyEnforcementRequest(BaseModel):
|
|
user_text: str
|
|
violation_context: dict
|
|
|
|
class RadicalizationDetectionResponse(BaseModel):
|
|
"""Response to the /detect endpoint"""
|
|
values: dict
|
|
|
|
class PolicyEnforcementResponse(BaseModel):
|
|
"""Response to the /generate_policy_enforcement endpoint"""
|
|
values: dict
|
|
|
|
@app.post("/detect_radicalization")
|
|
def detect_radicalization(
|
|
request: RadicalizationDetectionRequest
|
|
) -> RadicalizationDetectionResponse:
|
|
|
|
results = dr.execute(
|
|
final_vars=["detect_glorification"],
|
|
inputs={"project_root": ".", "user_input": request.user_text}
|
|
)
|
|
print(results)
|
|
print(type(results))
|
|
if isinstance(results, DataFrame):
|
|
results = results.to_dict(orient="dict")
|
|
return RadicalizationDetectionResponse(values=results)
|
|
|
|
@app.post("/generate_policy_enforcement")
|
|
def generate_policy_enforcement(
|
|
request: PolicyEnforcementRequest
|
|
) -> PolicyEnforcementResponse:
|
|
results = dr_enforcement.execute(
|
|
final_vars=["get_enforcement_decision"],
|
|
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
|
)
|
|
print(results)
|
|
print(type(results))
|
|
if isinstance(results, DataFrame):
|
|
results = results.to_dict(orient="dict")
|
|
return PolicyEnforcementResponse(values=results)
|
|
|
|
|
|
def gradio_detect_radicalization(user_text: str):
|
|
request = RadicalizationDetectionRequest(user_text=user_text)
|
|
response = detect_radicalization(request)
|
|
return response.values
|
|
|
|
def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
|
|
|
|
context_dict = eval(violation_context)
|
|
request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
|
|
response = generate_policy_enforcement(request)
|
|
return response.values
|
|
|
|
|
|
iface = gr.Interface(
|
|
fn=gradio_detect_radicalization,
|
|
inputs="text",
|
|
outputs="json",
|
|
title="Radicalization Detection",
|
|
description="Enter text to detect glorification or radicalization."
|
|
)
|
|
|
|
|
|
iface2 = gr.Interface(
|
|
fn=gradio_generate_policy_enforcement,
|
|
inputs=["text", "text"],
|
|
outputs="json",
|
|
title="Policy Enforcement Decision",
|
|
description="Enter user text and context to generate a policy enforcement decision."
|
|
)
|
|
|
|
|
|
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
|
|
|
|
|
|
iface_combined.launch(server_name="0.0.0.0", server_port=7861)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
from threading import Thread
|
|
|
|
|
|
def run_fastapi():
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
fastapi_thread = Thread(target=run_fastapi)
|
|
fastapi_thread.start()
|
|
|
|
|
|
iface_combined.launch() |