Hawkeye_AI / app.py
michaelmc1618's picture
Update app.py
36d7e68 verified
raw
history blame
7.56 kB
import os
os.system('pip install transformers')
os.system('pip install datasets')
os.system('pip install gradio')
os.system('pip install minijinja')
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import pipeline
from datasets import load_dataset
dataset = load_dataset("ibunescu/qa_legal_dataset_train")
# Use a pipeline as a high-level helper
pipe = pipeline("fill-mask", model="nlpaueb/legal-bert-base-uncased")
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
if token is not None:
response += token
yield response, history + [(message, response)]
def score_argument(argument):
# Keywords related to legal arguments
merits_keywords = ["compelling", "convincing", "strong", "solid"]
laws_keywords = ["statute", "law", "regulation", "act"]
precedents_keywords = ["precedent", "case", "ruling", "decision"]
verdict_keywords = ["guilty", "innocent", "verdict", "judgment"]
# Initialize scores
merits_score = sum([1 for word in merits_keywords if word in argument.lower()])
laws_score = sum([1 for word in laws_keywords if word in argument.lower()])
precedents_score = sum([1 for word in precedents_keywords if word in argument.lower()])
verdict_score = sum([1 for word in verdict_keywords if word in argument.lower()])
length_score = len(argument.split())
# Additional evaluations for legal standards
merits_value = merits_score * 2 # Each keyword in merits is valued at 2 points
laws_value = laws_score * 3 # Each keyword in laws is valued at 3 points
precedents_value = precedents_score * 4 # Each keyword in precedents is valued at 4 points
verdict_value = verdict_score * 5 # Each keyword in verdict is valued at 5 points
# Total score: Sum of all individual scores
total_score = merits_value + laws_value + precedents_value + verdict_value + length_score
return total_score
def color_code(score):
# Green for high score, yellow for medium, red for low
if score > 50:
return "green"
elif score > 30:
return "yellow"
else:
return "red"
# Custom CSS for white background and black text for input and output boxes
custom_css = """
body {
background-color: #ffffff;
color: #000000;
font-family: Arial, sans-serif;
}
.gradio-container {
max-width: 1000px;
margin: 0 auto;
padding: 20px;
background-color: #ffffff;
border: 1px solid #e0e0e0;
border-radius: 8px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.gr-button {
background-color: #ffffff !important;
border-color: #ffffff !important;
color: #000000 !important;
}
.gr-button:hover {
background-color: #ffffff !important;
border-color: #004085 !important;
}
.gr-input, .gr-textbox, .gr-slider, .gr-markdown, .gr-chatbox {
border-radius: 4px;
border: 1px solid #ced4da;
background-color: #ffffff !important;
color: #000000 !important;
}
.gr-input:focus, .gr-textbox:focus, .gr-slider:focus {
border-color: #ffffff;
outline: 0;
box-shadow: 0 0 0 0.2rem rgba(255, 255, 255, 1.0);
}
#flagging-button {
display: none;
}
footer {
display: none;
}
.chatbox .chat-container .chat-message {
background-color: #ffffff !important;
color: #000000 !important;
}
.chatbox .chat-container .chat-message-input {
background-color: #ffffff !important;
color: #000000 !important;
}
.gr-markdown {
background-color: #ffffff !important;
color: #000000 !important;
}
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6, .gr-markdown p, .gr-markdown ul, .gr-markdown ol, .gr-markdown li {
color: #000000 !important;
}
.score-box {
width: 60px;
height: 60px;
display: flex;
align-items: center;
justify-content: center;
font-size: 12px;
font-weight: bold;
color: black;
margin: 5px;
}
"""
# Function to facilitate the conversation between the two chatbots
def chat_between_bots(system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message):
response1, history1 = list(respond(message, history1, system_message1, max_tokens, temperature, top_p))[-1]
response2, history2 = list(respond(message, history2, system_message2, max_tokens, temperature, top_p))[-1]
shared_history.append(f"Prosecutor: {response1}")
shared_history.append(f"Defense Attorney: {response2}")
# Ensure the responses are balanced by limiting the length
max_length = max(len(response1), len(response2))
response1 = response1[:max_length]
response2 = response2[:max_length]
# Calculate scores and scoring matrices
score1 = score_argument(response1)
score2 = score_argument(response2)
prosecutor_color = color_code(score1)
defense_color = color_code(score2)
prosecutor_score_color = f"<div class='score-box' style='background-color:{prosecutor_color};'>Score: {score1}</div>"
defense_score_color = f"<div class='score-box' style='background-color:{defense_color};'>Score: {score2}</div>"
return response1, response2, history1, history2, shared_history, f"{response1}\n\n{response2}", prosecutor_score_color, defense_score_color
# Gradio interface
with gr.Blocks(css=custom_css) as demo:
history1 = gr.State([])
history2 = gr.State([])
shared_history = gr.State([])
message = gr.Textbox(label="Shared Input Box")
system_message1 = gr.State("You are an expert at legal Prosecution.")
system_message2 = gr.State("You are an expert at legal Defense.")
max_tokens = gr.State(512) # Adjusted to balance response length
temperature = gr.State(0.7)
top_p = gr.State(0.95)
with gr.Row():
with gr.Column(scale=4):
prosecutor_response = gr.Textbox(label="Prosecutor's Response", interactive=False)
with gr.Column(scale=1):
prosecutor_score_color = gr.HTML()
with gr.Row():
with gr.Column(scale=4):
defense_response = gr.Textbox(label="Defense Attorney's Response", interactive=False)
with gr.Column(scale=1):
defense_score_color = gr.HTML()
shared_argument = gr.Textbox(label="Shared Argument", interactive=False)
submit_btn = gr.Button("Submit")
submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_history, shared_argument, prosecutor_score_color, defense_score_color])
if __name__ == "__main__":
demo.launch()