File size: 3,110 Bytes
bbcd7ae
08308c1
 
 
bbcd7ae
 
08308c1
bbcd7ae
 
08308c1
d9e57d3
 
08308c1
 
 
bbcd7ae
08308c1
 
 
bbcd7ae
08308c1
 
 
bbcd7ae
08308c1
 
bbcd7ae
08308c1
 
 
 
 
bbcd7ae
08308c1
 
 
 
 
 
 
 
 
 
 
 
bbcd7ae
08308c1
f02dc24
08308c1
 
 
 
 
74d72ae
08308c1
74d72ae
08308c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f02dc24
08308c1
 
f02dc24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import chromadb
import gradio as gr

# Determine the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./SMOL")
model = AutoModelForCausalLM.from_pretrained("./SMOL").to(device)

# Initialize the sentence transformer model
smodel = SentenceTransformer('./embed')

# Initialize the chromadb client and collection
client = chromadb.PersistentClient(path="vectordb")
collection = client.get_or_create_collection("dogdb")

def clean_text_block(text):
    start_keyword = "'documents': [["
    end_keyword = "]], 'uris':"

    start_index = text.find(start_keyword)
    end_index = text.find(end_keyword) + len(end_keyword)

    if start_index != -1 and end_index != -1:
        cleaned_text = text[start_index + len(start_keyword):end_index - len(end_keyword)]
        return cleaned_text
    else:
        return "Keywords not found in the text."

def remove_unwanted_parts(text):
    start_keyword = "system"
    end_keyword = """Respond in a friendly manner; you are an informational about dogs.
assistant"""
    start_idx = text.find(start_keyword)
    end_idx = text.find(end_keyword) + len(end_keyword)
    
    if start_idx != -1 and end_idx != -1:
        cleaned_text = text[:start_idx] + text[end_idx:]
        return cleaned_text.strip()
    else:
        return text

def generate_response(question):
    query = [{'question': f"{question}?"}]
    query_embeddings = smodel.encode(query)
    results = collection.query(
        query_embeddings=query_embeddings,
        n_results=3  # how many results to return
    )
    
    results = clean_text_block(str(results))
    
    messages = [{"role": "user", "content": f"""After the colon is a set of text with information about dogs, then a question about the given text. Please answer the question based off the text, and do not talk about the documentation:
    text - {results}
    question - {question}
    Respond in a friendly manner; you are an informational about dogs."""}]
    
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)
    encoded_inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=True)
    inputs, attention_mask = encoded_inputs["input_ids"].to(device), encoded_inputs["attention_mask"].to(device)
    
    outputs = model.generate(inputs, attention_mask=attention_mask, max_new_tokens=150, temperature=0.4, top_p=0.6, do_sample=True)
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    cleaned_output = remove_unwanted_parts(output_text)
    return cleaned_output

# Create Gradio interface
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Dog Breed Q&A", description="Ask a question about your dog's breed! From over 70 different breeds. You can find the full list under this space's files 'Dog_List'. All done on a CPU!")

# Launch the interface
iface.launch()