mariiapaik commited on
Commit
0e62273
·
verified ·
1 Parent(s): 1609564

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import os
7
+
8
+ # 📌 1. Загружаем LLaMA 3
9
+ MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
12
+
13
+ # 📌 2. Загружаем Sentence Transformer для эмбеддингов
14
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
15
+
16
+ # 📌 3. Загружаем свою базу знаний
17
+ def load_documents():
18
+ knowledge_base = []
19
+ for file_name in os.listdir("files"):
20
+ file_path = os.path.join("files", file_name)
21
+ with open(file_path, "r", encoding="utf-8") as file:
22
+ text = file.read()
23
+ knowledge_base.append(text)
24
+ return knowledge_base
25
+
26
+ documents = load_documents()
27
+ document_embeddings = embedder.encode(documents, convert_to_tensor=True)
28
+
29
+ # 📌 4. Создаем FAISS-индекс
30
+ index = faiss.IndexFlatL2(document_embeddings.shape[1])
31
+ index.add(document_embeddings.cpu().numpy())
32
+
33
+ # 📌 5. Функция поиска релевантной информации
34
+ def retrieve_relevant_info(query, top_k=2):
35
+ query_embedding = embedder.encode([query], convert_to_tensor=True)
36
+ query_embedding = query_embedding.cpu().numpy()
37
+ distances, indices = index.search(query_embedding, top_k)
38
+ retrieved_docs = [documents[idx] for idx in indices[0]]
39
+ return " ".join(retrieved_docs)
40
+
41
+ # 📌 6. Функция генерации ответа
42
+ def generate_response(query):
43
+ relevant_info = retrieve_relevant_info(query)
44
+ input_text = f"Context: {relevant_info}\nQuestion: {query}\nAnswer:"
45
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
46
+ outputs = model.generate(**inputs, max_length=200)
47
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # 📌 7. Gradio-интерфейс
50
+ interface = gr.Interface(
51
+ fn=generate_response,
52
+ inputs=gr.Textbox(lines=2, placeholder="Введите ваш вопрос..."),
53
+ outputs=gr.Textbox(),
54
+ title="RAG с LLaMA 3",
55
+ description="Этот чатбот использует RAG (Retrieval-Augmented Generation) с LLaMA 3 и вашими документами."
56
+ )
57
+
58
+ interface.launch()