Mubbashir Ahmed commited on
Commit
697c8ae
·
1 Parent(s): e40bcd6

updates on app for evaluations

Browse files
Files changed (1) hide show
  1. app.py +90 -129
app.py CHANGED
@@ -1,152 +1,113 @@
1
  import os
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
- from sqlalchemy import create_engine
5
-
6
- # Clients for each model provider
7
- llama_client = InferenceClient(provider="sambanova", api_key=os.environ["HF_TOKEN"])
8
- minimax_client = InferenceClient(provider="novita", api_key=os.environ["HF_TOKEN"])
9
- mistral_client = InferenceClient(provider="together", api_key=os.environ["HF_TOKEN"])
10
-
11
- # Global objects
12
- db_connection = None
13
-
14
- def get_sqlalchemy_connection():
15
- server = os.getenv("SQL_SERVER")
16
- database = os.getenv("SQL_DATABASE")
17
- username = os.getenv("SQL_USERNAME")
18
- password = os.getenv("SQL_PASSWORD")
19
-
20
- connection_url = f"mssql+pymssql://{username}:{password}@{server}/{database}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  try:
23
- engine = create_engine(connection_url)
24
- conn = engine.connect()
25
- print("✅ SQLAlchemy + pymssql connection successful")
26
- return conn
27
- except Exception as e:
28
- print(f"❌ SQLAlchemy connection failed: {e}")
29
- return None
30
-
31
- def get_sql_connection():
32
- global db_connection
33
-
34
- if db_connection is not None:
35
- try:
36
- db_connection.cursor() # test if still open
37
- return db_connection
38
- except Exception as e:
39
- print(f"❌ SQL connection failed: {e}")
40
- db_connection = None # reset if broken
41
-
42
- # Reconnect if needed
43
- db_connection = get_sqlalchemy_connection()
44
- return db_connection
45
-
46
- # Format chat history for Markdown display
47
- def format_chat_history(chat_history):
48
- formatted = ""
49
- for msg in chat_history:
50
- role = msg["role"]
51
- content = msg["content"]
52
- if isinstance(content, list): # For LLaMA image+text input
53
- for item in content:
54
- if "text" in item:
55
- formatted += f"**{role.capitalize()}:** {item['text']}\n\n"
56
- elif "image_url" in item:
57
- formatted += f"**{role.capitalize()}:** 🖼️ Image: {item['image_url']['url']}\n\n"
58
- else:
59
- formatted += f"**{role.capitalize()}:** {content}\n\n"
60
- return formatted.strip()
61
-
62
- # Main chat handler
63
- def chat_with_model(model_choice, prompt, image_url, chat_history):
64
- if not prompt:
65
- return "❌ Please enter a text prompt.", chat_history, "", ""
66
-
67
- if chat_history is None:
68
- chat_history = []
69
-
70
- conn = get_sql_connection()
71
- if conn is None:
72
- return "❌ Failed to connect to database.", chat_history, "", ""
73
-
74
- try:
75
- # === LLaMA 4 ===
76
- if model_choice == "LLaMA 4 (SambaNova)":
77
- user_msg = [{"type": "text", "text": prompt}]
78
- if image_url:
79
- user_msg.append({"type": "image_url", "image_url": {"url": image_url}})
80
- chat_history.append({"role": "user", "content": user_msg})
81
-
82
- response = llama_client.chat.completions.create(
83
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
84
- messages=chat_history
85
  )
86
- bot_msg = response.choices[0].message.content
87
- chat_history.append({"role": "assistant", "content": bot_msg})
88
-
89
- # === MiniMax ===
90
- elif model_choice == "MiniMax M1 (Novita)":
91
- chat_history.append({"role": "user", "content": prompt})
92
- response = minimax_client.chat.completions.create(
93
- model="MiniMaxAI/MiniMax-M1-80k",
94
- messages=chat_history
95
  )
96
- bot_msg = response.choices[0].message.content
97
- chat_history.append({"role": "assistant", "content": bot_msg})
98
-
99
- # === Mistral ===
100
- elif model_choice == "Mistral Mixtral-8x7B (Together)":
101
- chat_history.append({"role": "user", "content": prompt})
102
- response = mistral_client.chat.completions.create(
103
- model="mistralai/Mixtral-8x7B-Instruct-v0.1",
104
- messages=chat_history
 
 
 
 
 
 
105
  )
106
- bot_msg = response.choices[0].message.content
107
- chat_history.append({"role": "assistant", "content": bot_msg})
108
 
109
  else:
110
- return "❌ Unsupported model selected.", chat_history, "", ""
111
-
112
- return format_chat_history(chat_history), chat_history, "", ""
113
 
114
  except Exception as e:
115
- return f" Error: {e}", chat_history, "", ""
 
 
 
 
116
 
117
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
118
  with gr.Blocks() as demo:
119
- gr.Markdown("## 🤖 Multi-Model Context-Aware Chatbot")
120
- gr.Markdown("Supports LLaMA 4 (with optional image), MiniMax, and Mistral. Memory is preserved for multi-turn dialog.")
121
-
122
- model_dropdown = gr.Dropdown(
123
- choices=[
124
- "LLaMA 4 (SambaNova)",
125
- "MiniMax M1 (Novita)",
126
- "Mistral Mixtral-8x7B (Together)"
127
- ],
128
- value="LLaMA 4 (SambaNova)",
129
- label="Select Model"
130
  )
131
 
132
- prompt_input = gr.Textbox(label="Text Prompt", placeholder="Ask something...", lines=2)
133
- image_url_input = gr.Textbox(label="Optional Image URL (for LLaMA only)", placeholder="https://example.com/image.jpg")
134
 
135
- submit_btn = gr.Button("💬 Generate Response")
136
- reset_btn = gr.Button("🔄 Reset Conversation")
137
- output_box = gr.Markdown(label="Chat History", value="")
138
- state = gr.State([])
139
 
140
- submit_btn.click(
141
- fn=chat_with_model,
142
- inputs=[model_dropdown, prompt_input, image_url_input, state],
143
- outputs=[output_box, state, prompt_input, image_url_input]
144
- )
145
 
146
- reset_btn.click(
147
- fn=lambda: ("🧹 Conversation reset. You can start a new one.", [], "", ""),
148
- inputs=[],
149
- outputs=[output_box, state, prompt_input, image_url_input]
150
  )
151
 
 
152
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ HF_TOKEN = os.environ.get("HF_TOKEN")
8
+
9
+ # ------------------------
10
+ # API Clients
11
+ # ------------------------
12
+ llama_client = InferenceClient(
13
+ provider="fireworks-ai",
14
+ api_key=HF_TOKEN,
15
+ )
16
+
17
+ qwen_client = InferenceClient(
18
+ provider="featherless-ai",
19
+ api_key=HF_TOKEN,
20
+ )
21
+
22
+ # ------------------------
23
+ # Mixtral Local Setup
24
+ # ------------------------
25
+ mixtral_model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
26
+ mixtral_tokenizer = AutoTokenizer.from_pretrained(mixtral_model_id)
27
+ mixtral_model = AutoModelForCausalLM.from_pretrained(
28
+ mixtral_model_id, torch_dtype=torch.float16
29
+ ).to("cuda")
30
+
31
+ # ------------------------
32
+ # Unified Inference Function with Chat History
33
+ # ------------------------
34
+ def run_model_with_history(model_name, user_input, chat_history):
35
+ messages = chat_history + [{"role": "user", "content": user_input}]
36
 
37
  try:
38
+ if model_name == "LLaMA 4":
39
+ result = llama_client.chat.completions.create(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
41
+ messages=messages
42
  )
43
+ reply = result.choices[0].message.content
44
+
45
+ elif model_name == "Qwen3 14B":
46
+ result = qwen_client.chat.completions.create(
47
+ model="Qwen/Qwen3-14B",
48
+ messages=messages
 
 
 
49
  )
50
+ reply = result.choices[0].message.content
51
+
52
+ elif model_name == "Mixtral 8x7B":
53
+ full_prompt = ""
54
+ for msg in messages:
55
+ prefix = "User: " if msg["role"] == "user" else "Assistant: "
56
+ full_prompt += f"{prefix}{msg['content']}\n"
57
+ inputs = mixtral_tokenizer(full_prompt, return_tensors="pt").to("cuda")
58
+ outputs = mixtral_model.generate(
59
+ **inputs,
60
+ max_new_tokens=512,
61
+ do_sample=True,
62
+ temperature=0.7,
63
+ top_k=50,
64
+ top_p=0.95
65
  )
66
+ reply = mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
67
 
68
  else:
69
+ reply = "❌ Invalid model selection."
 
 
70
 
71
  except Exception as e:
72
+ reply = f"⚠️ Error: {str(e)}"
73
+
74
+ # Update chat history
75
+ chat_history.append({"role": "user", "content": user_input})
76
+ chat_history.append({"role": "assistant", "content": reply})
77
 
78
+ # Format display
79
+ chat_transcript = "\n".join([
80
+ f"👤 User: {msg['content']}" if msg["role"] == "user" else f"🤖 Assistant: {msg['content']}"
81
+ for msg in chat_history
82
+ ])
83
+
84
+ return chat_transcript, chat_history
85
+
86
+ # ------------------------
87
+ # Gradio UI
88
+ # ------------------------
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("## 🧠 Generative AI Model Evaluation with Context")
91
+
92
+ model_choice = gr.Dropdown(
93
+ choices=["LLaMA 4", "Qwen3 14B", "Mixtral 8x7B"],
94
+ label="Select Model",
95
+ value="LLaMA 4"
 
 
 
 
 
96
  )
97
 
98
+ chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
99
+ prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
100
 
101
+ run_button = gr.Button("Send")
 
 
 
102
 
103
+ # Hidden chat history state
104
+ chat_memory = gr.State([])
 
 
 
105
 
106
+ run_button.click(
107
+ fn=run_model_with_history,
108
+ inputs=[model_choice, prompt_input, chat_memory],
109
+ outputs=[chat_display, chat_memory]
110
  )
111
 
112
+ # Launch app
113
  demo.launch()