Chris4K commited on
Commit
9f36b00
·
verified ·
1 Parent(s): abaf9f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -158
app.py CHANGED
@@ -1,161 +1,138 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
- import torch
 
3
  import gradio as gr
4
-
5
- # Load Llama 3.2 model
6
- model_name = "meta-llama/Llama-3.2-3B-Instruct" # Replace with the exact model path
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
9
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None, torch_dtype=torch.float32)
10
-
11
- # Helper function to process long contexts
12
- MAX_TOKENS = 100000 # Replace with the max token limit of the Llama model
13
-
14
-
15
- #########
16
- ###
17
- #########
18
- import faiss
19
- import torch
20
- import pandas as pd
21
- from sentence_transformers import SentenceTransformer
22
- from transformers import AutoTokenizer, AutoModelForCausalLM
23
- import gradio as gr
24
-
25
- # Load Llama model
26
- #model_name = "meta-llama/Llama-3.2-3B-Instruct" # Replace with the exact model path
27
- #tokenizer = AutoTokenizer.from_pretrained(model_name)
28
- #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
29
-
30
- # Load Sentence Transformer model for embeddings
31
- embedder = SentenceTransformer('distiluse-base-multilingual-cased') # Suitable for German text
32
-
33
- ########
34
- ###
35
- ###
36
- #####
37
- # Load the CSV data
38
- url = 'https://www.bofrost.de/datafeed/DE/products.csv'
39
- data = pd.read_csv(url, sep='|')
40
-
41
- # List of columns to keep
42
- columns_to_keep = [
43
- 'ID', 'Name', 'Description', 'Price',
44
- 'ProductCategory', 'Grammage',
45
- 'BasePriceText', 'Rating', 'RatingCount',
46
- 'Ingredients', 'CreationDate', 'Keywords', 'Brand'
47
- ]
48
-
49
- # Filter the DataFrame
50
- data_cleaned = data[columns_to_keep]
51
-
52
- # Remove unwanted characters from the 'Description' column
53
- data_cleaned['Description'] = data_cleaned['Description'].str.replace(r'[^\w\s.,;:\'"/?!€$%&()\[\]{}<>|=+\\-]', ' ', regex=True)
54
-
55
- # Combine relevant text columns for embedding
56
- data_cleaned['combined_text'] = data_cleaned.apply(lambda row: ' '.join([str(row[col]) for col in ['Name', 'Description', 'Keywords'] if pd.notnull(row[col])]), axis=1)
57
-
58
- ######
59
- ##
60
- #####
61
-
62
- # Generate embeddings for the combined text
63
- embeddings = embedder.encode(data_cleaned['combined_text'].tolist(), convert_to_tensor=True)
64
-
65
- # Convert embeddings to numpy array
66
- embeddings = embeddings.cpu().detach().numpy()
67
-
68
- # Initialize FAISS index
69
- d = embeddings.shape[1] # Dimension of embeddings
70
- faiss_index = faiss.IndexFlatL2(d)
71
-
72
- # Add embeddings to the index
73
- faiss_index.add(embeddings)
74
-
75
- #######
76
- ##
77
- ######
78
- def search_products(query, top_k=7):
79
- # Generate embedding for the query
80
- query_embedding = embedder.encode([query], convert_to_tensor=True).cpu().detach().numpy()
81
-
82
- # Search FAISS index
83
- distances, indices = faiss_index.search(query_embedding, top_k)
84
-
85
- # Retrieve corresponding products
86
- results = data_cleaned.iloc[indices[0]].to_dict(orient='records')
87
- return results
88
-
89
-
90
-
91
- # Update the prompt construction to include ChromaDB results
92
- def construct_system_prompt( context):
93
- prompt = f"You are a friendly bot specializing in Bofrost products. Return comprehensive german answers. Always add product ids. Use the following product descriptions:\n\n{context}\n\n"
94
- return prompt
95
-
96
- # Helper function to construct the prompt
97
- def construct_prompt(user_input, context, chat_history, max_history_turns=1): # Added max_history_turns
98
- system_message = construct_system_prompt(context)
99
- prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Limit history to the last max_history_turns
102
- for i, (user_msg, assistant_msg) in enumerate(chat_history[-max_history_turns:]):
103
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>"
104
- prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>"
105
 
106
- prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
107
- print("-------------------------")
108
- print(prompt)
109
- return prompt
110
-
111
- def chat_with_model(user_input, chat_history=[]):
112
- # Search for relevant products
113
- search_results = search_products(user_input)
114
-
115
- # Create context with search results
116
- if search_results:
117
- context = "Product Context:\n"
118
- for product in search_results:
119
- context += f"Produkt ID: {product['ID']}\n"
120
- context += f"Name: {product['Name']}\n"
121
- context += f"Beschreibung: {product['Description']}\n"
122
- context += f"Preis: {product['Price']}€\n"
123
- context += f"Bewertung: {product['Rating']} ({product['RatingCount']} Bewertungen)\n"
124
- context += f"Kategorie: {product['ProductCategory']}\n"
125
- context += f"Marke: {product['Brand']}\n"
126
- context += "---\n"
127
- else:
128
- context = "Das weiß ich nicht."
129
- print("context: ------------------------------------- \n"+context)
130
- # Pass both user_input and context to construct_prompt
131
- prompt = construct_prompt(user_input, context, chat_history) # This line is changed
132
- print("prompt: ------------------------------------- \n"+prompt)
133
- input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cpu")
134
- tokenizer.pad_token = tokenizer.eos_token
135
- attention_mask = torch.ones_like(input_ids).to("cpu")
136
- outputs = model.generate(input_ids, attention_mask=attention_mask,
137
- max_new_tokens=1200, do_sample=True,
138
- top_k=50, temperature=0.7)
139
- response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
140
- print("respone: ------------------------------------- \n"+response)
141
- chat_history.append((context, response)) # or chat_history.append((user_input, response)) if you want to store user input
142
- return response, chat_history
143
-
144
- #####
145
- ###
146
- ###
147
- # Gradio Interface
148
- def gradio_interface(user_input, history):
149
- response, updated_history = chat_with_model(user_input, history)
150
- return response, updated_history
151
-
152
- with gr.Blocks() as demo:
153
- gr.Markdown("# 🦙 Llama Instruct Chat with ChromaDB Integration")
154
- with gr.Row():
155
- user_input = gr.Textbox(label="Your Message", lines=2, placeholder="Type your message here...")
156
- submit_btn = gr.Button("Send")
157
- chat_history = gr.State([])
158
- chat_display = gr.Textbox(label="Chat Response", lines=10, placeholder="Chat history will appear here...", interactive=False)
159
- submit_btn.click(gradio_interface, inputs=[user_input, chat_history], outputs=[chat_display, chat_history])
160
-
161
- demo.launch(debug=True)
 
