Spaces:
Running
Running
tanveeshsingh
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
3 |
-
from jinja2 import Template
|
4 |
-
import torch
|
5 |
-
import json
|
6 |
|
7 |
-
# load the judge
|
8 |
-
device = "cuda:0"
|
9 |
-
model_name = "collinear-ai/collinear-reliability-judge-v8-sauber"
|
10 |
-
model_pipeline = pipeline(task="text-classification", model=model_name, device=device)
|
11 |
|
12 |
# templates
|
13 |
conv_template = Template(
|
@@ -46,30 +46,30 @@ nli_template = Template(
|
|
46 |
|
47 |
# Function to dynamically update inputs based on the input style
|
48 |
def update_inputs(input_style):
|
49 |
-
if input_style == "Conv":
|
50 |
-
|
51 |
-
elif input_style == "NLI":
|
52 |
-
|
53 |
-
elif input_style == "QA format":
|
54 |
-
|
55 |
|
56 |
|
57 |
# Function to judge reliability based on the selected input format
|
58 |
def judge_reliability(input_style, document, conversation, claim, question, answer):
|
59 |
-
with torch.no_grad():
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
return results
|
73 |
|
74 |
|
75 |
|
|
|
1 |
import gradio as gr
|
2 |
+
# from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
3 |
+
# from jinja2 import Template
|
4 |
+
# import torch
|
5 |
+
# import json
|
6 |
|
7 |
+
# # load the judge
|
8 |
+
# device = "cuda:0"
|
9 |
+
# model_name = "collinear-ai/collinear-reliability-judge-v8-sauber"
|
10 |
+
# model_pipeline = pipeline(task="text-classification", model=model_name, device=device)
|
11 |
|
12 |
# templates
|
13 |
conv_template = Template(
|
|
|
46 |
|
47 |
# Function to dynamically update inputs based on the input style
|
48 |
def update_inputs(input_style):
|
49 |
+
# if input_style == "Conv":
|
50 |
+
# return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
51 |
+
# elif input_style == "NLI":
|
52 |
+
# return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
53 |
+
# elif input_style == "QA format":
|
54 |
+
# return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
55 |
|
56 |
|
57 |
# Function to judge reliability based on the selected input format
|
58 |
def judge_reliability(input_style, document, conversation, claim, question, answer):
|
59 |
+
# with torch.no_grad():
|
60 |
+
# if input_style == "Conv":
|
61 |
+
# conversation = json.loads(conversation)
|
62 |
+
# text = conv_template.render(document=document, conversation=conversation)
|
63 |
+
# elif input_style == "NLI":
|
64 |
+
# text = nli_template.render(document=document, claim=claim)
|
65 |
+
# elif input_style == "QA format":
|
66 |
+
# text = qa_template.render(document=document, question=question, answer=answer)
|
67 |
+
|
68 |
+
# print(text)
|
69 |
|
70 |
+
# outputs = model_pipeline(text)
|
71 |
+
# results = f"Reliability Judge Outputs: {outputs}"
|
72 |
+
# return results
|
73 |
|
74 |
|
75 |
|