tanveeshsingh commited on
Commit
379e15d
·
verified ·
1 Parent(s): 595b81d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -27
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
- from jinja2 import Template
4
- import torch
5
- import json
6
 
7
- # load the judge
8
- device = "cuda:0"
9
- model_name = "collinear-ai/collinear-reliability-judge-v8-sauber"
10
- model_pipeline = pipeline(task="text-classification", model=model_name, device=device)
11
 
12
  # templates
13
  conv_template = Template(
@@ -46,30 +46,30 @@ nli_template = Template(
46
 
47
  # Function to dynamically update inputs based on the input style
48
  def update_inputs(input_style):
49
- if input_style == "Conv":
50
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
51
- elif input_style == "NLI":
52
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
53
- elif input_style == "QA format":
54
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
55
 
56
 
57
  # Function to judge reliability based on the selected input format
58
  def judge_reliability(input_style, document, conversation, claim, question, answer):
59
- with torch.no_grad():
60
- if input_style == "Conv":
61
- conversation = json.loads(conversation)
62
- text = conv_template.render(document=document, conversation=conversation)
63
- elif input_style == "NLI":
64
- text = nli_template.render(document=document, claim=claim)
65
- elif input_style == "QA format":
66
- text = qa_template.render(document=document, question=question, answer=answer)
67
-
68
- print(text)
69
 
70
- outputs = model_pipeline(text)
71
- results = f"Reliability Judge Outputs: {outputs}"
72
- return results
73
 
74
 
75
 
 
1
  import gradio as gr
2
+ # from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+ # from jinja2 import Template
4
+ # import torch
5
+ # import json
6
 
7
+ # # load the judge
8
+ # device = "cuda:0"
9
+ # model_name = "collinear-ai/collinear-reliability-judge-v8-sauber"
10
+ # model_pipeline = pipeline(task="text-classification", model=model_name, device=device)
11
 
12
  # templates
13
  conv_template = Template(
 
46
 
47
  # Function to dynamically update inputs based on the input style
48
  def update_inputs(input_style):
49
+ # if input_style == "Conv":
50
+ # return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
51
+ # elif input_style == "NLI":
52
+ # return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
53
+ # elif input_style == "QA format":
54
+ # return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
55
 
56
 
57
  # Function to judge reliability based on the selected input format
58
  def judge_reliability(input_style, document, conversation, claim, question, answer):
59
+ # with torch.no_grad():
60
+ # if input_style == "Conv":
61
+ # conversation = json.loads(conversation)
62
+ # text = conv_template.render(document=document, conversation=conversation)
63
+ # elif input_style == "NLI":
64
+ # text = nli_template.render(document=document, claim=claim)
65
+ # elif input_style == "QA format":
66
+ # text = qa_template.render(document=document, question=question, answer=answer)
67
+
68
+ # print(text)
69
 
70
+ # outputs = model_pipeline(text)
71
+ # results = f"Reliability Judge Outputs: {outputs}"
72
+ # return results
73
 
74
 
75