File size: 4,589 Bytes
854f61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from typing import Annotated

from fastapi import FastAPI, Form, UploadFile
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame

from data_module import data_pipeline, embedding_pipeline, vectorstore
from classification_module import semantic_similarity, dio_support_detector
from enforcement_module import policy_enforcement_decider

from decouple import config

app = FastAPI()

config = {"loader": "pd", 
          "embedding_service": "openai", 
          "api_key": config("OPENAI_API_KEY"), 
          "model_name": "text-embedding-ada-002",
          "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
          "ner_public_url": config("NER_PUBLIC_URL")
          }  # or "pd"

dr = (
    driver.Builder()
    .with_config(config)
    .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
    .build()
)

dr_enforcement = (
    driver.Builder()
    .with_config(config)
    .with_modules(policy_enforcement_decider)
    .build()
)

class RadicalizationDetectionRequest(BaseModel):
    user_text: str
    
class PolicyEnforcementRequest(BaseModel):
    user_text: str
    violation_context: dict

class RadicalizationDetectionResponse(BaseModel):
    """Response to the /detect endpoint"""
    values: dict
    
class PolicyEnforcementResponse(BaseModel):
    """Response to the /generate_policy_enforcement endpoint"""
    values: dict
    
@app.post("/detect_radicalization")
def detect_radicalization(

    request: RadicalizationDetectionRequest

) -> RadicalizationDetectionResponse:
    
    results = dr.execute(
        final_vars=["detect_glorification"],
        inputs={"project_root": ".", "user_input": request.user_text}
    )
    print(results)
    print(type(results))
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return RadicalizationDetectionResponse(values=results)

@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(

    request: PolicyEnforcementRequest

) -> PolicyEnforcementResponse:
    results = dr_enforcement.execute(
        final_vars=["get_enforcement_decision"],
        inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
    )
    print(results)
    print(type(results))
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return PolicyEnforcementResponse(values=results)

# Gradio Interface Functions
def gradio_detect_radicalization(user_text: str):
    request = RadicalizationDetectionRequest(user_text=user_text)
    response = detect_radicalization(request)
    return response.values

def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
    # violation_context needs to be provided in a valid JSON format
    context_dict = eval(violation_context)  # Replace eval with json.loads for safer parsing if it's JSON
    request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
    response = generate_policy_enforcement(request)
    return response.values

# Define the Gradio interface
iface = gr.Interface(
    fn=gradio_detect_radicalization,  # Function to detect radicalization
    inputs="text",  # Single text input
    outputs="json",  # Return JSON output
    title="Radicalization Detection",
    description="Enter text to detect glorification or radicalization."
)

# Second interface for policy enforcement
iface2 = gr.Interface(
    fn=gradio_generate_policy_enforcement,  # Function to generate policy enforcement
    inputs=["text", "text"],  # Two text inputs, one for user text, one for violation context
    outputs="json",  # Return JSON output
    title="Policy Enforcement Decision",
    description="Enter user text and context to generate a policy enforcement decision."
)

# Combine the interfaces in a Tabbed interface
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])

# Start the Gradio interface
iface_combined.launch(server_name="0.0.0.0", server_port=7861)


if __name__ == "__main__":
    import uvicorn
    from threading import Thread

    # Run FastAPI server in a separate thread
    def run_fastapi():
        uvicorn.run(app, host="0.0.0.0", port=8000)
    
    fastapi_thread = Thread(target=run_fastapi)
    fastapi_thread.start()

    # Launch Gradio Interface
    iface_combined.launch()