File size: 4,340 Bytes
bfda157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b67131
 
 
 
 
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2a0f9e2
1b67131
 
 
 
2952328
2a0f9e2
 
 
1b67131
 
5752674
1b67131
2a0f9e2
1b67131
2a0f9e2
1b67131
2a0f9e2
 
1b67131
2a0f9e2
1b67131
bfda157
1b67131
bfda157
1b67131
bfda157
 
 
1b67131
2a0f9e2
 
 
 
 
bfda157
2a0f9e2
bfda157
2a0f9e2
bfda157
2a0f9e2
 
bfda157
2a0f9e2
bfda157
 
 
 
 
98ef618
 
da64073
bfda157
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import pandas as pd
import requests
import os
import gradio as gr
import json
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

from predibase import Predibase, FinetuningConfig, DeploymentConfig

# Get a KEY from https://app.predibase.com/
api_token  = os.getenv('PREDIBASE_API_KEY')
pb = Predibase(api_token=api_token)

adapter_id = 'tour-assistant-model/14'
lorax_client = pb.deployments.client("solar-1-mini-chat-240612")


def extract_json(gen_text, n_shot_learning=0):
    if(n_shot_learning == -1) :
        start_index = 0
    else :
        start_index = gen_text.index("### Response:\n{") + 14
    if(n_shot_learning > 0) :
        for i in range(0, n_shot_learning):
            gen_text = gen_text[start_index:]
            start_index = gen_text.index("### Response:\n{") + 14
    end_index = gen_text.find("}\n\n### ") + 1
    return gen_text[start_index:end_index]

def get_completion(prompt):
    return lorax_client.generate(prompt, adapter_id=adapter_id, max_new_tokens=1000).generated_text

def greet(input):
    sys_str = "You are a helpful support assistant. Answer the following question."
    qa_list = []
    n_prompt_list = []
    qa_list.append({
        "question": "What are the benefits of joining a union?",
        "answer": "Collective bargaining of salary."
    })

    qa_list.append({
        "question": "How much are union dues, and what do they cover?",
        "answer": "The union dues for our union is 3%."
    })

    qa_list.append({
        "question": "How does the union handle grievances and disputes?",
        "answer": "There will be a panel to oversee disputes"
    })

    qa_list.append({
        "question": "Will joining a union affect my job security?",
        "answer": "No."
    })

    qa_list.append({
        "question": "What is the process for joining a union?",
        "answer": "Please use the contact form."
    })

    qa_list.append({
        "question": "How do unions negotiate contracts with employers?",
        "answer": "Our dear leader will handle the negotiations."
    })

    qa_list.append({
        "question": "What role do I play as a union member?",
        "answer": "You will be invited to our monthly picnics"
    })

    qa_list.append({
        "question": "How do unions ensure that employers comply with agreements?",
        "answer": "We will have a monthly meeting for members"
    })

    qa_list.append({
        "question": "Can I be forced to join a union?",
        "answer": "What kind of questions is that! Of course no!"
    })

    qa_list.append({
        "question": "What happens if I disagree with the union’s decisions?",
        "answer": "We will agree to disagree"
    })

    for qna in qa_list:
      ques_str = qna["question"]
      ans_str = qna["answer"]
      n_prompt_list.append(f"""
<|im_start|>system\n{sys_str}<|im_end|>
<|im_start|>question\n{ques_str}<|im_end|>
<|im_start|>answer\n{ans_str}<|im_end|>
"""
      )

    n_prompt_str = "\n"

    for prompt in n_prompt_list:
      n_prompt_str = n_prompt_str + prompt + "\n"

    total_prompt=f"""
{n_prompt_str}

<|im_start|>system\n{sys_str}<|im_end|>
<|im_start|>question
{input}\n<|im_end|>
<|im_start|>answer
"""


    print("***total_prompt:")
    print(total_prompt)
    response = get_completion(total_prompt)
    #gen_text = response["predictions"][0]["generated_text"]
    #return json.dumps(extract_json(gen_text, 3))

    ###gen_text = response["choices"][0]["text"]

    #return gen_text

    ###return json.dumps(extract_json(gen_text, -1))
    return response

    #return json.dumps(response)

#iface = gr.Interface(fn=greet, inputs="text", outputs="text")
#iface.launch()

#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"])
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Question", lines=3)], outputs="json")
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Question", lines=3)], outputs="text")
iface.queue(api_open=True);
iface.launch()