1
+ # main.py
2
+ from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  import gradio as gr
5
+ from services.chat_service import ChatService
6
+ from services.data_service import DataService
7
+ from services.faq_service import FAQService
8
+ from auth.auth_handler import get_api_key
9
+ from models.base_models import UserInput, SearchQuery
10
+ import logging
11
+ import asyncio
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
17
+ handlers=[
18
+ logging.FileHandler('chatbot.log'),
19
+ logging.StreamHandler()
20
+ ]
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Initialize FastAPI app
25
+ app = FastAPI(title="Bofrost Chat API", version="2.0.0")
26
+
27
+ # Add CORS middleware
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Initialize services
37
+ model_service = ModelService()
38
+ data_service = DataService(model_service)
39
+ faq_service = FAQService(model_service)
40
+ chat_service = ChatService(model_service, data_service, faq_service)
41
+
42
+ # API endpoints
43
+ @app.post("/api/chat")
44
+ async def chat_endpoint(
45
+ user_input: UserInput,
46
+ api_key: str = Depends(get_api_key),
47
+ background_tasks: BackgroundTasks
48
+ ):
49
+ try:
50
+ response, updated_history, search_results = await chat_service.chat(
51
+ user_input.user_input,
52
+ user_input.chat_history
53
+ )
54
+ return {
55
+ "status": "success",
56
+ "response": response,
57
+ "chat_history": updated_history,
58
+ "search_results": search_results
59
+ }
60
+ except Exception as e:
61
+ logger.error(f"Error in chat endpoint: {e}")
62
+ raise HTTPException(status_code=500, detail=str(e))
63
+
64
+ @app.post("/api/search")
65
+ async def search_endpoint(
66
+ query: SearchQuery,
67
+ api_key: str = Depends(get_api_key)
68
+ ):
69
+ try:
70
+ results = await data_service.search(query.query, query.top_k)
71
+ return {"results": results}
72
+ except Exception as e:
73
+ logger.error(f"Error in search endpoint: {e}")
74
+ raise HTTPException(status_code=500, detail=str(e))
75
+
76
+ @app.post("/api/faq/search")
77
+ async def faq_search_endpoint(
78
+ query: SearchQuery,
79
+ api_key: str = Depends(get_api_key)
80
+ ):
81
+ try:
82
+ results = await faq_service.search_faqs(query.query, query.top_k)
83
+ return {"results": results}
84
+ except Exception as e:
85
+ logger.error(f"Error in FAQ search endpoint: {e}")
86
+ raise HTTPException(status_code=500, detail=str(e))
87
+
88
+ # Gradio interface
89
+ def create_gradio_interface():
90
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
+ gr.Markdown("# 🦙 Bofrost Chat Assistant\nFragen Sie nach Produkten, Rezepten und mehr!")
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=4):
95
+ chat_display = gr.Chatbot(label="Chat-Verlauf", height=400)
96
+ user_input = gr.Textbox(
97
+ label="Ihre Nachricht",
98
+ placeholder="Stellen Sie Ihre Frage...",
99
+ lines=2
100
+ )
101
+
102
+ with gr.Column(scale=2):
103
+ with gr.Accordion("Zusätzliche Informationen", open=False):
104
+ product_info = gr.JSON(label="Produktdetails")
105
+
106
+ with gr.Row():
107
+ submit_btn = gr.Button("Senden", variant="primary")
108
+ clear_btn = gr.Button("Chat löschen")
109
+
110
+ chat_history = gr.State([])
111
+
112
+ async def respond(message, history):
113
+ response, updated_history, search_results = await chat_service.chat(message, history)
114
+ return response, updated_history, search_results
115
+
116
+ submit_btn.click(
117
+ respond,
118
+ inputs=[user_input, chat_history],
119
+ outputs=[chat_display, chat_history, product_info]
120
+ )
121
+
122
+ clear_btn.click(
123
+ lambda: ([], [], None),
124
+ outputs=[chat_display, chat_history, product_info]
125
+ )
126
+
127
+ demo.queue()
128
+ return demo
129
+
130
+ if __name__ == "__main__":
131
+ import uvicorn
132
 
133
+ # Create and launch Gradio interface
134
+ demo = create_gradio_interface()
135
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
136
 
137
+ # Start FastAPI server
138
+ uvicorn.run(app, host="0.0.0.0", port=8000)