Poonawala commited on
Commit
e87d259
·
verified ·
1 Parent(s): f434761

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -151
app.py CHANGED
@@ -1,154 +1,163 @@
1
- # Configure page settings (MUST BE FIRST STREAMLIT COMMAND)
2
- import streamlit as st
3
- from streamlit_option_menu import option_menu
4
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
- from PyPDF2 import PdfReader
6
-
7
- # Set page config
8
- st.set_page_config(
9
- page_title="Disease Analysis GPT",
10
- layout="wide",
11
- initial_sidebar_state="expanded"
12
- )
13
-
14
- # Load Hugging Face models and tokenizer for text generation
15
- @st.cache_resource
16
- def load_model():
17
- tokenizer = AutoTokenizer.from_pretrained("harishussain12/Disease_Managment")
18
- model = AutoModelForCausalLM.from_pretrained("harishussain12/Disease_Managment")
19
- return tokenizer, model
20
-
21
- # Function to create a text generation pipeline
22
- @st.cache_resource
23
- def create_pipeline():
24
- tokenizer, model = load_model()
25
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
26
-
27
- # Function to extract text from PDF file
28
- def read_pdf(file):
 
 
 
 
 
 
 
 
29
  try:
30
- reader = PdfReader(file)
31
- text = ""
32
- for page in reader.pages:
33
- text += page.extract_text()
34
- return text
 
 
 
 
 
 
 
35
  except Exception as e:
36
- return f"Error reading PDF: {e}"
37
-
38
- # Load pipelines
39
- text_pipeline = create_pipeline()
40
-
41
- # Custom CSS for styling
42
- st.markdown(
43
- """
44
- <style>
45
- body {
46
- font-family: 'Arial', sans-serif;
47
- }
48
- .stButton button {
49
- background-color: #0b2545;
50
- color: white;
51
- border: none;
52
- border-radius: 25px;
53
- padding: 8px 20px;
54
- font-size: 14px;
55
- font-weight: bold;
56
- cursor: pointer;
57
- }
58
- .stButton button:hover {
59
- background-color: #0a1b35;
60
  }
61
- .search-box {
62
- border-radius: 20px;
63
- border: 1px solid #ccc;
64
- padding: 10px;
65
- width: 100%;
66
- font-size: 16px;
67
- background-color: #ffffff;
68
- }
69
- .info-box {
70
- background-color: #f8f9fa;
71
- border-left: 5px solid #0b2545;
72
- padding: 15px;
73
- border-radius: 5px;
74
- font-size: 14px;
75
- }
76
- </style>
77
- """,
78
- unsafe_allow_html=True
79
- )
80
-
81
- # Sidebar
82
- with st.sidebar:
83
- new_chat_button = st.button("New Chat", key="new_chat", help="Start a new chat to ask a different question.")
84
- if new_chat_button:
85
- st.session_state.clear() # Clear session state to simulate a new chat
86
-
87
- selected = option_menu(
88
- menu_title=None,
89
- options=[" Home", " Discover"],
90
- icons=["house", "search"],
91
- menu_icon="cast",
92
- default_index=0,
93
- styles={
94
- "container": {"padding": "0!important", "background-color": "#3e4a5b"},
95
- "icon": {"color": "#ffffff", "font-size": "16px"},
96
- "nav-link": {
97
- "font-size": "15px",
98
- "text-align": "left",
99
- "margin": "0px",
100
- "color": "#ffffff",
101
- "font-weight": "bold",
102
- "padding": "10px 20px",
103
- },
104
- "nav-link-selected": {"background-color": "#0b2545", "color": "white"},
105
- }
106
- )
107
-
108
- # Main content
109
- col1, col2, col3 = st.columns([1, 2, 1])
110
-
111
- with col2:
112
- st.markdown("<h1 style='text-align: center;'>Disease Analysis GPT</h1>", unsafe_allow_html=True)
113
- st.markdown("<h3 style='text-align: center;'>What do you want to know?</h3>", unsafe_allow_html=True)
114
-
115
- # Model selection (now including Document Analysis)
116
- model_selection = st.selectbox(
117
- "Select a model",
118
- options=["Disease Analysis", "Document Analysis"],
119
- index=0
120
- )
121
-
122
- # If the user selects Document Analysis, show an error and prompt them to switch to Disease Analysis
123
- if model_selection == "Document Analysis":
124
- st.error("Please switch to 'Disease Analysis' model for generating responses. Document Analysis is not available in this version.")
125
-
126
- # Search box
127
- search_input = st.text_input(
128
- "",
129
- placeholder="Type your question here...",
130
- label_visibility="collapsed",
131
- help="Ask anything related to disease management."
132
- )
133
-
134
- # File upload below search box
135
- uploaded_file = st.file_uploader("Upload a PDF file", type="pdf", help="Attach relevant files or documents to your query.")
136
-
137
- if search_input:
138
- with st.spinner("Generating response..."):
139
- try:
140
- if model_selection == "Disease Analysis":
141
- context = ""
142
- if uploaded_file is not None:
143
- file_content = read_pdf(uploaded_file)
144
- if "Error" in file_content:
145
- st.error(file_content)
146
- else:
147
- context = file_content
148
-
149
- query_input = search_input + (f"\n\nContext:\n{context}" if context else "")
150
- response = text_pipeline(query_input, max_length=200, num_return_sequences=1)
151
- st.markdown(f"### Response:\n{response[0]['generated_text']}")
152
-
153
- except Exception as e:
154
- st.error(f"Error generating response: {str(e)}")
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from accelerate import Accelerator
5
+
6
+ # Check if GPU is available for better performance
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ print(f"Using device: {device}")
9
+
10
+ # Initialize the Accelerator for optimized inference
11
+ accelerator = Accelerator()
12
+
13
+ # Load models and tokenizers with FP16 for speed optimization if GPU is available
14
+ model_dirs = [
15
+ "Poonawala/gpt2",
16
+ "Poonawala/MiriFur",
17
+ "Poonawala/Llama-3.2-1B"
18
+ ]
19
+
20
+ models = {}
21
+ tokenizers = {}
22
+
23
+ def load_model(model_dir):
24
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
25
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
26
+
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ # Move model to GPU/CPU as per availability
31
+ model = model.to(device)
32
+ return model, tokenizer
33
+
34
+ # Load all models
35
+ for model_dir in model_dirs:
36
+ model_name = model_dir.split("/")[-1]
37
  try:
