mobenta commited on
Commit
5996956
·
verified ·
1 Parent(s): ef65105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -44
app.py CHANGED
@@ -6,23 +6,40 @@ from PIL import Image
6
  import gradio as gr
7
  import logging
8
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
 
9
 
10
  # Configure logging to write to a file
11
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
- # Check GPU availability and initialize device
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- print("CUDA Available:", torch.cuda.is_available())
16
- print("Using device:", device)
17
 
18
- # Load the ChartGemma model and processor
19
- try:
20
- model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma", torch_dtype=torch.float16).to(device)
21
- processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
22
- print("Model and processor loaded successfully.")
23
- except Exception as e:
24
- print("Error loading model or processor:", e)
25
- logging.error(f"Error loading model or processor: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Function to fetch stock data with different intervals
28
  def fetch_stock_data(ticker='TSLA', start='2023-01-01', end='2024-01-01', interval='1d'):
@@ -80,36 +97,6 @@ def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
80
  logging.error(f"Error combining images: {e}")
81
  raise
82
 
83
- # Function to generate insights
84
- def generate_insights(image, query, ticker1=None, ticker2=None):
85
- try:
86
- logging.debug(f"Generating insights for query: {query}")
87
-
88
- # Open and process the image
89
- image = Image.open(image).convert('RGB')
90
- inputs = processor(text=query, images=image, return_tensors="pt")
91
- logging.debug(f"Inputs prepared with shapes {inputs['input_ids'].shape} and {inputs['pixel_values'].shape}")
92
-
93
- prompt_length = inputs['input_ids'].shape[1]
94
- inputs = {k: v.to(device) for k, v in inputs.items()}
95
-
96
- # Generate insights using the model
97
- generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
98
- output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
99
-
100
- # Replace placeholders with actual ticker names in the insights
101
- if ticker1:
102
- output_text = output_text.replace("[First Ticker]", ticker1)
103
- if ticker2:
104
- output_text = output_text.replace("[Second Ticker]", ticker2)
105
-
106
- logging.debug(f"Generated insights: {output_text}")
107
-
108
- return output_text
109
- except Exception as e:
110
- logging.error(f"Error generating insights: {e}")
111
- return f"Error generating insights: {e}"
112
-
113
  # Function to handle the Gradio interface
114
  def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval):
115
  try:
@@ -127,10 +114,10 @@ def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_typ
127
 
128
  # Combine the two charts into one image
129
  combined_chart_path = combine_images(chart_path1, chart_path2)
130
- insights = generate_insights(combined_chart_path, query, ticker1, ticker2)
131
  return insights, combined_chart_path
132
 
133
- insights = generate_insights(chart_path1, query, ticker1)
134
  return insights, chart_path1
135
  except Exception as e:
136
  logging.error(f"Error processing image or query: {e}")
 
6
  import gradio as gr
7
  import logging
8
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
9
+ import spaces
10
 
11
  # Configure logging to write to a file
12
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
+ # Load the ChartGemma model and processor outside the GPU context
15
+ model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
16
+ processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
 
17
 
18
+ @spaces.GPU
19
+ def predict(image, input_text, ticker1=None, ticker2=None):
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.to(device)
22
+
23
+ image = image.convert("RGB")
24
+
25
+ inputs = processor(text=input_text, images=image, return_tensors="pt")
26
+ inputs = {k: v.to(device) for k, v in inputs.items()}
27
+
28
+ prompt_length = inputs['input_ids'].shape[1]
29
+
30
+ # Generate insights using the model
31
+ generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
32
+ output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
33
+
34
+ # Replace placeholders with actual ticker names in the insights
35
+ if ticker1:
36
+ output_text = output_text.replace("[First Ticker]", ticker1)
37
+ if ticker2:
38
+ output_text = output_text.replace("[Second Ticker]", ticker2)
39
+
40
+ logging.debug(f"Generated insights: {output_text}")
41
+
42
+ return output_text
43
 
44
  # Function to fetch stock data with different intervals
45
  def fetch_stock_data(ticker='TSLA', start='2023-01-01', end='2024-01-01', interval='1d'):
 
97
  logging.error(f"Error combining images: {e}")
98
  raise
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # Function to handle the Gradio interface
101
  def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval):
102
  try:
 
114
 
115
  # Combine the two charts into one image
116
  combined_chart_path = combine_images(chart_path1, chart_path2)
117
+ insights = predict(Image.open(combined_chart_path), query, ticker1, ticker2)
118
  return insights, combined_chart_path
119
 
120
+ insights = predict(Image.open(chart_path1), query, ticker1)
121
  return insights, chart_path1
122
  except Exception as e:
123
  logging.error(f"Error processing image or query: {e}")