File size: 2,300 Bytes
7d587fb
 
eb8aef0
 
 
 
 
 
 
 
 
 
7d587fb
 
 
eb8aef0
 
 
 
 
7d587fb
eb8aef0
 
 
 
 
 
 
 
 
7d587fb
eb8aef0
 
 
 
 
 
 
7d587fb
eb8aef0
7d587fb
eb8aef0
 
 
 
 
7d587fb
eb8aef0
7d587fb
 
 
 
 
eb8aef0
7d587fb
eb8aef0
7d587fb
eb5eb52
 
 
 
 
7d587fb
 
 
eb8aef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr

from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma

from transformers import T5Tokenizer, T5ForConditionalGeneration

embeddings = SentenceTransformerEmbeddings(model_name="msmarco-distilbert-base-v4")
db = Chroma(persist_directory="embeddings", embedding_function=embeddings)

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")


def respond(
        message,
        history: list[tuple[str, str]],
        max_tokens,
        temperature,
        repetition_penalty,
):
    matching_docs = db.similarity_search(message)

    context = ""
    current_length = 0
    for i, doc in enumerate(matching_docs):
        doc_text = f"Document {i + 1}:\n{doc.page_content}\n\n"
        doc_length = len(doc_text.split())
        context += doc_text
        current_length += doc_length

    prompt = (
        f"You are an expert in summarizing and answering questions based on given documents. "
        f"Please provide a detailed and well-explained answer to the following query in 4-6 sentences:\n\n"
        f"Query: {message}\n\n"
        f"Based on the following documents:\n{context}\n\n"
        f"Answer:"
    )

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    outputs = model.generate(input_ids,
                             do_sample=True,
                             max_new_tokens=max_tokens,
                             temperature=temperature,
                             repetition_penalty=repetition_penalty)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=10, value=1.5, step=0.1, label="Repetition penalty"),
    ],
    examples=[
        {"text": "What types of roles are in the system?"},
        {"text": "How to import records into stock receipts in Boost.space?"},
        {"text": "Is it possible to create a PDF export from the product?"}
    ],
)

if __name__ == "__main__":
    demo.launch()