File size: 3,110 Bytes
bbcd7ae 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 d9e57d3 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 bbcd7ae 08308c1 f02dc24 08308c1 74d72ae 08308c1 74d72ae 08308c1 f02dc24 08308c1 f02dc24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import chromadb
import gradio as gr
# Determine the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./SMOL")
model = AutoModelForCausalLM.from_pretrained("./SMOL").to(device)
# Initialize the sentence transformer model
smodel = SentenceTransformer('./embed')
# Initialize the chromadb client and collection
client = chromadb.PersistentClient(path="vectordb")
collection = client.get_or_create_collection("dogdb")
def clean_text_block(text):
start_keyword = "'documents': [["
end_keyword = "]], 'uris':"
start_index = text.find(start_keyword)
end_index = text.find(end_keyword) + len(end_keyword)
if start_index != -1 and end_index != -1:
cleaned_text = text[start_index + len(start_keyword):end_index - len(end_keyword)]
return cleaned_text
else:
return "Keywords not found in the text."
def remove_unwanted_parts(text):
start_keyword = "system"
end_keyword = """Respond in a friendly manner; you are an informational about dogs.
assistant"""
start_idx = text.find(start_keyword)
end_idx = text.find(end_keyword) + len(end_keyword)
if start_idx != -1 and end_idx != -1:
cleaned_text = text[:start_idx] + text[end_idx:]
return cleaned_text.strip()
else:
return text
def generate_response(question):
query = [{'question': f"{question}?"}]
query_embeddings = smodel.encode(query)
results = collection.query(
query_embeddings=query_embeddings,
n_results=3 # how many results to return
)
results = clean_text_block(str(results))
messages = [{"role": "user", "content": f"""After the colon is a set of text with information about dogs, then a question about the given text. Please answer the question based off the text, and do not talk about the documentation:
text - {results}
question - {question}
Respond in a friendly manner; you are an informational about dogs."""}]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
encoded_inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=True)
inputs, attention_mask = encoded_inputs["input_ids"].to(device), encoded_inputs["attention_mask"].to(device)
outputs = model.generate(inputs, attention_mask=attention_mask, max_new_tokens=150, temperature=0.4, top_p=0.6, do_sample=True)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = remove_unwanted_parts(output_text)
return cleaned_output
# Create Gradio interface
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Dog Breed Q&A", description="Ask a question about your dog's breed! From over 70 different breeds. You can find the full list under this space's files 'Dog_List'. All done on a CPU!")
# Launch the interface
iface.launch()
|