File size: 4,807 Bytes
8810e79
 
 
d429c0c
581cd3b
 
8810e79
 
f26a1bd
8810e79
 
f26a1bd
2200cb1
f26a1bd
166e47c
 
 
 
d429c0c
581cd3b
d429c0c
 
581cd3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d429c0c
2200cb1
8810e79
581cd3b
 
d429c0c
 
8810e79
 
 
 
d429c0c
8810e79
 
f26a1bd
8810e79
 
 
 
 
 
 
 
 
d429c0c
8810e79
 
 
 
 
 
 
 
 
7d2986f
8810e79
 
 
 
 
 
 
 
 
581cd3b
 
8810e79
 
 
 
 
581cd3b
8810e79
 
 
 
 
 
 
 
 
 
 
 
 
d429c0c
 
8810e79
 
 
f26a1bd
8810e79
 
7d2986f
8810e79
2f3318a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from huggingface_hub import InferenceClient
import random
import textwrap
from collections import Counter
import re

# Define the model to be used
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
client = InferenceClient(model)

# Embedded system prompt
system_prompt_text = "You are a smart and helpful co-worker of Thailand based multi-national company PTT, and PTTEP. You help with any kind of request and provide a detailed answer to the question. But if you are asked about something unethical or dangerous, you must refuse and provide a safe and respectful way to handle that."

# Read the content of the info.md file
with open("info.md", "r") as file:
    info_md_content = file.read()

# Chunk the info.md content into smaller sections
chunk_size = 2000  # Adjust this size as needed
info_md_chunks = textwrap.wrap(info_md_content, chunk_size)

def get_relevant_chunks(query, chunks, top_k=2):
    query_tokens = re.findall(r'\w+', query.lower())
    chunk_scores = []

    for chunk in chunks:
        chunk_tokens = re.findall(r'\w+', chunk.lower())
        chunk_counter = Counter(chunk_tokens)
        score = sum(chunk_counter[token] for token in query_tokens)
        chunk_scores.append((score, chunk))

    # Sort chunks by score in descending order and return the top_k chunks
    chunk_scores.sort(reverse=True, key=lambda x: x[0])
    relevant_chunks = [chunk for score, chunk in chunk_scores[:top_k]]
    
    return "\n\n".join(relevant_chunks)

def format_prompt_mixtral(message, history, info_md_chunks):
    prompt = "<s>"
    relevant_chunks = get_relevant_chunks(message, info_md_chunks)
    prompt += f"{relevant_chunks}\n\n"  # Add relevant chunks of info.md at the beginning
    prompt += f"{system_prompt_text}\n\n"  # Add the system prompt

    if history:
        for user_prompt, bot_response in history:
            prompt += f"[INST] {user_prompt} [/INST]"
            prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
    generate_kwargs = dict(
        temperature=temp,
        max_new_tokens=tokens,
        top_p=top_p,
        repetition_penalty=rep_p,
        do_sample=True,
        seed=seed,
    )

    formatted_prompt = format_prompt_mixtral(prompt, history, info_md_chunks)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield [(prompt, output)]
    history.append((prompt, output))
    yield history

def clear_fn():
    return None, None

rand_val = random.randint(1, 1111111111111111)

def check_rand(inp, val):
    if inp:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
    else:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))

with gr.Blocks() as app:
    gr.HTML("""<center><h1 style='font-size:xx-large;'>PTT Chatbot</h1><br><h3> ask anything about PTT </h3><br><h7>EXPERIMENTAL</center>""")
    with gr.Row():
        chat = gr.Chatbot(height=500)
    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                inp = gr.Textbox(label="Prompt", lines=5, interactive=True)
                with gr.Row():
                    with gr.Column(scale=2):
                        btn = gr.Button("Chat")
                    with gr.Column(scale=1):
                        with gr.Group():
                            stop_btn = gr.Button("Stop")
                            clear_btn = gr.Button("Clear")
            with gr.Column(scale=1):
                with gr.Group():
                    rand = gr.Checkbox(label="Random Seed", value=True)
                    seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
                    tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
                    temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)

    hid1 = gr.Number(value=1, visible=False)

    go = btn.click(check_rand, [rand, seed], seed).then(chat_inf, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)

    stop_btn.click(None, None, None, cancels=[go])
    clear_btn.click(clear_fn, None, [inp, chat])

app.queue(default_concurrency_limit=10).launch(share=True, auth=("admin", "0112358"))