Spaces:
Sleeping
Sleeping
import os | |
os.system('pip install transformers') | |
os.system('pip install datasets') | |
os.system('pip install gradio') | |
os.system('pip install minijinja') | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from transformers import pipeline | |
from datasets import load_dataset | |
dataset = load_dataset("ibunescu/qa_legal_dataset_train") | |
# Use a pipeline as a high-level helper | |
pipe = pipeline("fill-mask", model="nlpaueb/legal-bert-base-uncased") | |
""" | |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
""" | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
if token is not None: | |
response += token | |
yield response, history + [(message, response)] | |
def score_argument(argument): | |
# Keywords related to legal arguments | |
merits_keywords = ["compelling", "convincing", "strong", "solid"] | |
laws_keywords = ["statute", "law", "regulation", "act"] | |
precedents_keywords = ["precedent", "case", "ruling", "decision"] | |
verdict_keywords = ["guilty", "innocent", "verdict", "judgment"] | |
# Initialize scores | |
merits_score = sum([1 for word in merits_keywords if word in argument.lower()]) | |
laws_score = sum([1 for word in laws_keywords if word in argument.lower()]) | |
precedents_score = sum([1 for word in precedents_keywords if word in argument.lower()]) | |
verdict_score = sum([1 for word in verdict_keywords if word in argument.lower()]) | |
length_score = len(argument.split()) | |
# Additional evaluations for legal standards | |
merits_value = merits_score * 2 # Each keyword in merits is valued at 2 points | |
laws_value = laws_score * 3 # Each keyword in laws is valued at 3 points | |
precedents_value = precedents_score * 4 # Each keyword in precedents is valued at 4 points | |
verdict_value = verdict_score * 5 # Each keyword in verdict is valued at 5 points | |
# Total score: Sum of all individual scores | |
total_score = merits_value + laws_value + precedents_value + verdict_value + length_score | |
return total_score | |
def color_code(score): | |
# Green for high score, yellow for medium, red for low | |
if score > 50: | |
return "green" | |
elif score > 30: | |
return "yellow" | |
else: | |
return "red" | |
# Custom CSS for white background and black text for input and output boxes | |
custom_css = """ | |
body { | |
background-color: #ffffff; | |
color: #000000; | |
font-family: Arial, sans-serif; | |
} | |
.gradio-container { | |
max-width: 1000px; | |
margin: 0 auto; | |
padding: 20px; | |
background-color: #ffffff; | |
border: 1px solid #e0e0e0; | |
border-radius: 8px; | |
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); | |
} | |
.gr-button { | |
background-color: #ffffff !important; | |
border-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-button:hover { | |
background-color: #ffffff !important; | |
border-color: #004085 !important; | |
} | |
.gr-input, .gr-textbox, .gr-slider, .gr-markdown, .gr-chatbox { | |
border-radius: 4px; | |
border: 1px solid #ced4da; | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-input:focus, .gr-textbox:focus, .gr-slider:focus { | |
border-color: #ffffff; | |
outline: 0; | |
box-shadow: 0 0 0 0.2rem rgba(255, 255, 255, 1.0); | |
} | |
#flagging-button { | |
display: none; | |
} | |
footer { | |
display: none; | |
} | |
.chatbox .chat-container .chat-message { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.chatbox .chat-container .chat-message-input { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-markdown { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6, .gr-markdown p, .gr-markdown ul, .gr-markdown ol, .gr-markdown li { | |
color: #000000 !important; | |
} | |
.score-box { | |
width: 60px; | |
height: 60px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-size: 12px; | |
font-weight: bold; | |
color: black; | |
margin: 5px; | |
} | |
""" | |
# Function to facilitate the conversation between the two chatbots | |
def chat_between_bots(system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message): | |
response1, history1 = list(respond(message, history1, system_message1, max_tokens, temperature, top_p))[-1] | |
response2, history2 = list(respond(message, history2, system_message2, max_tokens, temperature, top_p))[-1] | |
shared_history.append(f"Prosecutor: {response1}") | |
shared_history.append(f"Defense Attorney: {response2}") | |
# Ensure the responses are balanced by limiting the length | |
max_length = max(len(response1), len(response2)) | |
response1 = response1[:max_length] | |
response2 = response2[:max_length] | |
# Calculate scores and scoring matrices | |
score1 = score_argument(response1) | |
score2 = score_argument(response2) | |
prosecutor_color = color_code(score1) | |
defense_color = color_code(score2) | |
prosecutor_score_color = f"<div class='score-box' style='background-color:{prosecutor_color};'>Score: {score1}</div>" | |
defense_score_color = f"<div class='score-box' style='background-color:{defense_color};'>Score: {score2}</div>" | |
return response1, response2, history1, history2, shared_history, f"{response1}\n\n{response2}", prosecutor_score_color, defense_score_color | |
# Gradio interface | |
with gr.Blocks(css=custom_css) as demo: | |
history1 = gr.State([]) | |
history2 = gr.State([]) | |
shared_history = gr.State([]) | |
message = gr.Textbox(label="Shared Input Box") | |
system_message1 = gr.State("You are an expert at legal Prosecution.") | |
system_message2 = gr.State("You are an expert at legal Defense.") | |
max_tokens = gr.State(512) # Adjusted to balance response length | |
temperature = gr.State(0.7) | |
top_p = gr.State(0.95) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
prosecutor_response = gr.Textbox(label="Prosecutor's Response", interactive=False) | |
with gr.Column(scale=1): | |
prosecutor_score_color = gr.HTML() | |
with gr.Row(): | |
with gr.Column(scale=4): | |
defense_response = gr.Textbox(label="Defense Attorney's Response", interactive=False) | |
with gr.Column(scale=1): | |
defense_score_color = gr.HTML() | |
shared_argument = gr.Textbox(label="Shared Argument", interactive=False) | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_history, shared_argument, prosecutor_score_color, defense_score_color]) | |
if __name__ == "__main__": | |
demo.launch() | |