File size: 7,564 Bytes
fb37d48
36d7e68
 
 
 
d625ced
 
 
36d7e68
 
d625ced
36d7e68
8b780e6
36d7e68
 
3ebfdcb
36d7e68
 
 
8b780e6
 
36d7e68
 
 
 
 
 
 
 
8b780e6
36d7e68
8b780e6
 
 
 
 
36d7e68
8b780e6
 
 
 
 
 
 
 
 
 
 
36d7e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b780e6
3ebfdcb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
os.system('pip install transformers')
os.system('pip install datasets')
os.system('pip install gradio')
os.system('pip install minijinja')

import gradio as gr
from huggingface_hub import InferenceClient
from transformers import pipeline
from datasets import load_dataset

dataset = load_dataset("ibunescu/qa_legal_dataset_train")

# Use a pipeline as a high-level helper
pipe = pipeline("fill-mask", model="nlpaueb/legal-bert-base-uncased")

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        if token is not None:
            response += token
        yield response, history + [(message, response)]

def score_argument(argument):
    # Keywords related to legal arguments
    merits_keywords = ["compelling", "convincing", "strong", "solid"]
    laws_keywords = ["statute", "law", "regulation", "act"]
    precedents_keywords = ["precedent", "case", "ruling", "decision"]
    verdict_keywords = ["guilty", "innocent", "verdict", "judgment"]

    # Initialize scores
    merits_score = sum([1 for word in merits_keywords if word in argument.lower()])
    laws_score = sum([1 for word in laws_keywords if word in argument.lower()])
    precedents_score = sum([1 for word in precedents_keywords if word in argument.lower()])
    verdict_score = sum([1 for word in verdict_keywords if word in argument.lower()])
    length_score = len(argument.split())

    # Additional evaluations for legal standards
    merits_value = merits_score * 2  # Each keyword in merits is valued at 2 points
    laws_value = laws_score * 3      # Each keyword in laws is valued at 3 points
    precedents_value = precedents_score * 4  # Each keyword in precedents is valued at 4 points
    verdict_value = verdict_score * 5  # Each keyword in verdict is valued at 5 points

    # Total score: Sum of all individual scores
    total_score = merits_value + laws_value + precedents_value + verdict_value + length_score

    return total_score

def color_code(score):
    # Green for high score, yellow for medium, red for low
    if score > 50:
        return "green"
    elif score > 30:
        return "yellow"
    else:
        return "red"

# Custom CSS for white background and black text for input and output boxes
custom_css = """
body {
    background-color: #ffffff;
    color: #000000;
    font-family: Arial, sans-serif;
}
.gradio-container {
    max-width: 1000px;
    margin: 0 auto;
    padding: 20px;
    background-color: #ffffff;
    border: 1px solid #e0e0e0;
    border-radius: 8px;
    box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.gr-button {
    background-color: #ffffff !important;
    border-color: #ffffff !important;
    color: #000000 !important;
}
.gr-button:hover {
    background-color: #ffffff !important;
    border-color: #004085 !important;
}
.gr-input, .gr-textbox, .gr-slider, .gr-markdown, .gr-chatbox {
    border-radius: 4px;
    border: 1px solid #ced4da;
    background-color: #ffffff !important;
    color: #000000 !important;
}
.gr-input:focus, .gr-textbox:focus, .gr-slider:focus {
    border-color: #ffffff;
    outline: 0;
    box-shadow: 0 0 0 0.2rem rgba(255, 255, 255, 1.0);
}
#flagging-button {
    display: none;
}
footer {
    display: none;
}
.chatbox .chat-container .chat-message {
    background-color: #ffffff !important;
    color: #000000 !important;
}
.chatbox .chat-container .chat-message-input {
    background-color: #ffffff !important;
    color: #000000 !important;
}
.gr-markdown {
    background-color: #ffffff !important;
    color: #000000 !important;
}
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6, .gr-markdown p, .gr-markdown ul, .gr-markdown ol, .gr-markdown li {
    color: #000000 !important;
}
.score-box {
    width: 60px;
    height: 60px;
    display: flex;
    align-items: center;
    justify-content: center;
    font-size: 12px;
    font-weight: bold;
    color: black;
    margin: 5px;
}
"""

# Function to facilitate the conversation between the two chatbots
def chat_between_bots(system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message):
    response1, history1 = list(respond(message, history1, system_message1, max_tokens, temperature, top_p))[-1]
    response2, history2 = list(respond(message, history2, system_message2, max_tokens, temperature, top_p))[-1]
    shared_history.append(f"Prosecutor: {response1}")
    shared_history.append(f"Defense Attorney: {response2}")
    
    # Ensure the responses are balanced by limiting the length
    max_length = max(len(response1), len(response2))
    response1 = response1[:max_length]
    response2 = response2[:max_length]
    
    # Calculate scores and scoring matrices
    score1 = score_argument(response1)
    score2 = score_argument(response2)
    
    prosecutor_color = color_code(score1)
    defense_color = color_code(score2)
    
    prosecutor_score_color = f"<div class='score-box' style='background-color:{prosecutor_color};'>Score: {score1}</div>"
    defense_score_color = f"<div class='score-box' style='background-color:{defense_color};'>Score: {score2}</div>"
    
    return response1, response2, history1, history2, shared_history, f"{response1}\n\n{response2}", prosecutor_score_color, defense_score_color

# Gradio interface
with gr.Blocks(css=custom_css) as demo:
    history1 = gr.State([])
    history2 = gr.State([])
    shared_history = gr.State([])
    message = gr.Textbox(label="Shared Input Box")
    system_message1 = gr.State("You are an expert at legal Prosecution.")
    system_message2 = gr.State("You are an expert at legal Defense.")
    max_tokens = gr.State(512)  # Adjusted to balance response length
    temperature = gr.State(0.7)
    top_p = gr.State(0.95)
    
    with gr.Row():
        with gr.Column(scale=4):
            prosecutor_response = gr.Textbox(label="Prosecutor's Response", interactive=False)
        with gr.Column(scale=1):
            prosecutor_score_color = gr.HTML()
            
    with gr.Row():
        with gr.Column(scale=4):
            defense_response = gr.Textbox(label="Defense Attorney's Response", interactive=False)
        with gr.Column(scale=1):
            defense_score_color = gr.HTML()
    
    shared_argument = gr.Textbox(label="Shared Argument", interactive=False)
    
    submit_btn = gr.Button("Submit")

    submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_history, shared_argument, prosecutor_score_color, defense_score_color])
    
if __name__ == "__main__":
    demo.launch()