Spaces:
Sleeping
Sleeping
import os | |
os.system('pip install transformers') | |
os.system('pip install datasets') | |
os.system('pip install gradio') | |
os.system('pip install minijinja') | |
os.system('pip install PyMuPDF') | |
os.system('pip install beautifulsoup4') | |
os.system('pip install requests') | |
import requests | |
from bs4 import BeautifulSoup | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from transformers import pipeline | |
from datasets import load_dataset | |
import fitz # PyMuPDF | |
# Load dataset | |
dataset = load_dataset("ibunescu/qa_legal_dataset_train") | |
# Different pipelines for different tasks | |
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn") | |
mask_filling_pipeline = pipeline("fill-mask", model="nlpaueb/legal-bert-base-uncased") | |
# Inference client for chat completion | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
if token is not None: | |
response += token | |
return response, history + [(message, response)] | |
def generate_case_outcome(prosecutor_response, defense_response): | |
prompt = f"Prosecutor's argument: {prosecutor_response}\nDefense Attorney's argument: {defense_response}\nBased on verified sources, provide the case details and give the outcome along with reasons." | |
evaluation = "" | |
for message in client.chat_completion( | |
[{"role": "system", "content": "Analyze the case and provide the outcome based on verified sources."}, | |
{"role": "user", "content": prompt}], | |
max_tokens=512, | |
stream=True, | |
temperature=0.6, | |
top_p=0.95, | |
): | |
token = message.choices[0].delta.content | |
if token is not None: | |
evaluation += token | |
return evaluation | |
def determine_winner(outcome): | |
# Here, we extract the necessary details to declare the winner | |
winner = "" | |
if "Prosecutor" in outcome and "Defense" in outcome: | |
if outcome.count("Prosecutor") > outcome.count("Defense"): | |
winner = "Prosecutor Wins" | |
else: | |
winner = "Defense Wins" | |
elif "Prosecutor" in outcome: | |
winner = "Prosecutor Wins" | |
elif "Defense" in outcome: | |
winner = "Defense Wins" | |
else: | |
winner = "No clear winner" | |
# Append detailed results from the verified source | |
detailed_result = "Detailed result: " + outcome | |
return winner + "\n\n" + detailed_result | |
def chat_between_bots(system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message): | |
prosecutor_response, history1 = respond(message, history1, system_message1, max_tokens, temperature, top_p) | |
defense_response, history2 = respond(message, history2, system_message2, max_tokens, temperature, top_p) | |
shared_history.append(f"Prosecutor: {prosecutor_response}") | |
shared_history.append(f"Defense Attorney: {defense_response}") | |
outcome = generate_case_outcome(prosecutor_response, defense_response) | |
winner = determine_winner(outcome) | |
return prosecutor_response, defense_response, history1, history2, shared_history, winner | |
def extract_text_from_pdf(pdf_file): | |
text = "" | |
doc = fitz.open(pdf_file) | |
for page in doc: | |
text += page.get_text() | |
return text | |
def ask_about_pdf(pdf_text, question): | |
result = qa_pipeline(question=question, context=pdf_text) | |
return result['answer'] | |
def update_pdf_gallery_and_extract_text(pdf_files): | |
if len(pdf_files) > 0: | |
pdf_text = extract_text_from_pdf(pdf_files[0].name) | |
else: | |
pdf_text = "" | |
return pdf_files, pdf_text | |
def get_top_10_cases(): | |
url = "https://www.courtlistener.com/?order_by=dateFiled+desc" | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
cases = [] | |
for item in soup.select('.search-result', limit=10): | |
case_name = item.select_one('.search-result-title a').text.strip() | |
case_number = item.select_one('.search-result-meta').text.strip().split()[-1] | |
cases.append(f"{case_name} - Case Number: {case_number}") | |
return "\n".join(cases) | |
def add_message(history, message): | |
for x in message["files"]: | |
history.append(((x,), None)) | |
if message["text"] is not None: | |
history.append((message["text"], None)) | |
return history, gr.MultimodalTextbox(value=None, interactive=False) | |
def bot(history): | |
system_message = "You are a helpful assistant." | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=150, | |
stream=True, | |
temperature=0.6, | |
top_p=0.95, | |
): | |
token = message.choices[0].delta.content | |
if token is not None: | |
response += token | |
history[-1][1] = response | |
return history | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def reset_conversation(): | |
return [], [], "", "" | |
def save_conversation(history1, history2, shared_history): | |
return history1, history2, shared_history | |
def ask_about_case_outcome(shared_history, question): | |
result = qa_pipeline(question=question, context=shared_history) | |
return result['answer'] | |
# Custom CSS for a clean layout | |
custom_css = """ | |
body { | |
background-color: #ffffff; | |
color: #000000; | |
font-family: Arial, sans-serif; | |
} | |
.gradio-container { | |
max-width: 1000px; | |
margin: 0 auto; | |
padding: 20px; | |
background-color: #ffffff; | |
border: 1px solid #e0e0e0; | |
border-radius: 8px; | |
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); | |
} | |
.gr-button { | |
background-color: #ffffff !important; | |
border-color: #ffffff !important; | |
color: #000000 !important; | |
margin: 5px; | |
} | |
.gr-button:hover { | |
background-color: #ffffff !important; | |
border-color: #004085 !important; | |
} | |
.gr-input, .gr-textbox, .gr-slider, .gr-markdown, .gr-chatbox { | |
border-radius: 4px; | |
border: 1px solid #ced4da; | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-input:focus, .gr-textbox:focus, .gr-slider:focus { | |
border-color: #ffffff; | |
outline: 0; | |
box-shadow: 0 0 0 0.2rem rgba(255, 255, 255, 1.0); | |
} | |
#flagging-button { | |
display: none; | |
} | |
footer { | |
display: none; | |
} | |
.chatbox .chat-container .chat-message { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.chatbox .chat-container .chat-message-input { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-markdown { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
} | |
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6, .gr-markdown p, .gr-markdown ul, .gr-markdown ol, .gr-markdown li { | |
color: #000000 !important; | |
} | |
.score-box { | |
width: 60px; | |
height: 60px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-size: 12px; | |
font-weight: bold; | |
color: black; | |
margin: 5px; | |
} | |
.scroll-box { | |
max-height: 200px; | |
overflow-y: scroll; | |
border: 1px solid #ced4da; | |
padding: 10px; | |
border-radius: 4px; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
history1 = gr.State([]) | |
history2 = gr.State([]) | |
shared_history = gr.State([]) | |
pdf_files = gr.State([]) | |
pdf_text = gr.State("") | |
top_10_cases = gr.State("") | |
with gr.Tab("Argument Evaluation"): | |
gr.Markdown("# Argument Evaluation", elem_classes=["gr-title"]) | |
gr.Markdown("## Prosecutor vs. Defense Attorney", elem_classes=["gr-subtitle"]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
top_10_btn = gr.Button("Give me the top 10 cases") | |
top_10_output = gr.Markdown(elem_classes=["scroll-box"]) | |
top_10_btn.click(get_top_10_cases, outputs=top_10_output) | |
with gr.Column(scale=2): | |
message = gr.Textbox(label="Enter Case Details to Argue", placeholder="Enter case details here...") | |
system_message1 = gr.State("You are an expert Prosecutor. Give your best arguments for the case on behalf of the prosecution.") | |
system_message2 = gr.State("You are an expert Defense Attorney. Give your best arguments for the case on behalf of the Defense.") | |
max_tokens = gr.State(512) | |
temperature = gr.State(0.6) | |
top_p = gr.State(0.95) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
prosecutor_response = gr.Textbox(label="Prosecutor's Response", interactive=True, elem_classes=["scroll-box"]) | |
with gr.Column(scale=1): | |
prosecutor_score_color = gr.HTML() | |
with gr.Column(scale=4): | |
defense_response = gr.Textbox(label="Defense Attorney's Response", interactive=True, elem_classes=["scroll-box"]) | |
with gr.Column(scale=1): | |
defense_score_color = gr.HTML() | |
winner = gr.Textbox(label="Winner", interactive=False, elem_classes=["scroll-box"]) | |
with gr.Row(): | |
submit_btn = gr.Button("Argue") | |
clear_btn = gr.Button("Clear and Reset") | |
save_btn = gr.Button("Save Conversation") | |
submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_history, winner]) | |
clear_btn.click(reset_conversation, outputs=[history1, history2, shared_history, prosecutor_response, defense_response, winner]) | |
save_btn.click(save_conversation, inputs=[history1, history2, shared_history], outputs=[history1, history2, shared_history]) | |
# Inner HTML for asking about the case outcome | |
with gr.Row(): | |
case_question = gr.Textbox(label="Ask a Question about the Case Outcome", placeholder="Enter your question here...") | |
case_answer = gr.Textbox(label="Answer", interactive=False, elem_classes=["scroll-box"]) | |
ask_case_btn = gr.Button("Ask") | |
ask_case_btn.click(ask_about_case_outcome, inputs=[shared_history, case_question], outputs=case_answer) | |
with gr.Tab("PDF Management"): | |
gr.Markdown("# PDF Management", elem_classes=["gr-title"]) | |
pdf_upload = gr.File(label="Upload Case Files (PDF)", file_types=[".pdf"]) | |
pdf_gallery = gr.Gallery(label="PDF Gallery") | |
pdf_view = gr.Textbox(label="PDF Content", interactive=False, elem_classes=["scroll-box"]) | |
pdf_question = gr.Textbox(label="Ask a Question about the PDF", placeholder="Enter your question here...") | |
pdf_answer = gr.Textbox(label="Answer", interactive=False, elem_classes=["scroll-box"]) | |
pdf_upload_btn = gr.Button("Update PDF Gallery") | |
pdf_ask_btn = gr.Button("Ask") | |
pdf_upload_btn.click(update_pdf_gallery_and_extract_text, inputs=[pdf_upload], outputs=[pdf_gallery, pdf_text]) | |
pdf_text.change(fn=lambda x: x, inputs=pdf_text, outputs=pdf_view) | |
pdf_ask_btn.click(ask_about_pdf, inputs=[pdf_text, pdf_question], outputs=pdf_answer) | |
with gr.Tab("Chatbot"): | |
gr.Markdown("# Chatbot", elem_classes=["gr-title"]) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False | |
) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response") | |
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
chatbot.like(print_like_dislike, None, None) | |
demo.queue() | |
demo.launch() | |