rajkumarrrk commited on
Commit
54a8d6c
1 Parent(s): b90ce03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -26
app.py CHANGED
@@ -4,52 +4,112 @@ from jinja2 import Template
4
  import torch
5
  import json
6
 
7
-
8
  # load the judge
9
  device = "cuda:0"
10
  model_name = "collinear-ai/collinear-reliability-judge-v5"
11
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # template
16
- template = Template(
17
- """
18
- # Document:
19
- {{ document }}
20
 
21
- # Conversation:
22
- {% for message in conversation %}
23
- {{ message.role }}: {{ message.content }}
24
- {% endfor %}
25
- """
 
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def judge_reliability(document: str, conversation: str):
 
30
  with torch.no_grad():
31
- conversation = json.loads(conversation)
32
- text = template.render(document=document, conversation=conversation)
 
 
 
 
 
 
33
  print(text)
 
34
  encoded = tokenizer([text], padding=True)
35
  input_ids = torch.tensor(encoded.input_ids).to(device)
36
  attention_mask = torch.tensor(encoded.attention_mask).to(device)
37
  outputs = model.forward(input_ids=input_ids, attention_mask=attention_mask)
38
  outputs = torch.softmax(outputs.logits, axis=1)
39
- results = f"Reliability Score: {outputs}"
40
  return results
41
 
42
- demo = gr.Interface(
43
- fn=judge_reliability,
44
- inputs=[
45
- gr.Textbox(label="Document", lines=5, value="Chris Voss, was born in Iowa, USA. He is the best negotiator in the world."),
46
- gr.Textbox(label="Conversation", lines=5, value='[{"role": "user", "content": "Where are you born?"}, {"role": "assistant", "content": "I am born in Iowa"}]')
47
- ],
48
- outputs=gr.Textbox(label="Results"),
49
- title="Collinear Reliability Judge",
50
- description="Enter a document and conversation (json formatted) to judge reliability. Note: this judges if the last assistant turn is faithful according to the given document ",
51
- theme="default"
52
- )
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if __name__ == "__main__":
55
  demo.launch()
 
4
  import torch
5
  import json
6
 
 
7
  # load the judge
8
  device = "cuda:0"
9
  model_name = "collinear-ai/collinear-reliability-judge-v5"
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
+ # templates
14
+ conv_template = Template(
15
+ """
16
+ # Document:
17
+ {{ document }}
18
+
19
+ # Conversation:
20
+ {% for message in conversation %}
21
+ {{ message.role }}: {{ message.content }}
22
+ {% endfor %}
23
+ """
24
+ )
25
 
26
+ qa_template = Template(
27
+ """
28
+ # Document:
29
+ {{ document }}
 
30
 
31
+ # Question:
32
+ {{ question }}
33
+
34
+ # Answer:
35
+ {{ answer }}
36
+ """
37
  )
38
 
39
+ nli_template = Template(
40
+ """
41
+ # Document:
42
+ {{ document }}
43
+
44
+ # Claim:
45
+ {{ claim }}
46
+ """
47
+ )
48
+
49
+
50
+ # Function to dynamically update inputs based on the input style
51
+ def update_inputs(input_style):
52
+ if input_style == "Conv":
53
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
54
+ elif input_style == "NLI":
55
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
56
+ elif input_style == "QA format":
57
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
58
+
59
 
60
+ # Function to judge reliability based on the selected input format
61
+ def judge_reliability(input_style, document, conversation, claim, question, answer):
62
  with torch.no_grad():
63
+ if input_style == "Conv":
64
+ conversation = json.loads(conversation)
65
+ text = conv_template.render(document=document, conversation=conversation)
66
+ elif input_style == "NLI":
67
+ text = nli_template.render(document=document, claim=claim)
68
+ elif input_style == "QA format":
69
+ text = qa_template.render(document=document, question=question, answer=answer)
70
+
71
  print(text)
72
+
73
  encoded = tokenizer([text], padding=True)
74
  input_ids = torch.tensor(encoded.input_ids).to(device)
75
  attention_mask = torch.tensor(encoded.attention_mask).to(device)
76
  outputs = model.forward(input_ids=input_ids, attention_mask=attention_mask)
77
  outputs = torch.softmax(outputs.logits, axis=1)
78
+ results = f"Reliability Score: {outputs[0][1].item()}"
79
  return results
80
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+
83
+ # Create the interface using gr.Blocks
84
+ with gr.Blocks() as demo:
85
+ with gr.Row():
86
+ input_style_dropdown = gr.Dropdown(label="Input Style", choices=["Conv", "NLI", "QA format"], value="Conv", visible=True)
87
+
88
+ with gr.Row():
89
+ document_input = gr.Textbox(label="Document", lines=5, visible=True, value="Chris Voss is one of the best negotiators in the world. And he was born in Iowa, USA.")
90
+ conversation_input = gr.Textbox(label="Conversation", lines=5, visible=True, value='[{"role": "user", "content": "Hi Chris Voss, Where are you born?"}, {"role": "assistant", "content": "I am born in Iowa"}]')
91
+ claim_input = gr.Textbox(label="Claim", lines=5, visible=False, value="CV was born in Iowa")
92
+ question_input = gr.Textbox(label="Question", lines=5, visible=False, value="Where is Chris Voss born?")
93
+ answer_input = gr.Textbox(label="Answer", lines=5, visible=False, value="CV was born in Iowa")
94
+
95
+ with gr.Row():
96
+ result_output = gr.Textbox(label="Results")
97
+
98
+
99
+ # Set the visibility of inputs based on the selected input style
100
+ input_style_dropdown.change(
101
+ fn=update_inputs,
102
+ inputs=[input_style_dropdown],
103
+ outputs=[document_input, conversation_input, claim_input, question_input, answer_input]
104
+ )
105
+
106
+ # Set the function to handle the reliability check
107
+ gr.Button("Submit").click(
108
+ fn=judge_reliability,
109
+ inputs=[input_style_dropdown, document_input, conversation_input, claim_input, question_input, answer_input],
110
+ outputs=result_output
111
+ )
112
+
113
+ # Launch the demo
114
  if __name__ == "__main__":
115
  demo.launch()