38
+ model, tokenizer = load_model(model_dir)
39
+ models[model_name] = model
40
+ tokenizers[model_name] = tokenizer
41
+
42
+ # Batch warm-up inference to reduce initial response time
43
+ dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"]
44
+ for dummy_input in dummy_inputs:
45
+ input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device)
46
+ with torch.no_grad():
47
+ model.generate(input_ids, max_new_tokens=1)
48
+
49
+ print(f"Loaded model and tokenizer from {model_dir}.")
50
  except Exception as e:
51
+ print(f"Failed to load model from {model_dir}: {e}")
52
+ continue
53
+
54
+ def get_response(prompt, model_name, user_type):
55
+ if model_name not in models:
56
+ return "Model not loaded correctly."
57
+
58
+ model = models[model_name]
59
+ tokenizer = tokenizers[model_name]
60
+
61
+ # Define different prompt templates based on user type
62
+ user_type_templates = {
63
+ "Expert": f"As an Expert, {prompt}\nAnswer:",
64
+ "Intermediate": f"As an Intermediate, {prompt}\nAnswer:",
65
+ "Beginner": f"Explain in simple terms: {prompt}\nAnswer:",
66
+ "Professional": f"As a Professional, {prompt}\nAnswer:"
 
 
 
 
 
 
 
 
67
  }
68
+
69
+ # Get the appropriate prompt based on user type
70
+ prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:")
71
+
72
+ encoding = tokenizer(
73
+ prompt_template,
74
+ return_tensors='pt',
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=500 # Increased length for larger inputs
78
+ ).to(device)
79
+
80
+ max_new_tokens = 200 # Increased to allow full-length answers
81
+
82
+ with torch.no_grad():
83
+ output = model.generate(
84
+ input_ids=encoding['input_ids'],
85
+ attention_mask=encoding['attention_mask'],
86
+ max_new_tokens=max_new_tokens, # Higher value for longer answers
87
+ num_beams=3, # Using beam search for better quality answers
88
+ repetition_penalty=1.2, # Increased to reduce repetitive text
89
+ temperature=0.9, # Slightly higher for creative outputs
90
+ top_p=0.9, # Including more tokens for diverse generation
91
+ early_stopping=True,
92
+ pad_token_id=tokenizer.pad_token_id
93
+ )
94
+
95
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
96
+ return response.strip()
97
+
98
+ def process_input(prompt, model_name, user_type):
99
+ if prompt and prompt.strip():
100
+ return get_response(prompt, model_name, user_type)
101
+ else:
102
+ return "Please provide a prompt."
103
+
104
+ # Gradio Interface with Modern Design
105
+ with gr.Blocks(css="""
106
+ body {
107
+ background-color: #faf3e0; /* Beige for a warm food-related theme */
108
+ font-family: 'Arial, sans-serif';
109
+ }
110
+ .title {
111
+ font-size: 2.5rem;
112
+ font-weight: bold;
113
+ color: #ff7f50; /* Coral color for a food-inspired look */
114
+ text-align: center;
115
+ margin-bottom: 1rem;
116
+ }
117
+ .container {
118
+ max-width: 900px;
119
+ margin: auto;
120
+ padding: 2rem;
121
+ background-color: #ffffff;
122
+ border-radius: 10px;
123
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
124
+ }
125
+ .button {
126
+ background-color: #ff7f50; /* Coral color for buttons */
127
+ color: white;
128
+ padding: 0.8rem 1.5rem;
129
+ font-size: 1rem;
130
+ border: none;
131
+ border-radius: 5px;
132
+ cursor: pointer;
133
+ }
134
+ .button:hover {
135
+ background-color: #ffa07a; /* Light salmon for hover effect */
136
+ }
137
+ """) as demo:
138
+
139
+ gr.Markdown("<div class='title'>Cookspert: Your Cooking Assistant</div>")
140
+
141
+ user_types = ["Expert", "Intermediate", "Beginner", "Professional"]
142
+
143
+ with gr.Tabs():
144
+ with gr.TabItem("Ask a Cooking Question"):
145
+ with gr.Row():
146
+ with gr.Column(scale=2):
147
+ prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2)
148
+ model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True)
149
+ user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner")
150
+ submit_button = gr.Button("ChefGPT", elem_classes="button")
151
+
152
+ response = gr.Textbox(
153
+ label="🍽️ Response",
154
+ placeholder="Your answer will appear here...",
155
+ lines=10,
156
+ interactive=False,
157
+ show_copy_button=True
158
+ )
159
+
160
+ submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response)
161
+
162
+ if __name__ == "__main__":
163
+ demo.launch(server_name="0.0.0.0", share=True, debug=True)