|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./SMOL") |
|
model = AutoModelForCausalLM.from_pretrained("./SMOL").to(device) |
|
|
|
|
|
smodel = SentenceTransformer('./embed') |
|
|
|
|
|
client = chromadb.PersistentClient(path="vectordb") |
|
collection = client.get_or_create_collection("dogdb") |
|
|
|
def clean_text_block(text): |
|
start_keyword = "'documents': [[" |
|
end_keyword = "]], 'uris':" |
|
|
|
start_index = text.find(start_keyword) |
|
end_index = text.find(end_keyword) + len(end_keyword) |
|
|
|
if start_index != -1 and end_index != -1: |
|
cleaned_text = text[start_index + len(start_keyword):end_index - len(end_keyword)] |
|
return cleaned_text |
|
else: |
|
return "Keywords not found in the text." |
|
|
|
def remove_unwanted_parts(text): |
|
start_keyword = "system" |
|
end_keyword = """Respond in a friendly manner; you are an informational about dogs. |
|
assistant""" |
|
start_idx = text.find(start_keyword) |
|
end_idx = text.find(end_keyword) + len(end_keyword) |
|
|
|
if start_idx != -1 and end_idx != -1: |
|
cleaned_text = text[:start_idx] + text[end_idx:] |
|
return cleaned_text.strip() |
|
else: |
|
return text |
|
|
|
def generate_response(question): |
|
query = [{'question': f"{question}?"}] |
|
query_embeddings = smodel.encode(query) |
|
results = collection.query( |
|
query_embeddings=query_embeddings, |
|
n_results=3 |
|
) |
|
|
|
results = clean_text_block(str(results)) |
|
|
|
messages = [{"role": "user", "content": f"""After the colon is a set of text with information about dogs, then a question about the given text. Please answer the question based off the text, and do not talk about the documentation: |
|
text - {results} |
|
question - {question} |
|
Respond in a friendly manner; you are an informational about dogs."""}] |
|
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False) |
|
encoded_inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=True) |
|
inputs, attention_mask = encoded_inputs["input_ids"].to(device), encoded_inputs["attention_mask"].to(device) |
|
|
|
outputs = model.generate(inputs, attention_mask=attention_mask, max_new_tokens=150, temperature=0.4, top_p=0.6, do_sample=True) |
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
cleaned_output = remove_unwanted_parts(output_text) |
|
return cleaned_output |
|
|
|
|
|
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Dog Breed Q&A", description="Ask a question about your dog's breed! From over 70 different breeds. You can find the full list under this space's files 'Dog_List'. All done on a CPU!") |
|
|
|
|
|
iface.launch() |
|
|